Merge branch 'main' into patch-1

This commit is contained in:
Kristoffer Dalby 2021-11-15 23:00:45 +00:00 committed by GitHub
commit bd7b5e97cb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
66 changed files with 2981 additions and 1869 deletions

View file

@ -18,12 +18,11 @@ jobs:
- name: Setup Go - name: Setup Go
uses: actions/setup-go@v2 uses: actions/setup-go@v2
with: with:
go-version: "1.16.3" go-version: "1.17.3"
- name: Install dependencies - name: Install dependencies
run: | run: |
go version go version
go install golang.org/x/lint/golint@latest
sudo apt update sudo apt update
sudo apt install -y make sudo apt install -y make

View file

@ -1,20 +1,37 @@
---
name: CI name: CI
on: [push, pull_request] on: [push, pull_request]
jobs: jobs:
# The "build" workflow golangci-lint:
lint:
# The type of runner that the job will run on
runs-on: ubuntu-latest runs-on: ubuntu-latest
# Steps represent a sequence of tasks that will be executed as part of the job
steps: steps:
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
- uses: actions/checkout@v2 - uses: actions/checkout@v2
# Install and run golangci-lint as a separate step, it's much faster this - name: golangci-lint
# way because this action has caching. It'll get run again in `make lint` uses: golangci/golangci-lint-action@v2
# below, but it's still much faster in the end than installing with:
# golangci-lint manually in the `Run lint` step. version: latest
- uses: golangci/golangci-lint-action@v2
prettier-lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Prettify code
uses: creyD/prettier_action@v4.0
with:
prettier_options: >-
--check **/*.{ts,js,md,yaml,yml,sass,css,scss,html}
only_changed: false
dry: true
proto-lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: bufbuild/buf-setup-action@v0.7.0
- uses: bufbuild/buf-lint-action@v1
with:
input: "proto"

View file

@ -1,7 +1,53 @@
--- ---
run: run:
timeout: 5m timeout: 10m
issues: issues:
skip-dirs: skip-dirs:
- gen - gen
linters:
enable-all: true
disable:
- exhaustivestruct
- revive
- lll
- interfacer
- scopelint
- maligned
- golint
- gofmt
- gochecknoglobals
- gochecknoinits
- gocognit
- funlen
- exhaustivestruct
- tagliatelle
- godox
- ireturn
# In progress
- gocritic
# We should strive to enable these:
- wrapcheck
- dupl
- makezero
# We might want to enable this, but it might be a lot of work
- cyclop
- nestif
- wsl # might be incompatible with gofumpt
- testpackage
- paralleltest
linters-settings:
varnamelen:
ignore-type-assert-ok: true
ignore-map-index-ok: true
ignore-names:
- err
- db
- id
- ip
- ok
- c

View file

@ -1,6 +1,14 @@
# Calculate version # Calculate version
version = $(shell ./scripts/version-at-commit.sh) version = $(shell ./scripts/version-at-commit.sh)
rwildcard=$(foreach d,$(wildcard $1*),$(call rwildcard,$d/,$2) $(filter $(subst *,%,$2),$d))
# GO_SOURCES = $(wildcard *.go)
# PROTO_SOURCES = $(wildcard **/*.proto)
GO_SOURCES = $(call rwildcard,,*.go)
PROTO_SOURCES = $(call rwildcard,,*.proto)
build: build:
go build -ldflags "-s -w -X github.com/juanfont/headscale/cmd/headscale/cli.Version=$(version)" cmd/headscale/headscale.go go build -ldflags "-s -w -X github.com/juanfont/headscale/cmd/headscale/cli.Version=$(version)" cmd/headscale/headscale.go
@ -19,7 +27,12 @@ coverprofile_html:
go tool cover -html=coverage.out go tool cover -html=coverage.out
lint: lint:
golangci-lint run --fix golangci-lint run --fix --timeout 10m
fmt:
prettier --write '**/**.{ts,js,md,yaml,yml,sass,css,scss,html}'
golines --max-len=88 --base-formatter=gofumpt -w $(GO_SOURCES)
clang-format -style="{BasedOnStyle: Google, IndentWidth: 4, AlignConsecutiveDeclarations: true, AlignConsecutiveAssignments: true, ColumnLimit: 0}" -i $(PROTO_SOURCES)
proto-lint: proto-lint:
cd proto/ && buf lint cd proto/ && buf lint

View file

@ -54,7 +54,6 @@ Suggestions/PRs welcomed!
Please have a look at the documentation under [`docs/`](docs/). Please have a look at the documentation under [`docs/`](docs/).
## Disclaimer ## Disclaimer
1. We have nothing to do with Tailscale, or Tailscale Inc. 1. We have nothing to do with Tailscale, or Tailscale Inc.
@ -64,6 +63,23 @@ Please have a look at the documentation under [`docs/`](docs/).
To contribute to Headscale you would need the lastest version of [Go](https://golang.org) and [Buf](https://buf.build)(Protobuf generator). To contribute to Headscale you would need the lastest version of [Go](https://golang.org) and [Buf](https://buf.build)(Protobuf generator).
### Code style
To ensure we have some consistency with a growing number of contributes, this project has adopted linting and style/formatting rules:
The **Go** code is linted with [`golangci-lint`](https://golangci-lint.run) and
formatted with [`golines`](https://github.com/segmentio/golines) (width 88) and
[`gofumpt`](https://github.com/mvdan/gofumpt).
Please configure your editor to run the tools while developing and make sure to
run `make lint` and `make fmt` before committing any code.
The **Proto** code is linted with [`buf`](https://docs.buf.build/lint/overview) and
formatted with [`clang-format`](https://clang.llvm.org/docs/ClangFormat.html).
The **rest** (markdown, yaml, etc) is formatted with [`prettier`](https://prettier.io).
Check out the `.golangci.yaml` and `Makefile` to see the specific configuration.
### Install development tools ### Install development tools
- Go - Go
@ -81,6 +97,7 @@ Some parts of the project requires the generation of Go code from Protobuf (if c
```shell ```shell
make generate make generate
``` ```
**Note**: Please check in changes from `gen/` in a separate commit to make it easier to review. **Note**: Please check in changes from `gen/` in a separate commit to make it easier to review.
To run the tests: To run the tests:
@ -261,5 +278,3 @@ make build
</td> </td>
</tr> </tr>
</table> </table>

149
acls.go
View file

@ -9,23 +9,30 @@ import (
"strings" "strings"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/tailscale/hujson" "github.com/tailscale/hujson"
"inet.af/netaddr" "inet.af/netaddr"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
) )
const ( const (
errorEmptyPolicy = Error("empty policy") errEmptyPolicy = Error("empty policy")
errorInvalidAction = Error("invalid action") errInvalidAction = Error("invalid action")
errorInvalidUserSection = Error("invalid user section") errInvalidUserSection = Error("invalid user section")
errorInvalidGroup = Error("invalid group") errInvalidGroup = Error("invalid group")
errorInvalidTag = Error("invalid tag") errInvalidTag = Error("invalid tag")
errorInvalidNamespace = Error("invalid namespace") errInvalidNamespace = Error("invalid namespace")
errorInvalidPortFormat = Error("invalid port format") errInvalidPortFormat = Error("invalid port format")
) )
// LoadACLPolicy loads the ACL policy from the specify path, and generates the ACL rules const (
Base10 = 10
BitSize16 = 16
portRangeBegin = 0
portRangeEnd = 65535
expectedTokenItems = 2
)
// LoadACLPolicy loads the ACL policy from the specify path, and generates the ACL rules.
func (h *Headscale) LoadACLPolicy(path string) error { func (h *Headscale) LoadACLPolicy(path string) error {
policyFile, err := os.Open(path) policyFile, err := os.Open(path)
if err != nil { if err != nil {
@ -34,23 +41,23 @@ func (h *Headscale) LoadACLPolicy(path string) error {
defer policyFile.Close() defer policyFile.Close()
var policy ACLPolicy var policy ACLPolicy
b, err := io.ReadAll(policyFile) policyBytes, err := io.ReadAll(policyFile)
if err != nil { if err != nil {
return err return err
} }
ast, err := hujson.Parse(b) ast, err := hujson.Parse(policyBytes)
if err != nil { if err != nil {
return err return err
} }
ast.Standardize() ast.Standardize()
b = ast.Pack() policyBytes = ast.Pack()
err = json.Unmarshal(b, &policy) err = json.Unmarshal(policyBytes, &policy)
if err != nil { if err != nil {
return err return err
} }
if policy.IsZero() { if policy.IsZero() {
return errorEmptyPolicy return errEmptyPolicy
} }
h.aclPolicy = &policy h.aclPolicy = &policy
@ -59,37 +66,40 @@ func (h *Headscale) LoadACLPolicy(path string) error {
return err return err
} }
h.aclRules = rules h.aclRules = rules
return nil return nil
} }
func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) { func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) {
rules := []tailcfg.FilterRule{} rules := []tailcfg.FilterRule{}
for i, a := range h.aclPolicy.ACLs { for index, acl := range h.aclPolicy.ACLs {
if a.Action != "accept" { if acl.Action != "accept" {
return nil, errorInvalidAction return nil, errInvalidAction
} }
r := tailcfg.FilterRule{} filterRule := tailcfg.FilterRule{}
srcIPs := []string{} srcIPs := []string{}
for j, u := range a.Users { for innerIndex, user := range acl.Users {
srcs, err := h.generateACLPolicySrcIP(u) srcs, err := h.generateACLPolicySrcIP(user)
if err != nil { if err != nil {
log.Error(). log.Error().
Msgf("Error parsing ACL %d, User %d", i, j) Msgf("Error parsing ACL %d, User %d", index, innerIndex)
return nil, err return nil, err
} }
srcIPs = append(srcIPs, srcs...) srcIPs = append(srcIPs, srcs...)
} }
r.SrcIPs = srcIPs filterRule.SrcIPs = srcIPs
destPorts := []tailcfg.NetPortRange{} destPorts := []tailcfg.NetPortRange{}
for j, d := range a.Ports { for innerIndex, ports := range acl.Ports {
dests, err := h.generateACLPolicyDestPorts(d) dests, err := h.generateACLPolicyDestPorts(ports)
if err != nil { if err != nil {
log.Error(). log.Error().
Msgf("Error parsing ACL %d, Port %d", i, j) Msgf("Error parsing ACL %d, Port %d", index, innerIndex)
return nil, err return nil, err
} }
destPorts = append(destPorts, dests...) destPorts = append(destPorts, dests...)
@ -108,10 +118,12 @@ func (h *Headscale) generateACLPolicySrcIP(u string) ([]string, error) {
return h.expandAlias(u) return h.expandAlias(u)
} }
func (h *Headscale) generateACLPolicyDestPorts(d string) ([]tailcfg.NetPortRange, error) { func (h *Headscale) generateACLPolicyDestPorts(
d string,
) ([]tailcfg.NetPortRange, error) {
tokens := strings.Split(d, ":") tokens := strings.Split(d, ":")
if len(tokens) < 2 || len(tokens) > 3 { if len(tokens) < expectedTokenItems || len(tokens) > 3 {
return nil, errorInvalidPortFormat return nil, errInvalidPortFormat
} }
var alias string var alias string
@ -121,7 +133,7 @@ func (h *Headscale) generateACLPolicyDestPorts(d string) ([]tailcfg.NetPortRange
// tag:montreal-webserver:80,443 // tag:montreal-webserver:80,443
// tag:api-server:443 // tag:api-server:443
// example-host-1:* // example-host-1:*
if len(tokens) == 2 { if len(tokens) == expectedTokenItems {
alias = tokens[0] alias = tokens[0]
} else { } else {
alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1]) alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1])
@ -146,34 +158,36 @@ func (h *Headscale) generateACLPolicyDestPorts(d string) ([]tailcfg.NetPortRange
dests = append(dests, pr) dests = append(dests, pr)
} }
} }
return dests, nil return dests, nil
} }
func (h *Headscale) expandAlias(s string) ([]string, error) { func (h *Headscale) expandAlias(alias string) ([]string, error) {
if s == "*" { if alias == "*" {
return []string{"*"}, nil return []string{"*"}, nil
} }
if strings.HasPrefix(s, "group:") { if strings.HasPrefix(alias, "group:") {
if _, ok := h.aclPolicy.Groups[s]; !ok { if _, ok := h.aclPolicy.Groups[alias]; !ok {
return nil, errorInvalidGroup return nil, errInvalidGroup
} }
ips := []string{} ips := []string{}
for _, n := range h.aclPolicy.Groups[s] { for _, n := range h.aclPolicy.Groups[alias] {
nodes, err := h.ListMachinesInNamespace(n) nodes, err := h.ListMachinesInNamespace(n)
if err != nil { if err != nil {
return nil, errorInvalidNamespace return nil, errInvalidNamespace
} }
for _, node := range nodes { for _, node := range nodes {
ips = append(ips, node.IPAddress) ips = append(ips, node.IPAddress)
} }
} }
return ips, nil return ips, nil
} }
if strings.HasPrefix(s, "tag:") { if strings.HasPrefix(alias, "tag:") {
if _, ok := h.aclPolicy.TagOwners[s]; !ok { if _, ok := h.aclPolicy.TagOwners[alias]; !ok {
return nil, errorInvalidTag return nil, errInvalidTag
} }
// This will have HORRIBLE performance. // This will have HORRIBLE performance.
@ -183,10 +197,10 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
return nil, err return nil, err
} }
ips := []string{} ips := []string{}
for _, m := range machines { for _, machine := range machines {
hostinfo := tailcfg.Hostinfo{} hostinfo := tailcfg.Hostinfo{}
if len(m.HostInfo) != 0 { if len(machine.HostInfo) != 0 {
hi, err := m.HostInfo.MarshalJSON() hi, err := machine.HostInfo.MarshalJSON()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -197,17 +211,19 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
// FIXME: Check TagOwners allows this // FIXME: Check TagOwners allows this
for _, t := range hostinfo.RequestTags { for _, t := range hostinfo.RequestTags {
if s[4:] == t { if alias[4:] == t {
ips = append(ips, m.IPAddress) ips = append(ips, machine.IPAddress)
break break
} }
} }
} }
} }
return ips, nil return ips, nil
} }
n, err := h.GetNamespace(s) n, err := h.GetNamespace(alias)
if err == nil { if err == nil {
nodes, err := h.ListMachinesInNamespace(n.Name) nodes, err := h.ListMachinesInNamespace(n.Name)
if err != nil { if err != nil {
@ -217,49 +233,54 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
for _, n := range nodes { for _, n := range nodes {
ips = append(ips, n.IPAddress) ips = append(ips, n.IPAddress)
} }
return ips, nil return ips, nil
} }
if h, ok := h.aclPolicy.Hosts[s]; ok { if h, ok := h.aclPolicy.Hosts[alias]; ok {
return []string{h.String()}, nil return []string{h.String()}, nil
} }
ip, err := netaddr.ParseIP(s) ip, err := netaddr.ParseIP(alias)
if err == nil { if err == nil {
return []string{ip.String()}, nil return []string{ip.String()}, nil
} }
cidr, err := netaddr.ParseIPPrefix(s) cidr, err := netaddr.ParseIPPrefix(alias)
if err == nil { if err == nil {
return []string{cidr.String()}, nil return []string{cidr.String()}, nil
} }
return nil, errorInvalidUserSection return nil, errInvalidUserSection
} }
func (h *Headscale) expandPorts(s string) (*[]tailcfg.PortRange, error) { func (h *Headscale) expandPorts(portsStr string) (*[]tailcfg.PortRange, error) {
if s == "*" { if portsStr == "*" {
return &[]tailcfg.PortRange{{First: 0, Last: 65535}}, nil return &[]tailcfg.PortRange{
{First: portRangeBegin, Last: portRangeEnd},
}, nil
} }
ports := []tailcfg.PortRange{} ports := []tailcfg.PortRange{}
for _, p := range strings.Split(s, ",") { for _, portStr := range strings.Split(portsStr, ",") {
rang := strings.Split(p, "-") rang := strings.Split(portStr, "-")
if len(rang) == 1 { switch len(rang) {
pi, err := strconv.ParseUint(rang[0], 10, 16) case 1:
port, err := strconv.ParseUint(rang[0], Base10, BitSize16)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ports = append(ports, tailcfg.PortRange{ ports = append(ports, tailcfg.PortRange{
First: uint16(pi), First: uint16(port),
Last: uint16(pi), Last: uint16(port),
}) })
} else if len(rang) == 2 {
start, err := strconv.ParseUint(rang[0], 10, 16) case expectedTokenItems:
start, err := strconv.ParseUint(rang[0], Base10, BitSize16)
if err != nil { if err != nil {
return nil, err return nil, err
} }
last, err := strconv.ParseUint(rang[1], 10, 16) last, err := strconv.ParseUint(rang[1], Base10, BitSize16)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -267,9 +288,11 @@ func (h *Headscale) expandPorts(s string) (*[]tailcfg.PortRange, error) {
First: uint16(start), First: uint16(start),
Last: uint16(last), Last: uint16(last),
}) })
} else {
return nil, errorInvalidPortFormat default:
return nil, errInvalidPortFormat
} }
} }
return &ports, nil return &ports, nil
} }

View file

@ -5,54 +5,58 @@ import (
) )
func (s *Suite) TestWrongPath(c *check.C) { func (s *Suite) TestWrongPath(c *check.C) {
err := h.LoadACLPolicy("asdfg") err := app.LoadACLPolicy("asdfg")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
} }
func (s *Suite) TestBrokenHuJson(c *check.C) { func (s *Suite) TestBrokenHuJson(c *check.C) {
err := h.LoadACLPolicy("./tests/acls/broken.hujson") err := app.LoadACLPolicy("./tests/acls/broken.hujson")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
} }
func (s *Suite) TestInvalidPolicyHuson(c *check.C) { func (s *Suite) TestInvalidPolicyHuson(c *check.C) {
err := h.LoadACLPolicy("./tests/acls/invalid.hujson") err := app.LoadACLPolicy("./tests/acls/invalid.hujson")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
c.Assert(err, check.Equals, errorEmptyPolicy) c.Assert(err, check.Equals, errEmptyPolicy)
} }
func (s *Suite) TestParseHosts(c *check.C) { func (s *Suite) TestParseHosts(c *check.C) {
var hs Hosts var hosts Hosts
err := hs.UnmarshalJSON([]byte(`{"example-host-1": "100.100.100.100","example-host-2": "100.100.101.100/24"}`)) err := hosts.UnmarshalJSON(
c.Assert(hs, check.NotNil) []byte(
`{"example-host-1": "100.100.100.100","example-host-2": "100.100.101.100/24"}`,
),
)
c.Assert(hosts, check.NotNil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
} }
func (s *Suite) TestParseInvalidCIDR(c *check.C) { func (s *Suite) TestParseInvalidCIDR(c *check.C) {
var hs Hosts var hosts Hosts
err := hs.UnmarshalJSON([]byte(`{"example-host-1": "100.100.100.100/42"}`)) err := hosts.UnmarshalJSON([]byte(`{"example-host-1": "100.100.100.100/42"}`))
c.Assert(hs, check.IsNil) c.Assert(hosts, check.IsNil)
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
} }
func (s *Suite) TestRuleInvalidGeneration(c *check.C) { func (s *Suite) TestRuleInvalidGeneration(c *check.C) {
err := h.LoadACLPolicy("./tests/acls/acl_policy_invalid.hujson") err := app.LoadACLPolicy("./tests/acls/acl_policy_invalid.hujson")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
} }
func (s *Suite) TestBasicRule(c *check.C) { func (s *Suite) TestBasicRule(c *check.C) {
err := h.LoadACLPolicy("./tests/acls/acl_policy_basic_1.hujson") err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_1.hujson")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
rules, err := h.generateACLRules() rules, err := app.generateACLRules()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil) c.Assert(rules, check.NotNil)
} }
func (s *Suite) TestPortRange(c *check.C) { func (s *Suite) TestPortRange(c *check.C) {
err := h.LoadACLPolicy("./tests/acls/acl_policy_basic_range.hujson") err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_range.hujson")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
rules, err := h.generateACLRules() rules, err := app.generateACLRules()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil) c.Assert(rules, check.NotNil)
@ -63,10 +67,10 @@ func (s *Suite) TestPortRange(c *check.C) {
} }
func (s *Suite) TestPortWildcard(c *check.C) { func (s *Suite) TestPortWildcard(c *check.C) {
err := h.LoadACLPolicy("./tests/acls/acl_policy_basic_wildcards.hujson") err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_wildcards.hujson")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
rules, err := h.generateACLRules() rules, err := app.generateACLRules()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil) c.Assert(rules, check.NotNil)
@ -79,33 +83,35 @@ func (s *Suite) TestPortWildcard(c *check.C) {
} }
func (s *Suite) TestPortNamespace(c *check.C) { func (s *Suite) TestPortNamespace(c *check.C) {
n, err := h.CreateNamespace("testnamespace") namespace, err := app.CreateNamespace("testnamespace")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = h.GetMachine("testnamespace", "testmachine") _, err = app.GetMachine("testnamespace", "testmachine")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
ip, _ := h.getAvailableIP() ip, _ := app.getAvailableIP()
m := Machine{ machine := Machine{
ID: 0, ID: 0,
MachineKey: "foo", MachineKey: "foo",
NodeKey: "bar", NodeKey: "bar",
DiscoKey: "faa", DiscoKey: "faa",
Name: "testmachine", Name: "testmachine",
NamespaceID: n.ID, NamespaceID: namespace.ID,
Registered: true, Registered: true,
RegisterMethod: "authKey", RegisterMethod: "authKey",
IPAddress: ip.String(), IPAddress: ip.String(),
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
} }
h.db.Save(&m) app.db.Save(&machine)
err = h.LoadACLPolicy("./tests/acls/acl_policy_basic_namespace_as_user.hujson") err = app.LoadACLPolicy(
"./tests/acls/acl_policy_basic_namespace_as_user.hujson",
)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
rules, err := h.generateACLRules() rules, err := app.generateACLRules()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil) c.Assert(rules, check.NotNil)
@ -119,33 +125,33 @@ func (s *Suite) TestPortNamespace(c *check.C) {
} }
func (s *Suite) TestPortGroup(c *check.C) { func (s *Suite) TestPortGroup(c *check.C) {
n, err := h.CreateNamespace("testnamespace") namespace, err := app.CreateNamespace("testnamespace")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = h.GetMachine("testnamespace", "testmachine") _, err = app.GetMachine("testnamespace", "testmachine")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
ip, _ := h.getAvailableIP() ip, _ := app.getAvailableIP()
m := Machine{ machine := Machine{
ID: 0, ID: 0,
MachineKey: "foo", MachineKey: "foo",
NodeKey: "bar", NodeKey: "bar",
DiscoKey: "faa", DiscoKey: "faa",
Name: "testmachine", Name: "testmachine",
NamespaceID: n.ID, NamespaceID: namespace.ID,
Registered: true, Registered: true,
RegisterMethod: "authKey", RegisterMethod: "authKey",
IPAddress: ip.String(), IPAddress: ip.String(),
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
} }
h.db.Save(&m) app.db.Save(&machine)
err = h.LoadACLPolicy("./tests/acls/acl_policy_basic_groups.hujson") err = app.LoadACLPolicy("./tests/acls/acl_policy_basic_groups.hujson")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
rules, err := h.generateACLRules() rules, err := app.generateACLRules()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil) c.Assert(rules, check.NotNil)

View file

@ -8,7 +8,7 @@ import (
"inet.af/netaddr" "inet.af/netaddr"
) )
// ACLPolicy represents a Tailscale ACL Policy // ACLPolicy represents a Tailscale ACL Policy.
type ACLPolicy struct { type ACLPolicy struct {
Groups Groups `json:"Groups"` Groups Groups `json:"Groups"`
Hosts Hosts `json:"Hosts"` Hosts Hosts `json:"Hosts"`
@ -17,61 +17,63 @@ type ACLPolicy struct {
Tests []ACLTest `json:"Tests"` Tests []ACLTest `json:"Tests"`
} }
// ACL is a basic rule for the ACL Policy // ACL is a basic rule for the ACL Policy.
type ACL struct { type ACL struct {
Action string `json:"Action"` Action string `json:"Action"`
Users []string `json:"Users"` Users []string `json:"Users"`
Ports []string `json:"Ports"` Ports []string `json:"Ports"`
} }
// Groups references a series of alias in the ACL rules // Groups references a series of alias in the ACL rules.
type Groups map[string][]string type Groups map[string][]string
// Hosts are alias for IP addresses or subnets // Hosts are alias for IP addresses or subnets.
type Hosts map[string]netaddr.IPPrefix type Hosts map[string]netaddr.IPPrefix
// TagOwners specify what users (namespaces?) are allow to use certain tags // TagOwners specify what users (namespaces?) are allow to use certain tags.
type TagOwners map[string][]string type TagOwners map[string][]string
// ACLTest is not implemented, but should be use to check if a certain rule is allowed // ACLTest is not implemented, but should be use to check if a certain rule is allowed.
type ACLTest struct { type ACLTest struct {
User string `json:"User"` User string `json:"User"`
Allow []string `json:"Allow"` Allow []string `json:"Allow"`
Deny []string `json:"Deny,omitempty"` Deny []string `json:"Deny,omitempty"`
} }
// UnmarshalJSON allows to parse the Hosts directly into netaddr objects // UnmarshalJSON allows to parse the Hosts directly into netaddr objects.
func (h *Hosts) UnmarshalJSON(data []byte) error { func (hosts *Hosts) UnmarshalJSON(data []byte) error {
hosts := Hosts{} newHosts := Hosts{}
hs := make(map[string]string) hostIPPrefixMap := make(map[string]string)
ast, err := hujson.Parse(data) ast, err := hujson.Parse(data)
if err != nil { if err != nil {
return err return err
} }
ast.Standardize() ast.Standardize()
data = ast.Pack() data = ast.Pack()
err = json.Unmarshal(data, &hs) err = json.Unmarshal(data, &hostIPPrefixMap)
if err != nil { if err != nil {
return err return err
} }
for k, v := range hs { for host, prefixStr := range hostIPPrefixMap {
if !strings.Contains(v, "/") { if !strings.Contains(prefixStr, "/") {
v = v + "/32" prefixStr += "/32"
} }
prefix, err := netaddr.ParseIPPrefix(v) prefix, err := netaddr.ParseIPPrefix(prefixStr)
if err != nil { if err != nil {
return err return err
} }
hosts[k] = prefix newHosts[host] = prefix
} }
*h = hosts *hosts = newHosts
return nil return nil
} }
// IsZero is perhaps a bit naive here // IsZero is perhaps a bit naive here.
func (p ACLPolicy) IsZero() bool { func (policy ACLPolicy) IsZero() bool {
if len(p.Groups) == 0 && len(p.Hosts) == 0 && len(p.ACLs) == 0 { if len(policy.Groups) == 0 && len(policy.Hosts) == 0 && len(policy.ACLs) == 0 {
return true return true
} }
return false return false
} }

297
api.go
View file

@ -10,31 +10,37 @@ import (
"strings" "strings"
"time" "time"
"github.com/rs/zerolog/log"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/klauspost/compress/zstd" "github.com/klauspost/compress/zstd"
"github.com/rs/zerolog/log"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/wgkey" "tailscale.com/types/wgkey"
) )
const reservedResponseHeaderSize = 4
// KeyHandler provides the Headscale pub key // KeyHandler provides the Headscale pub key
// Listens in /key // Listens in /key.
func (h *Headscale) KeyHandler(c *gin.Context) { func (h *Headscale) KeyHandler(ctx *gin.Context) {
c.Data(200, "text/plain; charset=utf-8", []byte(h.publicKey.HexString())) ctx.Data(
http.StatusOK,
"text/plain; charset=utf-8",
[]byte(h.publicKey.HexString()),
)
} }
// RegisterWebAPI shows a simple message in the browser to point to the CLI // RegisterWebAPI shows a simple message in the browser to point to the CLI
// Listens in /register // Listens in /register.
func (h *Headscale) RegisterWebAPI(c *gin.Context) { func (h *Headscale) RegisterWebAPI(ctx *gin.Context) {
mKeyStr := c.Query("key") machineKeyStr := ctx.Query("key")
if mKeyStr == "" { if machineKeyStr == "" {
c.String(http.StatusBadRequest, "Wrong params") ctx.String(http.StatusBadRequest, "Wrong params")
return return
} }
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(` ctx.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
<html> <html>
<body> <body>
<h1>headscale</h1> <h1>headscale</h1>
@ -51,43 +57,45 @@ func (h *Headscale) RegisterWebAPI(c *gin.Context) {
</body> </body>
</html> </html>
`, mKeyStr))) `, machineKeyStr)))
} }
// RegistrationHandler handles the actual registration process of a machine // RegistrationHandler handles the actual registration process of a machine
// Endpoint /machine/:id // Endpoint /machine/:id.
func (h *Headscale) RegistrationHandler(c *gin.Context) { func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
body, _ := io.ReadAll(c.Request.Body) body, _ := io.ReadAll(ctx.Request.Body)
mKeyStr := c.Param("id") machineKeyStr := ctx.Param("id")
mKey, err := wgkey.ParseHex(mKeyStr) machineKey, err := wgkey.ParseHex(machineKeyStr)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "Registration"). Str("handler", "Registration").
Err(err). Err(err).
Msg("Cannot parse machine key") Msg("Cannot parse machine key")
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc() machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
c.String(http.StatusInternalServerError, "Sad!") ctx.String(http.StatusInternalServerError, "Sad!")
return return
} }
req := tailcfg.RegisterRequest{} req := tailcfg.RegisterRequest{}
err = decode(body, &req, &mKey, h.privateKey) err = decode(body, &req, &machineKey, h.privateKey)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "Registration"). Str("handler", "Registration").
Err(err). Err(err).
Msg("Cannot decode message") Msg("Cannot decode message")
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc() machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
c.String(http.StatusInternalServerError, "Very sad!") ctx.String(http.StatusInternalServerError, "Very sad!")
return return
} }
now := time.Now().UTC() now := time.Now().UTC()
m, err := h.GetMachineByMachineKey(mKey.HexString()) machine, err := h.GetMachineByMachineKey(machineKey.HexString())
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine") log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine")
newMachine := Machine{ newMachine := Machine{
Expiry: &time.Time{}, Expiry: &time.Time{},
MachineKey: mKey.HexString(), MachineKey: machineKey.HexString(),
Name: req.Hostinfo.Hostname, Name: req.Hostinfo.Hostname,
} }
if err := h.db.Create(&newMachine).Error; err != nil { if err := h.db.Create(&newMachine).Error; err != nil {
@ -95,88 +103,96 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
Str("handler", "Registration"). Str("handler", "Registration").
Err(err). Err(err).
Msg("Could not create row") Msg("Could not create row")
machineRegistrations.WithLabelValues("unknown", "web", "error", m.Namespace.Name).Inc() machineRegistrations.WithLabelValues("unknown", "web", "error", machine.Namespace.Name).
Inc()
return return
} }
m = &newMachine machine = &newMachine
} }
if !m.Registered && req.Auth.AuthKey != "" { if !machine.Registered && req.Auth.AuthKey != "" {
h.handleAuthKey(c, h.db, mKey, req, *m) h.handleAuthKey(ctx, h.db, machineKey, req, *machine)
return return
} }
resp := tailcfg.RegisterResponse{} resp := tailcfg.RegisterResponse{}
// We have the updated key! // We have the updated key!
if m.NodeKey == wgkey.Key(req.NodeKey).HexString() { if machine.NodeKey == wgkey.Key(req.NodeKey).HexString() {
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout) // The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648 // https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) { if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) {
log.Info(). log.Info().
Str("handler", "Registration"). Str("handler", "Registration").
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Client requested logout") Msg("Client requested logout")
m.Expiry = &req.Expiry // save the expiry so that the machine is marked as expired machine.Expiry = &req.Expiry // save the expiry so that the machine is marked as expired
h.db.Save(&m) h.db.Save(&machine)
resp.AuthURL = "" resp.AuthURL = ""
resp.MachineAuthorized = false resp.MachineAuthorized = false
resp.User = *m.Namespace.toUser() resp.User = *machine.Namespace.toUser()
respBody, err := encode(resp, &mKey, h.privateKey) respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "Registration"). Str("handler", "Registration").
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
c.String(http.StatusInternalServerError, "") ctx.String(http.StatusInternalServerError, "")
return return
} }
c.Data(200, "application/json; charset=utf-8", respBody) ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
return return
} }
if m.Registered && m.Expiry.UTC().After(now) { if machine.Registered && machine.Expiry.UTC().After(now) {
// The machine registration is valid, respond with redirect to /map // The machine registration is valid, respond with redirect to /map
log.Debug(). log.Debug().
Str("handler", "Registration"). Str("handler", "Registration").
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Client is registered and we have the current NodeKey. All clear to /map") Msg("Client is registered and we have the current NodeKey. All clear to /map")
resp.AuthURL = "" resp.AuthURL = ""
resp.MachineAuthorized = true resp.MachineAuthorized = true
resp.User = *m.Namespace.toUser() resp.User = *machine.Namespace.toUser()
resp.Login = *m.Namespace.toLogin() resp.Login = *machine.Namespace.toLogin()
respBody, err := encode(resp, &mKey, h.privateKey) respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "Registration"). Str("handler", "Registration").
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
machineRegistrations.WithLabelValues("update", "web", "error", m.Namespace.Name).Inc() machineRegistrations.WithLabelValues("update", "web", "error", machine.Namespace.Name).
c.String(http.StatusInternalServerError, "") Inc()
ctx.String(http.StatusInternalServerError, "")
return return
} }
machineRegistrations.WithLabelValues("update", "web", "success", m.Namespace.Name).Inc() machineRegistrations.WithLabelValues("update", "web", "success", machine.Namespace.Name).
c.Data(200, "application/json; charset=utf-8", respBody) Inc()
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
return return
} }
// The client has registered before, but has expired // The client has registered before, but has expired
log.Debug(). log.Debug().
Str("handler", "Registration"). Str("handler", "Registration").
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Machine registration has expired. Sending a authurl to register") Msg("Machine registration has expired. Sending a authurl to register")
if h.cfg.OIDC.Issuer != "" { if h.cfg.OIDC.Issuer != "" {
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString())
} else { } else {
resp.AuthURL = fmt.Sprintf("%s/register?key=%s", resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString())
} }
// When a client connects, it may request a specific expiry time in its // When a client connects, it may request a specific expiry time in its
@ -185,102 +201,120 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
// into two steps (which cant pass arbitrary data between them easily) and needs to be // into two steps (which cant pass arbitrary data between them easily) and needs to be
// retrieved again after the user has authenticated. After the authentication flow // retrieved again after the user has authenticated. After the authentication flow
// completes, RequestedExpiry is copied into Expiry. // completes, RequestedExpiry is copied into Expiry.
m.RequestedExpiry = &req.Expiry machine.RequestedExpiry = &req.Expiry
h.db.Save(&m) h.db.Save(&machine)
respBody, err := encode(resp, &mKey, h.privateKey) respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "Registration"). Str("handler", "Registration").
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
machineRegistrations.WithLabelValues("new", "web", "error", m.Namespace.Name).Inc() machineRegistrations.WithLabelValues("new", "web", "error", machine.Namespace.Name).
c.String(http.StatusInternalServerError, "") Inc()
ctx.String(http.StatusInternalServerError, "")
return return
} }
machineRegistrations.WithLabelValues("new", "web", "success", m.Namespace.Name).Inc() machineRegistrations.WithLabelValues("new", "web", "success", machine.Namespace.Name).
c.Data(200, "application/json; charset=utf-8", respBody) Inc()
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
return return
} }
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration // The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
if m.NodeKey == wgkey.Key(req.OldNodeKey).HexString() && m.Expiry.UTC().After(now) { if machine.NodeKey == wgkey.Key(req.OldNodeKey).HexString() &&
machine.Expiry.UTC().After(now) {
log.Debug(). log.Debug().
Str("handler", "Registration"). Str("handler", "Registration").
Str("machine", m.Name). Str("machine", machine.Name).
Msg("We have the OldNodeKey in the database. This is a key refresh") Msg("We have the OldNodeKey in the database. This is a key refresh")
m.NodeKey = wgkey.Key(req.NodeKey).HexString() machine.NodeKey = wgkey.Key(req.NodeKey).HexString()
h.db.Save(&m) h.db.Save(&machine)
resp.AuthURL = "" resp.AuthURL = ""
resp.User = *m.Namespace.toUser() resp.User = *machine.Namespace.toUser()
respBody, err := encode(resp, &mKey, h.privateKey) respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "Registration"). Str("handler", "Registration").
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
c.String(http.StatusInternalServerError, "Extremely sad!") ctx.String(http.StatusInternalServerError, "Extremely sad!")
return return
} }
c.Data(200, "application/json; charset=utf-8", respBody) ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
return return
} }
// The machine registration is new, redirect the client to the registration URL // The machine registration is new, redirect the client to the registration URL
log.Debug(). log.Debug().
Str("handler", "Registration"). Str("handler", "Registration").
Str("machine", m.Name). Str("machine", machine.Name).
Msg("The node is sending us a new NodeKey, sending auth url") Msg("The node is sending us a new NodeKey, sending auth url")
if h.cfg.OIDC.Issuer != "" { if h.cfg.OIDC.Issuer != "" {
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) resp.AuthURL = fmt.Sprintf(
"%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"),
machineKey.HexString(),
)
} else { } else {
resp.AuthURL = fmt.Sprintf("%s/register?key=%s", resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString())
} }
// save the requested expiry time for retrieval later in the authentication flow // save the requested expiry time for retrieval later in the authentication flow
m.RequestedExpiry = &req.Expiry machine.RequestedExpiry = &req.Expiry
m.NodeKey = wgkey.Key(req.NodeKey).HexString() // save the NodeKey machine.NodeKey = wgkey.Key(req.NodeKey).HexString() // save the NodeKey
h.db.Save(&m) h.db.Save(&machine)
respBody, err := encode(resp, &mKey, h.privateKey) respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "Registration"). Str("handler", "Registration").
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
c.String(http.StatusInternalServerError, "") ctx.String(http.StatusInternalServerError, "")
return return
} }
c.Data(200, "application/json; charset=utf-8", respBody) ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
} }
func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m *Machine) ([]byte, error) { func (h *Headscale) getMapResponse(
machineKey wgkey.Key,
req tailcfg.MapRequest,
machine *Machine,
) ([]byte, error) {
log.Trace(). log.Trace().
Str("func", "getMapResponse"). Str("func", "getMapResponse").
Str("machine", req.Hostinfo.Hostname). Str("machine", req.Hostinfo.Hostname).
Msg("Creating Map response") Msg("Creating Map response")
node, err := m.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig, true) node, err := machine.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("func", "getMapResponse"). Str("func", "getMapResponse").
Err(err). Err(err).
Msg("Cannot convert to node") Msg("Cannot convert to node")
return nil, err return nil, err
} }
peers, err := h.getPeers(m) peers, err := h.getPeers(machine)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("func", "getMapResponse"). Str("func", "getMapResponse").
Err(err). Err(err).
Msg("Cannot fetch peers") Msg("Cannot fetch peers")
return nil, err return nil, err
} }
profiles := getMapResponseUserProfiles(*m, peers) profiles := getMapResponseUserProfiles(*machine, peers)
nodePeers, err := peers.toNodes(h.cfg.BaseDomain, h.cfg.DNSConfig, true) nodePeers, err := peers.toNodes(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
if err != nil { if err != nil {
@ -288,17 +322,16 @@ func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m *Ma
Str("func", "getMapResponse"). Str("func", "getMapResponse").
Err(err). Err(err).
Msg("Failed to convert peers to Tailscale nodes") Msg("Failed to convert peers to Tailscale nodes")
return nil, err return nil, err
} }
dnsConfig, err := getMapResponseDNSConfig(h.cfg.DNSConfig, h.cfg.BaseDomain, *m, peers) dnsConfig := getMapResponseDNSConfig(
if err != nil { h.cfg.DNSConfig,
log.Error(). h.cfg.BaseDomain,
Str("func", "getMapResponse"). *machine,
Err(err). peers,
Msg("Failed generate the DNSConfig") )
return nil, err
}
resp := tailcfg.MapResponse{ resp := tailcfg.MapResponse{
KeepAlive: false, KeepAlive: false,
@ -323,66 +356,71 @@ func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m *Ma
encoder, _ := zstd.NewWriter(nil) encoder, _ := zstd.NewWriter(nil)
srcCompressed := encoder.EncodeAll(src, nil) srcCompressed := encoder.EncodeAll(src, nil)
respBody, err = encodeMsg(srcCompressed, &mKey, h.privateKey) respBody, err = encodeMsg(srcCompressed, &machineKey, h.privateKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} else { } else {
respBody, err = encode(resp, &mKey, h.privateKey) respBody, err = encode(resp, &machineKey, h.privateKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
// declare the incoming size on the first 4 bytes // declare the incoming size on the first 4 bytes
data := make([]byte, 4) data := make([]byte, reservedResponseHeaderSize)
binary.LittleEndian.PutUint32(data, uint32(len(respBody))) binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
data = append(data, respBody...) data = append(data, respBody...)
return data, nil return data, nil
} }
func (h *Headscale) getMapKeepAliveResponse(mKey wgkey.Key, req tailcfg.MapRequest, m *Machine) ([]byte, error) { func (h *Headscale) getMapKeepAliveResponse(
resp := tailcfg.MapResponse{ machineKey wgkey.Key,
mapRequest tailcfg.MapRequest,
) ([]byte, error) {
mapResponse := tailcfg.MapResponse{
KeepAlive: true, KeepAlive: true,
} }
var respBody []byte var respBody []byte
var err error var err error
if req.Compress == "zstd" { if mapRequest.Compress == "zstd" {
src, _ := json.Marshal(resp) src, _ := json.Marshal(mapResponse)
encoder, _ := zstd.NewWriter(nil) encoder, _ := zstd.NewWriter(nil)
srcCompressed := encoder.EncodeAll(src, nil) srcCompressed := encoder.EncodeAll(src, nil)
respBody, err = encodeMsg(srcCompressed, &mKey, h.privateKey) respBody, err = encodeMsg(srcCompressed, &machineKey, h.privateKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} else { } else {
respBody, err = encode(resp, &mKey, h.privateKey) respBody, err = encode(mapResponse, &machineKey, h.privateKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
data := make([]byte, 4) data := make([]byte, reservedResponseHeaderSize)
binary.LittleEndian.PutUint32(data, uint32(len(respBody))) binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
data = append(data, respBody...) data = append(data, respBody...)
return data, nil return data, nil
} }
func (h *Headscale) handleAuthKey( func (h *Headscale) handleAuthKey(
c *gin.Context, ctx *gin.Context,
db *gorm.DB, db *gorm.DB,
idKey wgkey.Key, idKey wgkey.Key,
req tailcfg.RegisterRequest, reqisterRequest tailcfg.RegisterRequest,
m Machine, machine Machine,
) { ) {
log.Debug(). log.Debug().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", req.Hostinfo.Hostname). Str("machine", reqisterRequest.Hostinfo.Hostname).
Msgf("Processing auth key for %s", req.Hostinfo.Hostname) Msgf("Processing auth key for %s", reqisterRequest.Hostinfo.Hostname)
resp := tailcfg.RegisterResponse{} resp := tailcfg.RegisterResponse{}
pak, err := h.checkKeyValidity(req.Auth.AuthKey) pak, err := h.checkKeyValidity(reqisterRequest.Auth.AuthKey)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", m.Name). Str("machine", machine.Name).
Err(err). Err(err).
Msg("Failed authentication via AuthKey") Msg("Failed authentication via AuthKey")
resp.MachineAuthorized = false resp.MachineAuthorized = false
@ -390,48 +428,56 @@ func (h *Headscale) handleAuthKey(
if err != nil { if err != nil {
log.Error(). log.Error().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", m.Name). Str("machine", machine.Name).
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
c.String(http.StatusInternalServerError, "") ctx.String(http.StatusInternalServerError, "")
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).Inc() machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
Inc()
return return
} }
c.Data(401, "application/json; charset=utf-8", respBody) ctx.Data(http.StatusUnauthorized, "application/json; charset=utf-8", respBody)
log.Error(). log.Error().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Failed authentication via AuthKey") Msg("Failed authentication via AuthKey")
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).Inc() machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
Inc()
return return
} }
log.Debug(). log.Debug().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Authentication key was valid, proceeding to acquire an IP address") Msg("Authentication key was valid, proceeding to acquire an IP address")
ip, err := h.getAvailableIP() ip, err := h.getAvailableIP()
if err != nil { if err != nil {
log.Error(). log.Error().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Failed to find an available IP") Msg("Failed to find an available IP")
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).Inc() machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
Inc()
return return
} }
log.Info(). log.Info().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", m.Name). Str("machine", machine.Name).
Str("ip", ip.String()). Str("ip", ip.String()).
Msgf("Assigning %s to %s", ip, m.Name) Msgf("Assigning %s to %s", ip, machine.Name)
m.AuthKeyID = uint(pak.ID) machine.AuthKeyID = uint(pak.ID)
m.IPAddress = ip.String() machine.IPAddress = ip.String()
m.NamespaceID = pak.NamespaceID machine.NamespaceID = pak.NamespaceID
m.NodeKey = wgkey.Key(req.NodeKey).HexString() // we update it just in case machine.NodeKey = wgkey.Key(reqisterRequest.NodeKey).
m.Registered = true HexString()
m.RegisterMethod = "authKey" // we update it just in case
db.Save(&m) machine.Registered = true
machine.RegisterMethod = "authKey"
db.Save(&machine)
pak.Used = true pak.Used = true
db.Save(&pak) db.Save(&pak)
@ -442,18 +488,21 @@ func (h *Headscale) handleAuthKey(
if err != nil { if err != nil {
log.Error(). log.Error().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", m.Name). Str("machine", machine.Name).
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).Inc() machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
c.String(http.StatusInternalServerError, "Extremely sad!") Inc()
ctx.String(http.StatusInternalServerError, "Extremely sad!")
return return
} }
machineRegistrations.WithLabelValues("new", "authkey", "success", m.Namespace.Name).Inc() machineRegistrations.WithLabelValues("new", "authkey", "success", machine.Namespace.Name).
c.Data(200, "application/json; charset=utf-8", respBody) Inc()
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
log.Info(). log.Info().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", m.Name). Str("machine", machine.Name).
Str("ip", ip.String()). Str("ip", ip.String()).
Msg("Successfully authenticated via AuthKey") Msg("Successfully authenticated via AuthKey")
} }

271
app.go
View file

@ -18,20 +18,19 @@ import (
"time" "time"
"github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc"
"github.com/patrickmn/go-cache"
"golang.org/x/oauth2"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/grpc-ecosystem/go-grpc-middleware" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/philip-bui/grpc-zerolog" "github.com/patrickmn/go-cache"
zerolog "github.com/philip-bui/grpc-zerolog"
zl "github.com/rs/zerolog" zl "github.com/rs/zerolog"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/soheilhy/cmux" "github.com/soheilhy/cmux"
ginprometheus "github.com/zsais/go-gin-prometheus" ginprometheus "github.com/zsais/go-gin-prometheus"
"golang.org/x/crypto/acme" "golang.org/x/crypto/acme"
"golang.org/x/crypto/acme/autocert" "golang.org/x/crypto/acme/autocert"
"golang.org/x/oauth2"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
@ -48,7 +47,16 @@ import (
) )
const ( const (
AUTH_PREFIX = "Bearer " AuthPrefix = "Bearer "
Postgres = "postgresql"
Sqlite = "sqlite3"
updateInterval = 5000
HTTPReadTimeout = 30 * time.Second
errUnsupportedDatabase = Error("unsupported DB")
errUnsupportedLetsEncryptChallengeType = Error(
"unknown value for Lets Encrypt challenge type",
)
) )
// Config contains the initial Headscale configuration. // Config contains the initial Headscale configuration.
@ -151,16 +159,22 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
var dbString string var dbString string
switch cfg.DBtype { switch cfg.DBtype {
case "postgres": case Postgres:
dbString = fmt.Sprintf("host=%s port=%d dbname=%s user=%s password=%s sslmode=disable", cfg.DBhost, dbString = fmt.Sprintf(
cfg.DBport, cfg.DBname, cfg.DBuser, cfg.DBpass) "host=%s port=%d dbname=%s user=%s password=%s sslmode=disable",
case "sqlite3": cfg.DBhost,
cfg.DBport,
cfg.DBname,
cfg.DBuser,
cfg.DBpass,
)
case Sqlite:
dbString = cfg.DBpath dbString = cfg.DBpath
default: default:
return nil, errors.New("unsupported DB") return nil, errUnsupportedDatabase
} }
h := Headscale{ app := Headscale{
cfg: cfg, cfg: cfg,
dbType: cfg.DBtype, dbType: cfg.DBtype,
dbString: dbString, dbString: dbString,
@ -169,33 +183,32 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
aclRules: tailcfg.FilterAllowAll, // default allowall aclRules: tailcfg.FilterAllowAll, // default allowall
} }
err = h.initDB() err = app.initDB()
if err != nil { if err != nil {
return nil, err return nil, err
} }
if cfg.OIDC.Issuer != "" { if cfg.OIDC.Issuer != "" {
err = h.initOIDC() err = app.initOIDC()
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
if h.cfg.DNSConfig != nil && h.cfg.DNSConfig.Proxied { // if MagicDNS if app.cfg.DNSConfig != nil && app.cfg.DNSConfig.Proxied { // if MagicDNS
magicDNSDomains, err := generateMagicDNSRootDomains(h.cfg.IPPrefix, h.cfg.BaseDomain) magicDNSDomains := generateMagicDNSRootDomains(
if err != nil { app.cfg.IPPrefix,
return nil, err )
}
// we might have routes already from Split DNS // we might have routes already from Split DNS
if h.cfg.DNSConfig.Routes == nil { if app.cfg.DNSConfig.Routes == nil {
h.cfg.DNSConfig.Routes = make(map[string][]dnstype.Resolver) app.cfg.DNSConfig.Routes = make(map[string][]dnstype.Resolver)
} }
for _, d := range magicDNSDomains { for _, d := range magicDNSDomains {
h.cfg.DNSConfig.Routes[d.WithoutTrailingDot()] = nil app.cfg.DNSConfig.Routes[d.WithoutTrailingDot()] = nil
} }
} }
return &h, nil return &app, nil
} }
// Redirect to our TLS url. // Redirect to our TLS url.
@ -221,30 +234,37 @@ func (h *Headscale) expireEphemeralNodesWorker() {
return return
} }
for _, ns := range namespaces { for _, namespace := range namespaces {
machines, err := h.ListMachinesInNamespace(ns.Name) machines, err := h.ListMachinesInNamespace(namespace.Name)
if err != nil { if err != nil {
log.Error().Err(err).Str("namespace", ns.Name).Msg("Error listing machines in namespace") log.Error().
Err(err).
Str("namespace", namespace.Name).
Msg("Error listing machines in namespace")
return return
} }
for _, m := range machines { for _, machine := range machines {
if m.AuthKey != nil && m.LastSeen != nil && m.AuthKey.Ephemeral && if machine.AuthKey != nil && machine.LastSeen != nil &&
time.Now().After(m.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) { machine.AuthKey.Ephemeral &&
log.Info().Str("machine", m.Name).Msg("Ephemeral client removed from database") time.Now().
After(machine.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) {
log.Info().
Str("machine", machine.Name).
Msg("Ephemeral client removed from database")
err = h.db.Unscoped().Delete(m).Error err = h.db.Unscoped().Delete(machine).Error
if err != nil { if err != nil {
log.Error(). log.Error().
Err(err). Err(err).
Str("machine", m.Name). Str("machine", machine.Name).
Msg("🤮 Cannot delete ephemeral machine from the database") Msg("🤮 Cannot delete ephemeral machine from the database")
} }
} }
} }
h.setLastStateChangeToNow(ns.Name) h.setLastStateChangeToNow(namespace.Name)
} }
} }
@ -266,36 +286,56 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
req interface{}, req interface{},
info *grpc.UnaryServerInfo, info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler) (interface{}, error) { handler grpc.UnaryHandler) (interface{}, error) {
// Check if the request is coming from the on-server client. // Check if the request is coming from the on-server client.
// This is not secure, but it is to maintain maintainability // This is not secure, but it is to maintain maintainability
// with the "legacy" database-based client // with the "legacy" database-based client
// It is also neede for grpc-gateway to be able to connect to // It is also neede for grpc-gateway to be able to connect to
// the server // the server
p, _ := peer.FromContext(ctx) client, _ := peer.FromContext(ctx)
log.Trace().Caller().Str("client_address", p.Addr.String()).Msg("Client is trying to authenticate") log.Trace().
Caller().
Str("client_address", client.Addr.String()).
Msg("Client is trying to authenticate")
md, ok := metadata.FromIncomingContext(ctx) meta, ok := metadata.FromIncomingContext(ctx)
if !ok { if !ok {
log.Error().Caller().Str("client_address", p.Addr.String()).Msg("Retrieving metadata is failed") log.Error().
return ctx, status.Errorf(codes.InvalidArgument, "Retrieving metadata is failed") Caller().
Str("client_address", client.Addr.String()).
Msg("Retrieving metadata is failed")
return ctx, status.Errorf(
codes.InvalidArgument,
"Retrieving metadata is failed",
)
} }
authHeader, ok := md["authorization"] authHeader, ok := meta["authorization"]
if !ok { if !ok {
log.Error().Caller().Str("client_address", p.Addr.String()).Msg("Authorization token is not supplied") log.Error().
return ctx, status.Errorf(codes.Unauthenticated, "Authorization token is not supplied") Caller().
Str("client_address", client.Addr.String()).
Msg("Authorization token is not supplied")
return ctx, status.Errorf(
codes.Unauthenticated,
"Authorization token is not supplied",
)
} }
token := authHeader[0] token := authHeader[0]
if !strings.HasPrefix(token, AUTH_PREFIX) { if !strings.HasPrefix(token, AuthPrefix) {
log.Error(). log.Error().
Caller(). Caller().
Str("client_address", p.Addr.String()). Str("client_address", client.Addr.String()).
Msg(`missing "Bearer " prefix in "Authorization" header`) Msg(`missing "Bearer " prefix in "Authorization" header`)
return ctx, status.Error(codes.Unauthenticated, `missing "Bearer " prefix in "Authorization" header`)
return ctx, status.Error(
codes.Unauthenticated,
`missing "Bearer " prefix in "Authorization" header`,
)
} }
// TODO(kradalby): Implement API key backend: // TODO(kradalby): Implement API key backend:
@ -307,7 +347,10 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
// Currently all other than localhost traffic is unauthorized, this is intentional to allow // Currently all other than localhost traffic is unauthorized, this is intentional to allow
// us to make use of gRPC for our CLI, but not having to implement any of the remote capabilities // us to make use of gRPC for our CLI, but not having to implement any of the remote capabilities
// and API key auth // and API key auth
return ctx, status.Error(codes.Unauthenticated, "Authentication is not implemented yet") return ctx, status.Error(
codes.Unauthenticated,
"Authentication is not implemented yet",
)
// if strings.TrimPrefix(token, AUTH_PREFIX) != a.Token { // if strings.TrimPrefix(token, AUTH_PREFIX) != a.Token {
// log.Error().Caller().Str("client_address", p.Addr.String()).Msg("invalid token") // log.Error().Caller().Str("client_address", p.Addr.String()).Msg("invalid token")
@ -317,25 +360,25 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
// return handler(ctx, req) // return handler(ctx, req)
} }
func (h *Headscale) httpAuthenticationMiddleware(c *gin.Context) { func (h *Headscale) httpAuthenticationMiddleware(ctx *gin.Context) {
log.Trace(). log.Trace().
Caller(). Caller().
Str("client_address", c.ClientIP()). Str("client_address", ctx.ClientIP()).
Msg("HTTP authentication invoked") Msg("HTTP authentication invoked")
authHeader := c.GetHeader("authorization") authHeader := ctx.GetHeader("authorization")
if !strings.HasPrefix(authHeader, AUTH_PREFIX) { if !strings.HasPrefix(authHeader, AuthPrefix) {
log.Error(). log.Error().
Caller(). Caller().
Str("client_address", c.ClientIP()). Str("client_address", ctx.ClientIP()).
Msg(`missing "Bearer " prefix in "Authorization" header`) Msg(`missing "Bearer " prefix in "Authorization" header`)
c.AbortWithStatus(http.StatusUnauthorized) ctx.AbortWithStatus(http.StatusUnauthorized)
return return
} }
c.AbortWithStatus(http.StatusUnauthorized) ctx.AbortWithStatus(http.StatusUnauthorized)
// TODO(kradalby): Implement API key backend // TODO(kradalby): Implement API key backend
// Currently all traffic is unauthorized, this is intentional to allow // Currently all traffic is unauthorized, this is intentional to allow
@ -359,6 +402,7 @@ func (h *Headscale) ensureUnixSocketIsAbsent() error {
if _, err := os.Stat(h.cfg.UnixSocket); errors.Is(err, os.ErrNotExist) { if _, err := os.Stat(h.cfg.UnixSocket); errors.Is(err, os.ErrNotExist) {
return nil return nil
} }
return os.Remove(h.cfg.UnixSocket) return os.Remove(h.cfg.UnixSocket)
} }
@ -401,14 +445,17 @@ func (h *Headscale) Serve() error {
// Create the cmux object that will multiplex 2 protocols on the same port. // Create the cmux object that will multiplex 2 protocols on the same port.
// The two following listeners will be served on the same port below gracefully. // The two following listeners will be served on the same port below gracefully.
m := cmux.New(networkListener) networkMutex := cmux.New(networkListener)
// Match gRPC requests here // Match gRPC requests here
grpcListener := m.MatchWithWriters( grpcListener := networkMutex.MatchWithWriters(
cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc"), cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc"),
cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc+proto"), cmux.HTTP2MatchHeaderFieldSendSettings(
"content-type",
"application/grpc+proto",
),
) )
// Otherwise match regular http requests. // Otherwise match regular http requests.
httpListener := m.Match(cmux.Any()) httpListener := networkMutex.Match(cmux.Any())
grpcGatewayMux := runtime.NewServeMux() grpcGatewayMux := runtime.NewServeMux()
@ -431,30 +478,33 @@ func (h *Headscale) Serve() error {
return err return err
} }
r := gin.Default() router := gin.Default()
p := ginprometheus.NewPrometheus("gin") prometheus := ginprometheus.NewPrometheus("gin")
p.Use(r) prometheus.Use(router)
r.GET("/health", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"healthy": "ok"}) }) router.GET(
r.GET("/key", h.KeyHandler) "/health",
r.GET("/register", h.RegisterWebAPI) func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"healthy": "ok"}) },
r.POST("/machine/:id/map", h.PollNetMapHandler) )
r.POST("/machine/:id", h.RegistrationHandler) router.GET("/key", h.KeyHandler)
r.GET("/oidc/register/:mkey", h.RegisterOIDC) router.GET("/register", h.RegisterWebAPI)
r.GET("/oidc/callback", h.OIDCCallback) router.POST("/machine/:id/map", h.PollNetMapHandler)
r.GET("/apple", h.AppleMobileConfig) router.POST("/machine/:id", h.RegistrationHandler)
r.GET("/apple/:platform", h.ApplePlatformConfig) router.GET("/oidc/register/:mkey", h.RegisterOIDC)
r.GET("/swagger", SwaggerUI) router.GET("/oidc/callback", h.OIDCCallback)
r.GET("/swagger/v1/openapiv2.json", SwaggerAPIv1) router.GET("/apple", h.AppleMobileConfig)
router.GET("/apple/:platform", h.ApplePlatformConfig)
router.GET("/swagger", SwaggerUI)
router.GET("/swagger/v1/openapiv2.json", SwaggerAPIv1)
api := r.Group("/api") api := router.Group("/api")
api.Use(h.httpAuthenticationMiddleware) api.Use(h.httpAuthenticationMiddleware)
{ {
api.Any("/v1/*any", gin.WrapF(grpcGatewayMux.ServeHTTP)) api.Any("/v1/*any", gin.WrapF(grpcGatewayMux.ServeHTTP))
} }
r.NoRoute(stdoutHandler) router.NoRoute(stdoutHandler)
// Fetch an initial DERP Map before we start serving // Fetch an initial DERP Map before we start serving
h.DERPMap = GetDERPMap(h.cfg.DERP) h.DERPMap = GetDERPMap(h.cfg.DERP)
@ -466,14 +516,13 @@ func (h *Headscale) Serve() error {
} }
// I HATE THIS // I HATE THIS
updateMillisecondsWait := int64(5000) go h.watchForKVUpdates(updateInterval)
go h.watchForKVUpdates(updateMillisecondsWait) go h.expireEphemeralNodes(updateInterval)
go h.expireEphemeralNodes(updateMillisecondsWait)
httpServer := &http.Server{ httpServer := &http.Server{
Addr: h.cfg.Addr, Addr: h.cfg.Addr,
Handler: r, Handler: router,
ReadTimeout: 30 * time.Second, ReadTimeout: HTTPReadTimeout,
// Go does not handle timeouts in HTTP very well, and there is // Go does not handle timeouts in HTTP very well, and there is
// no good way to handle streaming timeouts, therefore we need to // no good way to handle streaming timeouts, therefore we need to
// keep this at unlimited and be careful to clean up connections // keep this at unlimited and be careful to clean up connections
@ -519,36 +568,40 @@ func (h *Headscale) Serve() error {
reflection.Register(grpcServer) reflection.Register(grpcServer)
reflection.Register(grpcSocket) reflection.Register(grpcSocket)
g := new(errgroup.Group) errorGroup := new(errgroup.Group)
g.Go(func() error { return grpcSocket.Serve(socketListener) }) errorGroup.Go(func() error { return grpcSocket.Serve(socketListener) })
// TODO(kradalby): Verify if we need the same TLS setup for gRPC as HTTP // TODO(kradalby): Verify if we need the same TLS setup for gRPC as HTTP
g.Go(func() error { return grpcServer.Serve(grpcListener) }) errorGroup.Go(func() error { return grpcServer.Serve(grpcListener) })
if tlsConfig != nil { if tlsConfig != nil {
g.Go(func() error { errorGroup.Go(func() error {
tlsl := tls.NewListener(httpListener, tlsConfig) tlsl := tls.NewListener(httpListener, tlsConfig)
return httpServer.Serve(tlsl) return httpServer.Serve(tlsl)
}) })
} else { } else {
g.Go(func() error { return httpServer.Serve(httpListener) }) errorGroup.Go(func() error { return httpServer.Serve(httpListener) })
} }
g.Go(func() error { return m.Serve() }) errorGroup.Go(func() error { return networkMutex.Serve() })
log.Info().Msgf("listening and serving (multiplexed HTTP and gRPC) on: %s", h.cfg.Addr) log.Info().
Msgf("listening and serving (multiplexed HTTP and gRPC) on: %s", h.cfg.Addr)
return g.Wait() return errorGroup.Wait()
} }
func (h *Headscale) getTLSSettings() (*tls.Config, error) { func (h *Headscale) getTLSSettings() (*tls.Config, error) {
var err error
if h.cfg.TLSLetsEncryptHostname != "" { if h.cfg.TLSLetsEncryptHostname != "" {
if !strings.HasPrefix(h.cfg.ServerURL, "https://") { if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
log.Warn().Msg("Listening with TLS but ServerURL does not start with https://") log.Warn().
Msg("Listening with TLS but ServerURL does not start with https://")
} }
m := autocert.Manager{ certManager := autocert.Manager{
Prompt: autocert.AcceptTOS, Prompt: autocert.AcceptTOS,
HostPolicy: autocert.HostWhitelist(h.cfg.TLSLetsEncryptHostname), HostPolicy: autocert.HostWhitelist(h.cfg.TLSLetsEncryptHostname),
Cache: autocert.DirCache(h.cfg.TLSLetsEncryptCacheDir), Cache: autocert.DirCache(h.cfg.TLSLetsEncryptCacheDir),
@ -558,40 +611,44 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
Email: h.cfg.ACMEEmail, Email: h.cfg.ACMEEmail,
} }
if h.cfg.TLSLetsEncryptChallengeType == "TLS-ALPN-01" { switch h.cfg.TLSLetsEncryptChallengeType {
case "TLS-ALPN-01":
// Configuration via autocert with TLS-ALPN-01 (https://tools.ietf.org/html/rfc8737) // Configuration via autocert with TLS-ALPN-01 (https://tools.ietf.org/html/rfc8737)
// The RFC requires that the validation is done on port 443; in other words, headscale // The RFC requires that the validation is done on port 443; in other words, headscale
// must be reachable on port 443. // must be reachable on port 443.
return m.TLSConfig(), nil return certManager.TLSConfig(), nil
} else if h.cfg.TLSLetsEncryptChallengeType == "HTTP-01" {
case "HTTP-01":
// Configuration via autocert with HTTP-01. This requires listening on // Configuration via autocert with HTTP-01. This requires listening on
// port 80 for the certificate validation in addition to the headscale // port 80 for the certificate validation in addition to the headscale
// service, which can be configured to run on any other port. // service, which can be configured to run on any other port.
go func() { go func() {
log.Fatal(). log.Fatal().
Err(http.ListenAndServe(h.cfg.TLSLetsEncryptListen, m.HTTPHandler(http.HandlerFunc(h.redirect)))). Err(http.ListenAndServe(h.cfg.TLSLetsEncryptListen, certManager.HTTPHandler(http.HandlerFunc(h.redirect)))).
Msg("failed to set up a HTTP server") Msg("failed to set up a HTTP server")
}() }()
return m.TLSConfig(), nil return certManager.TLSConfig(), nil
} else {
return nil, errors.New("unknown value for TLSLetsEncryptChallengeType") default:
return nil, errUnsupportedLetsEncryptChallengeType
} }
} else if h.cfg.TLSCertPath == "" { } else if h.cfg.TLSCertPath == "" {
if !strings.HasPrefix(h.cfg.ServerURL, "http://") { if !strings.HasPrefix(h.cfg.ServerURL, "http://") {
log.Warn().Msg("Listening without TLS but ServerURL does not start with http://") log.Warn().Msg("Listening without TLS but ServerURL does not start with http://")
} }
return nil, nil return nil, err
} else { } else {
if !strings.HasPrefix(h.cfg.ServerURL, "https://") { if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
log.Warn().Msg("Listening with TLS but ServerURL does not start with https://") log.Warn().Msg("Listening with TLS but ServerURL does not start with https://")
} }
var err error tlsConfig := &tls.Config{
tlsConfig := &tls.Config{} ClientAuth: tls.RequireAnyClientCert,
tlsConfig.ClientAuth = tls.RequireAnyClientCert NextProtos: []string{"http/1.1"},
tlsConfig.NextProtos = []string{"http/1.1"} Certificates: make([]tls.Certificate, 1),
tlsConfig.Certificates = make([]tls.Certificate, 1) MinVersion: tls.VersionTLS12,
}
tlsConfig.Certificates[0], err = tls.LoadX509KeyPair(h.cfg.TLSCertPath, h.cfg.TLSKeyPath) tlsConfig.Certificates[0], err = tls.LoadX509KeyPair(h.cfg.TLSCertPath, h.cfg.TLSKeyPath)
return tlsConfig, err return tlsConfig, err
@ -628,13 +685,13 @@ func (h *Headscale) getLastStateChange(namespaces ...string) time.Time {
} }
} }
func stdoutHandler(c *gin.Context) { func stdoutHandler(ctx *gin.Context) {
b, _ := io.ReadAll(c.Request.Body) body, _ := io.ReadAll(ctx.Request.Body)
log.Trace(). log.Trace().
Interface("header", c.Request.Header). Interface("header", ctx.Request.Header).
Interface("proto", c.Request.Proto). Interface("proto", ctx.Request.Proto).
Interface("url", c.Request.URL). Interface("url", ctx.Request.URL).
Bytes("body", b). Bytes("body", body).
Msg("Request did not match") Msg("Request did not match")
} }

View file

@ -17,8 +17,10 @@ var _ = check.Suite(&Suite{})
type Suite struct{} type Suite struct{}
var tmpDir string var (
var h Headscale tmpDir string
app Headscale
)
func (s *Suite) SetUpTest(c *check.C) { func (s *Suite) SetUpTest(c *check.C) {
s.ResetDB(c) s.ResetDB(c)
@ -41,18 +43,18 @@ func (s *Suite) ResetDB(c *check.C) {
IPPrefix: netaddr.MustParseIPPrefix("10.27.0.0/23"), IPPrefix: netaddr.MustParseIPPrefix("10.27.0.0/23"),
} }
h = Headscale{ app = Headscale{
cfg: cfg, cfg: cfg,
dbType: "sqlite3", dbType: "sqlite3",
dbString: tmpDir + "/headscale_test.db", dbString: tmpDir + "/headscale_test.db",
} }
err = h.initDB() err = app.initDB()
if err != nil { if err != nil {
c.Fatal(err) c.Fatal(err)
} }
db, err := h.openDB() db, err := app.openDB()
if err != nil { if err != nil {
c.Fatal(err) c.Fatal(err)
} }
h.db = db app.db = db
} }

View file

@ -5,16 +5,15 @@ import (
"net/http" "net/http"
"text/template" "text/template"
"github.com/rs/zerolog/log"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gofrs/uuid" "github.com/gofrs/uuid"
"github.com/rs/zerolog/log"
) )
// AppleMobileConfig shows a simple message in the browser to point to the CLI // AppleMobileConfig shows a simple message in the browser to point to the CLI
// Listens in /register // Listens in /register.
func (h *Headscale) AppleMobileConfig(c *gin.Context) { func (h *Headscale) AppleMobileConfig(ctx *gin.Context) {
t := template.Must(template.New("apple").Parse(` appleTemplate := template.Must(template.New("apple").Parse(`
<html> <html>
<body> <body>
<h1>Apple configuration profiles</h1> <h1>Apple configuration profiles</h1>
@ -56,7 +55,7 @@ func (h *Headscale) AppleMobileConfig(c *gin.Context) {
<p>Or</p> <p>Or</p>
<p>Use your terminal to configure the default setting for Tailscale by issuing:</p> <p>Use your terminal to configure the default setting for Tailscale by issuing:</p>
<code>defaults write io.tailscale.ipn.macos ControlURL {{.Url}}</code> <code>defaults write io.tailscale.ipn.macos ControlURL {{.URL}}</code>
<p>Restart Tailscale.app and log in.</p> <p>Restart Tailscale.app and log in.</p>
@ -64,24 +63,29 @@ func (h *Headscale) AppleMobileConfig(c *gin.Context) {
</html>`)) </html>`))
config := map[string]interface{}{ config := map[string]interface{}{
"Url": h.cfg.ServerURL, "URL": h.cfg.ServerURL,
} }
var payload bytes.Buffer var payload bytes.Buffer
if err := t.Execute(&payload, config); err != nil { if err := appleTemplate.Execute(&payload, config); err != nil {
log.Error(). log.Error().
Str("handler", "AppleMobileConfig"). Str("handler", "AppleMobileConfig").
Err(err). Err(err).
Msg("Could not render Apple index template") Msg("Could not render Apple index template")
c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Apple index template")) ctx.Data(
http.StatusInternalServerError,
"text/html; charset=utf-8",
[]byte("Could not render Apple index template"),
)
return return
} }
c.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes()) ctx.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes())
} }
func (h *Headscale) ApplePlatformConfig(c *gin.Context) { func (h *Headscale) ApplePlatformConfig(ctx *gin.Context) {
platform := c.Param("platform") platform := ctx.Param("platform")
id, err := uuid.NewV4() id, err := uuid.NewV4()
if err != nil { if err != nil {
@ -89,23 +93,33 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
Str("handler", "ApplePlatformConfig"). Str("handler", "ApplePlatformConfig").
Err(err). Err(err).
Msg("Failed not create UUID") Msg("Failed not create UUID")
c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Failed to create UUID")) ctx.Data(
http.StatusInternalServerError,
"text/html; charset=utf-8",
[]byte("Failed to create UUID"),
)
return return
} }
contentId, err := uuid.NewV4() contentID, err := uuid.NewV4()
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "ApplePlatformConfig"). Str("handler", "ApplePlatformConfig").
Err(err). Err(err).
Msg("Failed not create UUID") Msg("Failed not create UUID")
c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Failed to create UUID")) ctx.Data(
http.StatusInternalServerError,
"text/html; charset=utf-8",
[]byte("Failed to create UUID"),
)
return return
} }
platformConfig := AppleMobilePlatformConfig{ platformConfig := AppleMobilePlatformConfig{
UUID: contentId, UUID: contentID,
Url: h.cfg.ServerURL, URL: h.cfg.ServerURL,
} }
var payload bytes.Buffer var payload bytes.Buffer
@ -117,7 +131,12 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
Str("handler", "ApplePlatformConfig"). Str("handler", "ApplePlatformConfig").
Err(err). Err(err).
Msg("Could not render Apple macOS template") Msg("Could not render Apple macOS template")
c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Apple macOS template")) ctx.Data(
http.StatusInternalServerError,
"text/html; charset=utf-8",
[]byte("Could not render Apple macOS template"),
)
return return
} }
case "ios": case "ios":
@ -126,17 +145,27 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
Str("handler", "ApplePlatformConfig"). Str("handler", "ApplePlatformConfig").
Err(err). Err(err).
Msg("Could not render Apple iOS template") Msg("Could not render Apple iOS template")
c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Apple iOS template")) ctx.Data(
http.StatusInternalServerError,
"text/html; charset=utf-8",
[]byte("Could not render Apple iOS template"),
)
return return
} }
default: default:
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte("Invalid platform, only ios and macos is supported")) ctx.Data(
http.StatusOK,
"text/html; charset=utf-8",
[]byte("Invalid platform, only ios and macos is supported"),
)
return return
} }
config := AppleMobileConfig{ config := AppleMobileConfig{
UUID: id, UUID: id,
Url: h.cfg.ServerURL, URL: h.cfg.ServerURL,
Payload: payload.String(), Payload: payload.String(),
} }
@ -146,25 +175,35 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
Str("handler", "ApplePlatformConfig"). Str("handler", "ApplePlatformConfig").
Err(err). Err(err).
Msg("Could not render Apple platform template") Msg("Could not render Apple platform template")
c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Apple platform template")) ctx.Data(
http.StatusInternalServerError,
"text/html; charset=utf-8",
[]byte("Could not render Apple platform template"),
)
return return
} }
c.Data(http.StatusOK, "application/x-apple-aspen-config; charset=utf-8", content.Bytes()) ctx.Data(
http.StatusOK,
"application/x-apple-aspen-config; charset=utf-8",
content.Bytes(),
)
} }
type AppleMobileConfig struct { type AppleMobileConfig struct {
UUID uuid.UUID UUID uuid.UUID
Url string URL string
Payload string Payload string
} }
type AppleMobilePlatformConfig struct { type AppleMobilePlatformConfig struct {
UUID uuid.UUID UUID uuid.UUID
Url string URL string
} }
var commonTemplate = template.Must(template.New("mobileconfig").Parse(`<?xml version="1.0" encoding="UTF-8"?> var commonTemplate = template.Must(
template.New("mobileconfig").Parse(`<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd"> <!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0"> <plist version="1.0">
<dict> <dict>
@ -173,7 +212,7 @@ var commonTemplate = template.Must(template.New("mobileconfig").Parse(`<?xml ver
<key>PayloadDisplayName</key> <key>PayloadDisplayName</key>
<string>Headscale</string> <string>Headscale</string>
<key>PayloadDescription</key> <key>PayloadDescription</key>
<string>Configure Tailscale login server to: {{.Url}}</string> <string>Configure Tailscale login server to: {{.URL}}</string>
<key>PayloadIdentifier</key> <key>PayloadIdentifier</key>
<string>com.github.juanfont.headscale</string> <string>com.github.juanfont.headscale</string>
<key>PayloadRemovalDisallowed</key> <key>PayloadRemovalDisallowed</key>
@ -187,7 +226,8 @@ var commonTemplate = template.Must(template.New("mobileconfig").Parse(`<?xml ver
{{.Payload}} {{.Payload}}
</array> </array>
</dict> </dict>
</plist>`)) </plist>`),
)
var iosTemplate = template.Must(template.New("iosTemplate").Parse(` var iosTemplate = template.Must(template.New("iosTemplate").Parse(`
<dict> <dict>
@ -203,7 +243,7 @@ var iosTemplate = template.Must(template.New("iosTemplate").Parse(`
<true/> <true/>
<key>ControlURL</key> <key>ControlURL</key>
<string>{{.Url}}</string> <string>{{.URL}}</string>
</dict> </dict>
`)) `))
@ -221,6 +261,6 @@ var macosTemplate = template.Must(template.New("macosTemplate").Parse(`
<true/> <true/>
<key>ControlURL</key> <key>ControlURL</key>
<string>{{.Url}}</string> <string>{{.URL}}</string>
</dict> </dict>
`)) `))

View file

@ -7,31 +7,34 @@ import (
) )
func (s *Suite) TestRegisterMachine(c *check.C) { func (s *Suite) TestRegisterMachine(c *check.C) {
n, err := h.CreateNamespace("test") namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
now := time.Now().UTC() now := time.Now().UTC()
m := Machine{ machine := Machine{
ID: 0, ID: 0,
MachineKey: "8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e", MachineKey: "8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e",
NodeKey: "bar", NodeKey: "bar",
DiscoKey: "faa", DiscoKey: "faa",
Name: "testmachine", Name: "testmachine",
NamespaceID: n.ID, NamespaceID: namespace.ID,
IPAddress: "10.0.0.1", IPAddress: "10.0.0.1",
Expiry: &now, Expiry: &now,
RequestedExpiry: &now, RequestedExpiry: &now,
} }
h.db.Save(&m) app.db.Save(&machine)
_, err = h.GetMachine("test", "testmachine") _, err = app.GetMachine("test", "testmachine")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
m2, err := h.RegisterMachine("8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e", n.Name) machineAfterRegistering, err := app.RegisterMachine(
"8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e",
namespace.Name,
)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(m2.Registered, check.Equals, true) c.Assert(machineAfterRegistering.Registered, check.Equals, true)
_, err = m2.GetHostInfo() _, err = machineAfterRegistering.GetHostInfo()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
} }

View file

@ -27,7 +27,8 @@ func init() {
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("") log.Fatal().Err(err).Msg("")
} }
createNodeCmd.Flags().StringSliceP("route", "r", []string{}, "List (or repeated flags) of routes to advertise") createNodeCmd.Flags().
StringSliceP("route", "r", []string{}, "List (or repeated flags) of routes to advertise")
debugCmd.AddCommand(createNodeCmd) debugCmd.AddCommand(createNodeCmd)
} }
@ -47,6 +48,7 @@ var createNodeCmd = &cobra.Command{
namespace, err := cmd.Flags().GetString("namespace") namespace, err := cmd.Flags().GetString("namespace")
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
return return
} }
@ -56,19 +58,34 @@ var createNodeCmd = &cobra.Command{
name, err := cmd.Flags().GetString("name") name, err := cmd.Flags().GetString("name")
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting node from flag: %s", err), output) ErrorOutput(
err,
fmt.Sprintf("Error getting node from flag: %s", err),
output,
)
return return
} }
machineKey, err := cmd.Flags().GetString("key") machineKey, err := cmd.Flags().GetString("key")
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting key from flag: %s", err), output) ErrorOutput(
err,
fmt.Sprintf("Error getting key from flag: %s", err),
output,
)
return return
} }
routes, err := cmd.Flags().GetStringSlice("route") routes, err := cmd.Flags().GetStringSlice("route")
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting routes from flag: %s", err), output) ErrorOutput(
err,
fmt.Sprintf("Error getting routes from flag: %s", err),
output,
)
return return
} }
@ -81,7 +98,12 @@ var createNodeCmd = &cobra.Command{
response, err := client.DebugCreateMachine(ctx, request) response, err := client.DebugCreateMachine(ctx, request)
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Cannot create machine: %s", status.Convert(err).Message()), output) ErrorOutput(
err,
fmt.Sprintf("Cannot create machine: %s", status.Convert(err).Message()),
output,
)
return return
} }

View file

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
survey "github.com/AlecAivazis/survey/v2" survey "github.com/AlecAivazis/survey/v2"
"github.com/juanfont/headscale"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/pterm/pterm" "github.com/pterm/pterm"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -19,6 +20,10 @@ func init() {
namespaceCmd.AddCommand(renameNamespaceCmd) namespaceCmd.AddCommand(renameNamespaceCmd)
} }
const (
errMissingParameter = headscale.Error("missing parameters")
)
var namespaceCmd = &cobra.Command{ var namespaceCmd = &cobra.Command{
Use: "namespaces", Use: "namespaces",
Short: "Manage the namespaces of Headscale", Short: "Manage the namespaces of Headscale",
@ -29,8 +34,9 @@ var createNamespaceCmd = &cobra.Command{
Short: "Creates a new namespace", Short: "Creates a new namespace",
Args: func(cmd *cobra.Command, args []string) error { Args: func(cmd *cobra.Command, args []string) error {
if len(args) < 1 { if len(args) < 1 {
return fmt.Errorf("Missing parameters") return errMissingParameter
} }
return nil return nil
}, },
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
@ -49,7 +55,15 @@ var createNamespaceCmd = &cobra.Command{
log.Trace().Interface("request", request).Msg("Sending CreateNamespace request") log.Trace().Interface("request", request).Msg("Sending CreateNamespace request")
response, err := client.CreateNamespace(ctx, request) response, err := client.CreateNamespace(ctx, request)
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Cannot create namespace: %s", status.Convert(err).Message()), output) ErrorOutput(
err,
fmt.Sprintf(
"Cannot create namespace: %s",
status.Convert(err).Message(),
),
output,
)
return return
} }
@ -62,8 +76,9 @@ var destroyNamespaceCmd = &cobra.Command{
Short: "Destroys a namespace", Short: "Destroys a namespace",
Args: func(cmd *cobra.Command, args []string) error { Args: func(cmd *cobra.Command, args []string) error {
if len(args) < 1 { if len(args) < 1 {
return fmt.Errorf("Missing parameters") return errMissingParameter
} }
return nil return nil
}, },
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
@ -81,7 +96,12 @@ var destroyNamespaceCmd = &cobra.Command{
_, err := client.GetNamespace(ctx, request) _, err := client.GetNamespace(ctx, request)
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error: %s", status.Convert(err).Message()), output) ErrorOutput(
err,
fmt.Sprintf("Error: %s", status.Convert(err).Message()),
output,
)
return return
} }
@ -89,7 +109,10 @@ var destroyNamespaceCmd = &cobra.Command{
force, _ := cmd.Flags().GetBool("force") force, _ := cmd.Flags().GetBool("force")
if !force { if !force {
prompt := &survey.Confirm{ prompt := &survey.Confirm{
Message: fmt.Sprintf("Do you want to remove the namespace '%s' and any associated preauthkeys?", namespaceName), Message: fmt.Sprintf(
"Do you want to remove the namespace '%s' and any associated preauthkeys?",
namespaceName,
),
} }
err := survey.AskOne(prompt, &confirm) err := survey.AskOne(prompt, &confirm)
if err != nil { if err != nil {
@ -102,7 +125,15 @@ var destroyNamespaceCmd = &cobra.Command{
response, err := client.DeleteNamespace(ctx, request) response, err := client.DeleteNamespace(ctx, request)
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Cannot destroy namespace: %s", status.Convert(err).Message()), output) ErrorOutput(
err,
fmt.Sprintf(
"Cannot destroy namespace: %s",
status.Convert(err).Message(),
),
output,
)
return return
} }
SuccessOutput(response, "Namespace destroyed", output) SuccessOutput(response, "Namespace destroyed", output)
@ -126,19 +157,25 @@ var listNamespacesCmd = &cobra.Command{
response, err := client.ListNamespaces(ctx, request) response, err := client.ListNamespaces(ctx, request)
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Cannot get namespaces: %s", status.Convert(err).Message()), output) ErrorOutput(
err,
fmt.Sprintf("Cannot get namespaces: %s", status.Convert(err).Message()),
output,
)
return return
} }
if output != "" { if output != "" {
SuccessOutput(response.Namespaces, "", output) SuccessOutput(response.Namespaces, "", output)
return return
} }
d := pterm.TableData{{"ID", "Name", "Created"}} tableData := pterm.TableData{{"ID", "Name", "Created"}}
for _, namespace := range response.GetNamespaces() { for _, namespace := range response.GetNamespaces() {
d = append( tableData = append(
d, tableData,
[]string{ []string{
namespace.GetId(), namespace.GetId(),
namespace.GetName(), namespace.GetName(),
@ -146,9 +183,14 @@ var listNamespacesCmd = &cobra.Command{
}, },
) )
} }
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render() err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to render pterm table: %s", err), output) ErrorOutput(
err,
fmt.Sprintf("Failed to render pterm table: %s", err),
output,
)
return return
} }
}, },
@ -158,9 +200,11 @@ var renameNamespaceCmd = &cobra.Command{
Use: "rename OLD_NAME NEW_NAME", Use: "rename OLD_NAME NEW_NAME",
Short: "Renames a namespace", Short: "Renames a namespace",
Args: func(cmd *cobra.Command, args []string) error { Args: func(cmd *cobra.Command, args []string) error {
if len(args) < 2 { expectedArguments := 2
return fmt.Errorf("Missing parameters") if len(args) < expectedArguments {
return errMissingParameter
} }
return nil return nil
}, },
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
@ -177,7 +221,15 @@ var renameNamespaceCmd = &cobra.Command{
response, err := client.RenameNamespace(ctx, request) response, err := client.RenameNamespace(ctx, request)
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Cannot rename namespace: %s", status.Convert(err).Message()), output) ErrorOutput(
err,
fmt.Sprintf(
"Cannot rename namespace: %s",
status.Convert(err).Message(),
),
output,
)
return return
} }

View file

@ -7,6 +7,7 @@ import (
"time" "time"
survey "github.com/AlecAivazis/survey/v2" survey "github.com/AlecAivazis/survey/v2"
"github.com/juanfont/headscale"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/pterm/pterm" "github.com/pterm/pterm"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@ -77,6 +78,7 @@ var registerNodeCmd = &cobra.Command{
namespace, err := cmd.Flags().GetString("namespace") namespace, err := cmd.Flags().GetString("namespace")
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
return return
} }
@ -86,7 +88,12 @@ var registerNodeCmd = &cobra.Command{
machineKey, err := cmd.Flags().GetString("key") machineKey, err := cmd.Flags().GetString("key")
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting machine key from flag: %s", err), output) ErrorOutput(
err,
fmt.Sprintf("Error getting machine key from flag: %s", err),
output,
)
return return
} }
@ -97,7 +104,15 @@ var registerNodeCmd = &cobra.Command{
response, err := client.RegisterMachine(ctx, request) response, err := client.RegisterMachine(ctx, request)
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Cannot register machine: %s\n", status.Convert(err).Message()), output) ErrorOutput(
err,
fmt.Sprintf(
"Cannot register machine: %s\n",
status.Convert(err).Message(),
),
output,
)
return return
} }
@ -113,6 +128,7 @@ var listNodesCmd = &cobra.Command{
namespace, err := cmd.Flags().GetString("namespace") namespace, err := cmd.Flags().GetString("namespace")
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
return return
} }
@ -126,24 +142,36 @@ var listNodesCmd = &cobra.Command{
response, err := client.ListMachines(ctx, request) response, err := client.ListMachines(ctx, request)
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()), output) ErrorOutput(
err,
fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()),
output,
)
return return
} }
if output != "" { if output != "" {
SuccessOutput(response.Machines, "", output) SuccessOutput(response.Machines, "", output)
return return
} }
d, err := nodesToPtables(namespace, response.Machines) tableData, err := nodesToPtables(namespace, response.Machines)
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
return return
} }
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render() err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to render pterm table: %s", err), output) ErrorOutput(
err,
fmt.Sprintf("Failed to render pterm table: %s", err),
output,
)
return return
} }
}, },
@ -155,9 +183,14 @@ var deleteNodeCmd = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output") output, _ := cmd.Flags().GetString("output")
id, err := cmd.Flags().GetInt("identifier") identifier, err := cmd.Flags().GetInt("identifier")
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting ID to integer: %s", err), output) ErrorOutput(
err,
fmt.Sprintf("Error converting ID to integer: %s", err),
output,
)
return return
} }
@ -166,24 +199,35 @@ var deleteNodeCmd = &cobra.Command{
defer conn.Close() defer conn.Close()
getRequest := &v1.GetMachineRequest{ getRequest := &v1.GetMachineRequest{
MachineId: uint64(id), MachineId: uint64(identifier),
} }
getResponse, err := client.GetMachine(ctx, getRequest) getResponse, err := client.GetMachine(ctx, getRequest)
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting node node: %s", status.Convert(err).Message()), output) ErrorOutput(
err,
fmt.Sprintf(
"Error getting node node: %s",
status.Convert(err).Message(),
),
output,
)
return return
} }
deleteRequest := &v1.DeleteMachineRequest{ deleteRequest := &v1.DeleteMachineRequest{
MachineId: uint64(id), MachineId: uint64(identifier),
} }
confirm := false confirm := false
force, _ := cmd.Flags().GetBool("force") force, _ := cmd.Flags().GetBool("force")
if !force { if !force {
prompt := &survey.Confirm{ prompt := &survey.Confirm{
Message: fmt.Sprintf("Do you want to remove the node %s?", getResponse.GetMachine().Name), Message: fmt.Sprintf(
"Do you want to remove the node %s?",
getResponse.GetMachine().Name,
),
} }
err = survey.AskOne(prompt, &confirm) err = survey.AskOne(prompt, &confirm)
if err != nil { if err != nil {
@ -195,13 +239,26 @@ var deleteNodeCmd = &cobra.Command{
response, err := client.DeleteMachine(ctx, deleteRequest) response, err := client.DeleteMachine(ctx, deleteRequest)
if output != "" { if output != "" {
SuccessOutput(response, "", output) SuccessOutput(response, "", output)
return return
} }
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error deleting node: %s", status.Convert(err).Message()), output) ErrorOutput(
err,
fmt.Sprintf(
"Error deleting node: %s",
status.Convert(err).Message(),
),
output,
)
return return
} }
SuccessOutput(map[string]string{"Result": "Node deleted"}, "Node deleted", output) SuccessOutput(
map[string]string{"Result": "Node deleted"},
"Node deleted",
output,
)
} else { } else {
SuccessOutput(map[string]string{"Result": "Node not deleted"}, "Node not deleted", output) SuccessOutput(map[string]string{"Result": "Node not deleted"}, "Node not deleted", output)
} }
@ -210,12 +267,12 @@ var deleteNodeCmd = &cobra.Command{
func sharingWorker( func sharingWorker(
cmd *cobra.Command, cmd *cobra.Command,
args []string,
) (string, *v1.Machine, *v1.Namespace, error) { ) (string, *v1.Machine, *v1.Namespace, error) {
output, _ := cmd.Flags().GetString("output") output, _ := cmd.Flags().GetString("output")
namespaceStr, err := cmd.Flags().GetString("namespace") namespaceStr, err := cmd.Flags().GetString("namespace")
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
return "", nil, nil, err return "", nil, nil, err
} }
@ -223,19 +280,25 @@ func sharingWorker(
defer cancel() defer cancel()
defer conn.Close() defer conn.Close()
id, err := cmd.Flags().GetInt("identifier") identifier, err := cmd.Flags().GetInt("identifier")
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting ID to integer: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error converting ID to integer: %s", err), output)
return "", nil, nil, err return "", nil, nil, err
} }
machineRequest := &v1.GetMachineRequest{ machineRequest := &v1.GetMachineRequest{
MachineId: uint64(id), MachineId: uint64(identifier),
} }
machineResponse, err := client.GetMachine(ctx, machineRequest) machineResponse, err := client.GetMachine(ctx, machineRequest)
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting node node: %s", status.Convert(err).Message()), output) ErrorOutput(
err,
fmt.Sprintf("Error getting node node: %s", status.Convert(err).Message()),
output,
)
return "", nil, nil, err return "", nil, nil, err
} }
@ -245,7 +308,12 @@ func sharingWorker(
namespaceResponse, err := client.GetNamespace(ctx, namespaceRequest) namespaceResponse, err := client.GetNamespace(ctx, namespaceRequest)
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting node node: %s", status.Convert(err).Message()), output) ErrorOutput(
err,
fmt.Sprintf("Error getting node node: %s", status.Convert(err).Message()),
output,
)
return "", nil, nil, err return "", nil, nil, err
} }
@ -256,9 +324,14 @@ var shareMachineCmd = &cobra.Command{
Use: "share", Use: "share",
Short: "Shares a node from the current namespace to the specified one", Short: "Shares a node from the current namespace to the specified one",
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
output, machine, namespace, err := sharingWorker(cmd, args) output, machine, namespace, err := sharingWorker(cmd)
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to fetch namespace or machine: %s", err), output) ErrorOutput(
err,
fmt.Sprintf("Failed to fetch namespace or machine: %s", err),
output,
)
return return
} }
@ -273,7 +346,12 @@ var shareMachineCmd = &cobra.Command{
response, err := client.ShareMachine(ctx, request) response, err := client.ShareMachine(ctx, request)
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error sharing node: %s", status.Convert(err).Message()), output) ErrorOutput(
err,
fmt.Sprintf("Error sharing node: %s", status.Convert(err).Message()),
output,
)
return return
} }
@ -285,9 +363,14 @@ var unshareMachineCmd = &cobra.Command{
Use: "unshare", Use: "unshare",
Short: "Unshares a node from the specified namespace", Short: "Unshares a node from the specified namespace",
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
output, machine, namespace, err := sharingWorker(cmd, args) output, machine, namespace, err := sharingWorker(cmd)
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to fetch namespace or machine: %s", err), output) ErrorOutput(
err,
fmt.Sprintf("Failed to fetch namespace or machine: %s", err),
output,
)
return return
} }
@ -302,7 +385,12 @@ var unshareMachineCmd = &cobra.Command{
response, err := client.UnshareMachine(ctx, request) response, err := client.UnshareMachine(ctx, request)
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error unsharing node: %s", status.Convert(err).Message()), output) ErrorOutput(
err,
fmt.Sprintf("Error unsharing node: %s", status.Convert(err).Message()),
output,
)
return return
} }
@ -310,8 +398,22 @@ var unshareMachineCmd = &cobra.Command{
}, },
} }
func nodesToPtables(currentNamespace string, machines []*v1.Machine) (pterm.TableData, error) { func nodesToPtables(
d := pterm.TableData{{"ID", "Name", "NodeKey", "Namespace", "IP address", "Ephemeral", "Last seen", "Online"}} currentNamespace string,
machines []*v1.Machine,
) (pterm.TableData, error) {
tableData := pterm.TableData{
{
"ID",
"Name",
"NodeKey",
"Namespace",
"IP address",
"Ephemeral",
"Last seen",
"Online",
},
}
for _, machine := range machines { for _, machine := range machines {
var ephemeral bool var ephemeral bool
@ -331,7 +433,9 @@ func nodesToPtables(currentNamespace string, machines []*v1.Machine) (pterm.Tabl
nodeKey := tailcfg.NodeKey(nKey) nodeKey := tailcfg.NodeKey(nKey)
var online string var online string
if lastSeen.After(time.Now().Add(-5 * time.Minute)) { // TODO: Find a better way to reliably show if online if lastSeen.After(
time.Now().Add(-5 * time.Minute),
) { // TODO: Find a better way to reliably show if online
online = pterm.LightGreen("true") online = pterm.LightGreen("true")
} else { } else {
online = pterm.LightRed("false") online = pterm.LightRed("false")
@ -344,10 +448,10 @@ func nodesToPtables(currentNamespace string, machines []*v1.Machine) (pterm.Tabl
// Shared into this namespace // Shared into this namespace
namespace = pterm.LightYellow(machine.Namespace.Name) namespace = pterm.LightYellow(machine.Namespace.Name)
} }
d = append( tableData = append(
d, tableData,
[]string{ []string{
strconv.FormatUint(machine.Id, 10), strconv.FormatUint(machine.Id, headscale.Base10),
machine.Name, machine.Name,
nodeKey.ShortString(), nodeKey.ShortString(),
namespace, namespace,
@ -358,5 +462,6 @@ func nodesToPtables(currentNamespace string, machines []*v1.Machine) (pterm.Tabl
}, },
) )
} }
return d, nil
return tableData, nil
} }

View file

@ -12,6 +12,10 @@ import (
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
) )
const (
DefaultPreAuthKeyExpiry = 24 * time.Hour
)
func init() { func init() {
rootCmd.AddCommand(preauthkeysCmd) rootCmd.AddCommand(preauthkeysCmd)
preauthkeysCmd.PersistentFlags().StringP("namespace", "n", "", "Namespace") preauthkeysCmd.PersistentFlags().StringP("namespace", "n", "", "Namespace")
@ -22,10 +26,12 @@ func init() {
preauthkeysCmd.AddCommand(listPreAuthKeys) preauthkeysCmd.AddCommand(listPreAuthKeys)
preauthkeysCmd.AddCommand(createPreAuthKeyCmd) preauthkeysCmd.AddCommand(createPreAuthKeyCmd)
preauthkeysCmd.AddCommand(expirePreAuthKeyCmd) preauthkeysCmd.AddCommand(expirePreAuthKeyCmd)
createPreAuthKeyCmd.PersistentFlags().Bool("reusable", false, "Make the preauthkey reusable") createPreAuthKeyCmd.PersistentFlags().
createPreAuthKeyCmd.PersistentFlags().Bool("ephemeral", false, "Preauthkey for ephemeral nodes") Bool("reusable", false, "Make the preauthkey reusable")
createPreAuthKeyCmd.PersistentFlags().
Bool("ephemeral", false, "Preauthkey for ephemeral nodes")
createPreAuthKeyCmd.Flags(). createPreAuthKeyCmd.Flags().
DurationP("expiration", "e", 24*time.Hour, "Human-readable expiration of the key (30m, 24h, 365d...)") DurationP("expiration", "e", DefaultPreAuthKeyExpiry, "Human-readable expiration of the key (30m, 24h, 365d...)")
} }
var preauthkeysCmd = &cobra.Command{ var preauthkeysCmd = &cobra.Command{
@ -39,9 +45,10 @@ var listPreAuthKeys = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output") output, _ := cmd.Flags().GetString("output")
n, err := cmd.Flags().GetString("namespace") namespace, err := cmd.Flags().GetString("namespace")
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
return return
} }
@ -50,48 +57,61 @@ var listPreAuthKeys = &cobra.Command{
defer conn.Close() defer conn.Close()
request := &v1.ListPreAuthKeysRequest{ request := &v1.ListPreAuthKeysRequest{
Namespace: n, Namespace: namespace,
} }
response, err := client.ListPreAuthKeys(ctx, request) response, err := client.ListPreAuthKeys(ctx, request)
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting the list of keys: %s", err), output) ErrorOutput(
err,
fmt.Sprintf("Error getting the list of keys: %s", err),
output,
)
return return
} }
if output != "" { if output != "" {
SuccessOutput(response.PreAuthKeys, "", output) SuccessOutput(response.PreAuthKeys, "", output)
return return
} }
d := pterm.TableData{{"ID", "Key", "Reusable", "Ephemeral", "Used", "Expiration", "Created"}} tableData := pterm.TableData{
for _, k := range response.PreAuthKeys { {"ID", "Key", "Reusable", "Ephemeral", "Used", "Expiration", "Created"},
}
for _, key := range response.PreAuthKeys {
expiration := "-" expiration := "-"
if k.GetExpiration() != nil { if key.GetExpiration() != nil {
expiration = k.Expiration.AsTime().Format("2006-01-02 15:04:05") expiration = key.Expiration.AsTime().Format("2006-01-02 15:04:05")
} }
var reusable string var reusable string
if k.GetEphemeral() { if key.GetEphemeral() {
reusable = "N/A" reusable = "N/A"
} else { } else {
reusable = fmt.Sprintf("%v", k.GetReusable()) reusable = fmt.Sprintf("%v", key.GetReusable())
} }
d = append(d, []string{ tableData = append(tableData, []string{
k.GetId(), key.GetId(),
k.GetKey(), key.GetKey(),
reusable, reusable,
strconv.FormatBool(k.GetEphemeral()), strconv.FormatBool(key.GetEphemeral()),
strconv.FormatBool(k.GetUsed()), strconv.FormatBool(key.GetUsed()),
expiration, expiration,
k.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"), key.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"),
}) })
} }
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render() err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to render pterm table: %s", err), output) ErrorOutput(
err,
fmt.Sprintf("Failed to render pterm table: %s", err),
output,
)
return return
} }
}, },
@ -106,6 +126,7 @@ var createPreAuthKeyCmd = &cobra.Command{
namespace, err := cmd.Flags().GetString("namespace") namespace, err := cmd.Flags().GetString("namespace")
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
return return
} }
@ -139,7 +160,12 @@ var createPreAuthKeyCmd = &cobra.Command{
response, err := client.CreatePreAuthKey(ctx, request) response, err := client.CreatePreAuthKey(ctx, request)
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Cannot create Pre Auth Key: %s\n", err), output) ErrorOutput(
err,
fmt.Sprintf("Cannot create Pre Auth Key: %s\n", err),
output,
)
return return
} }
@ -152,8 +178,9 @@ var expirePreAuthKeyCmd = &cobra.Command{
Short: "Expire a preauthkey", Short: "Expire a preauthkey",
Args: func(cmd *cobra.Command, args []string) error { Args: func(cmd *cobra.Command, args []string) error {
if len(args) < 1 { if len(args) < 1 {
return fmt.Errorf("missing parameters") return errMissingParameter
} }
return nil return nil
}, },
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
@ -161,6 +188,7 @@ var expirePreAuthKeyCmd = &cobra.Command{
namespace, err := cmd.Flags().GetString("namespace") namespace, err := cmd.Flags().GetString("namespace")
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
return return
} }
@ -175,7 +203,12 @@ var expirePreAuthKeyCmd = &cobra.Command{
response, err := client.ExpirePreAuthKey(ctx, request) response, err := client.ExpirePreAuthKey(ctx, request)
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Cannot expire Pre Auth Key: %s\n", err), output) ErrorOutput(
err,
fmt.Sprintf("Cannot expire Pre Auth Key: %s\n", err),
output,
)
return return
} }

View file

@ -10,7 +10,8 @@ import (
func init() { func init() {
rootCmd.PersistentFlags(). rootCmd.PersistentFlags().
StringP("output", "o", "", "Output format. Empty for human-readable, 'json', 'json-line' or 'yaml'") StringP("output", "o", "", "Output format. Empty for human-readable, 'json', 'json-line' or 'yaml'")
rootCmd.PersistentFlags().Bool("force", false, "Disable prompts and forces the execution") rootCmd.PersistentFlags().
Bool("force", false, "Disable prompts and forces the execution")
} }
var rootCmd = &cobra.Command{ var rootCmd = &cobra.Command{

View file

@ -21,7 +21,8 @@ func init() {
} }
routesCmd.AddCommand(listRoutesCmd) routesCmd.AddCommand(listRoutesCmd)
enableRouteCmd.Flags().StringSliceP("route", "r", []string{}, "List (or repeated flags) of routes to enable") enableRouteCmd.Flags().
StringSliceP("route", "r", []string{}, "List (or repeated flags) of routes to enable")
enableRouteCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)") enableRouteCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
err = enableRouteCmd.MarkFlagRequired("identifier") err = enableRouteCmd.MarkFlagRequired("identifier")
if err != nil { if err != nil {
@ -44,9 +45,14 @@ var listRoutesCmd = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output") output, _ := cmd.Flags().GetString("output")
machineId, err := cmd.Flags().GetUint64("identifier") machineID, err := cmd.Flags().GetUint64("identifier")
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting machine id from flag: %s", err), output) ErrorOutput(
err,
fmt.Sprintf("Error getting machine id from flag: %s", err),
output,
)
return return
} }
@ -55,29 +61,41 @@ var listRoutesCmd = &cobra.Command{
defer conn.Close() defer conn.Close()
request := &v1.GetMachineRouteRequest{ request := &v1.GetMachineRouteRequest{
MachineId: machineId, MachineId: machineID,
} }
response, err := client.GetMachineRoute(ctx, request) response, err := client.GetMachineRoute(ctx, request)
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()), output) ErrorOutput(
err,
fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()),
output,
)
return return
} }
if output != "" { if output != "" {
SuccessOutput(response.Routes, "", output) SuccessOutput(response.Routes, "", output)
return return
} }
d := routesToPtables(response.Routes) tableData := routesToPtables(response.Routes)
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
return return
} }
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render() err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to render pterm table: %s", err), output) ErrorOutput(
err,
fmt.Sprintf("Failed to render pterm table: %s", err),
output,
)
return return
} }
}, },
@ -93,15 +111,26 @@ omit the route you do not want to enable.
`, `,
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output") output, _ := cmd.Flags().GetString("output")
machineId, err := cmd.Flags().GetUint64("identifier")
machineID, err := cmd.Flags().GetUint64("identifier")
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting machine id from flag: %s", err), output) ErrorOutput(
err,
fmt.Sprintf("Error getting machine id from flag: %s", err),
output,
)
return return
} }
routes, err := cmd.Flags().GetStringSlice("route") routes, err := cmd.Flags().GetStringSlice("route")
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting routes from flag: %s", err), output) ErrorOutput(
err,
fmt.Sprintf("Error getting routes from flag: %s", err),
output,
)
return return
} }
@ -110,45 +139,61 @@ omit the route you do not want to enable.
defer conn.Close() defer conn.Close()
request := &v1.EnableMachineRoutesRequest{ request := &v1.EnableMachineRoutesRequest{
MachineId: machineId, MachineId: machineID,
Routes: routes, Routes: routes,
} }
response, err := client.EnableMachineRoutes(ctx, request) response, err := client.EnableMachineRoutes(ctx, request)
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Cannot register machine: %s\n", status.Convert(err).Message()), output) ErrorOutput(
err,
fmt.Sprintf(
"Cannot register machine: %s\n",
status.Convert(err).Message(),
),
output,
)
return return
} }
if output != "" { if output != "" {
SuccessOutput(response.Routes, "", output) SuccessOutput(response.Routes, "", output)
return return
} }
d := routesToPtables(response.Routes) tableData := routesToPtables(response.Routes)
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
return return
} }
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render() err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to render pterm table: %s", err), output) ErrorOutput(
err,
fmt.Sprintf("Failed to render pterm table: %s", err),
output,
)
return return
} }
}, },
} }
// routesToPtables converts the list of routes to a nice table // routesToPtables converts the list of routes to a nice table.
func routesToPtables(routes *v1.Routes) pterm.TableData { func routesToPtables(routes *v1.Routes) pterm.TableData {
d := pterm.TableData{{"Route", "Enabled"}} tableData := pterm.TableData{{"Route", "Enabled"}}
for _, route := range routes.GetAdvertisedRoutes() { for _, route := range routes.GetAdvertisedRoutes() {
enabled := isStringInSlice(routes.EnabledRoutes, route) enabled := isStringInSlice(routes.EnabledRoutes, route)
d = append(d, []string{route, strconv.FormatBool(enabled)}) tableData = append(tableData, []string{route, strconv.FormatBool(enabled)})
} }
return d
return tableData
} }
func isStringInSlice(strs []string, s string) bool { func isStringInSlice(strs []string, s string) bool {

View file

@ -52,9 +52,8 @@ func LoadConfig(path string) error {
viper.SetDefault("cli.insecure", false) viper.SetDefault("cli.insecure", false)
viper.SetDefault("cli.timeout", "5s") viper.SetDefault("cli.timeout", "5s")
err := viper.ReadInConfig() if err := viper.ReadInConfig(); err != nil {
if err != nil { return fmt.Errorf("fatal error reading config file: %w", err)
return fmt.Errorf("Fatal error reading config file: %s \n", err)
} }
// Collect any validation errors and return them all at once // Collect any validation errors and return them all at once
@ -82,6 +81,7 @@ func LoadConfig(path string) error {
errorText += "Fatal config error: server_url must start with https:// or http://\n" errorText += "Fatal config error: server_url must start with https:// or http://\n"
} }
if errorText != "" { if errorText != "" {
//nolint
return errors.New(strings.TrimSuffix(errorText, "\n")) return errors.New(strings.TrimSuffix(errorText, "\n"))
} else { } else {
return nil return nil
@ -149,9 +149,14 @@ func GetDNSConfig() (*tailcfg.DNSConfig, string) {
if viper.IsSet("dns_config.restricted_nameservers") { if viper.IsSet("dns_config.restricted_nameservers") {
if len(dnsConfig.Nameservers) > 0 { if len(dnsConfig.Nameservers) > 0 {
dnsConfig.Routes = make(map[string][]dnstype.Resolver) dnsConfig.Routes = make(map[string][]dnstype.Resolver)
restrictedDNS := viper.GetStringMapStringSlice("dns_config.restricted_nameservers") restrictedDNS := viper.GetStringMapStringSlice(
"dns_config.restricted_nameservers",
)
for domain, restrictedNameservers := range restrictedDNS { for domain, restrictedNameservers := range restrictedDNS {
restrictedResolvers := make([]dnstype.Resolver, len(restrictedNameservers)) restrictedResolvers := make(
[]dnstype.Resolver,
len(restrictedNameservers),
)
for index, nameserverStr := range restrictedNameservers { for index, nameserverStr := range restrictedNameservers {
nameserver, err := netaddr.ParseIP(nameserverStr) nameserver, err := netaddr.ParseIP(nameserverStr)
if err != nil { if err != nil {
@ -208,6 +213,7 @@ func absPath(path string) string {
path = filepath.Join(dir, path) path = filepath.Join(dir, path)
} }
} }
return path return path
} }
@ -219,7 +225,9 @@ func getHeadscaleConfig() headscale.Config {
"10h", "10h",
) // use 10h here because it is the length of a standard business day plus a small amount of leeway ) // use 10h here because it is the length of a standard business day plus a small amount of leeway
if viper.GetDuration("max_machine_registration_duration") >= time.Second { if viper.GetDuration("max_machine_registration_duration") >= time.Second {
maxMachineRegistrationDuration = viper.GetDuration("max_machine_registration_duration") maxMachineRegistrationDuration = viper.GetDuration(
"max_machine_registration_duration",
)
} }
// defaultMachineRegistrationDuration is the default time assigned to a machine registration if one is not // defaultMachineRegistrationDuration is the default time assigned to a machine registration if one is not
@ -229,7 +237,9 @@ func getHeadscaleConfig() headscale.Config {
"8h", "8h",
) // use 8h here because it's the length of a standard business day ) // use 8h here because it's the length of a standard business day
if viper.GetDuration("default_machine_registration_duration") >= time.Second { if viper.GetDuration("default_machine_registration_duration") >= time.Second {
defaultMachineRegistrationDuration = viper.GetDuration("default_machine_registration_duration") defaultMachineRegistrationDuration = viper.GetDuration(
"default_machine_registration_duration",
)
} }
dnsConfig, baseDomain := GetDNSConfig() dnsConfig, baseDomain := GetDNSConfig()
@ -244,7 +254,9 @@ func getHeadscaleConfig() headscale.Config {
DERP: derpConfig, DERP: derpConfig,
EphemeralNodeInactivityTimeout: viper.GetDuration("ephemeral_node_inactivity_timeout"), EphemeralNodeInactivityTimeout: viper.GetDuration(
"ephemeral_node_inactivity_timeout",
),
DBtype: viper.GetString("db_type"), DBtype: viper.GetString("db_type"),
DBpath: absPath(viper.GetString("db_path")), DBpath: absPath(viper.GetString("db_path")),
@ -256,7 +268,9 @@ func getHeadscaleConfig() headscale.Config {
TLSLetsEncryptHostname: viper.GetString("tls_letsencrypt_hostname"), TLSLetsEncryptHostname: viper.GetString("tls_letsencrypt_hostname"),
TLSLetsEncryptListen: viper.GetString("tls_letsencrypt_listen"), TLSLetsEncryptListen: viper.GetString("tls_letsencrypt_listen"),
TLSLetsEncryptCacheDir: absPath(viper.GetString("tls_letsencrypt_cache_dir")), TLSLetsEncryptCacheDir: absPath(
viper.GetString("tls_letsencrypt_cache_dir"),
),
TLSLetsEncryptChallengeType: viper.GetString("tls_letsencrypt_challenge_type"), TLSLetsEncryptChallengeType: viper.GetString("tls_letsencrypt_challenge_type"),
TLSCertPath: absPath(viper.GetString("tls_cert_path")), TLSCertPath: absPath(viper.GetString("tls_cert_path")),
@ -292,11 +306,14 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
// to avoid races // to avoid races
minInactivityTimeout, _ := time.ParseDuration("65s") minInactivityTimeout, _ := time.ParseDuration("65s")
if viper.GetDuration("ephemeral_node_inactivity_timeout") <= minInactivityTimeout { if viper.GetDuration("ephemeral_node_inactivity_timeout") <= minInactivityTimeout {
// TODO: Find a better way to return this text
//nolint
err := fmt.Errorf( err := fmt.Errorf(
"ephemeral_node_inactivity_timeout (%s) is set too low, must be more than %s\n", "ephemeral_node_inactivity_timeout (%s) is set too low, must be more than %s",
viper.GetString("ephemeral_node_inactivity_timeout"), viper.GetString("ephemeral_node_inactivity_timeout"),
minInactivityTimeout, minInactivityTimeout,
) )
return nil, err return nil, err
} }
@ -304,7 +321,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
cfg.OIDC.MatchMap = loadOIDCMatchMap() cfg.OIDC.MatchMap = loadOIDCMatchMap()
h, err := headscale.NewHeadscale(cfg) app, err := headscale.NewHeadscale(cfg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -313,7 +330,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
if viper.GetString("acl_policy_path") != "" { if viper.GetString("acl_policy_path") != "" {
aclPath := absPath(viper.GetString("acl_policy_path")) aclPath := absPath(viper.GetString("acl_policy_path"))
err = h.LoadACLPolicy(aclPath) err = app.LoadACLPolicy(aclPath)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("path", aclPath). Str("path", aclPath).
@ -322,7 +339,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
} }
} }
return h, nil return app, nil
} }
func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) { func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) {
@ -342,7 +359,6 @@ func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.
// If the address is not set, we assume that we are on the server hosting headscale. // If the address is not set, we assume that we are on the server hosting headscale.
if address == "" { if address == "" {
log.Debug(). log.Debug().
Str("socket", cfg.UnixSocket). Str("socket", cfg.UnixSocket).
Msgf("HEADSCALE_CLI_ADDRESS environment is not set, connecting to unix socket.") Msgf("HEADSCALE_CLI_ADDRESS environment is not set, connecting to unix socket.")
@ -402,10 +418,13 @@ func SuccessOutput(result interface{}, override string, outputFormat string) {
log.Fatal().Err(err) log.Fatal().Err(err)
} }
default: default:
//nolint
fmt.Println(override) fmt.Println(override)
return return
} }
//nolint
fmt.Println(string(j)) fmt.Println(string(j))
} }
@ -423,6 +442,7 @@ func HasMachineOutputFlag() bool {
return true return true
} }
} }
return false return false
} }
@ -431,7 +451,10 @@ type tokenAuth struct {
} }
// Return value is mapped to request headers. // Return value is mapped to request headers.
func (t tokenAuth) GetRequestMetadata(ctx context.Context, in ...string) (map[string]string, error) { func (t tokenAuth) GetRequestMetadata(
ctx context.Context,
in ...string,
) (map[string]string, error) {
return map[string]string{ return map[string]string{
"authorization": "Bearer " + t.token, "authorization": "Bearer " + t.token,
}, nil }, nil

View file

@ -23,6 +23,8 @@ func main() {
colors = true colors = true
case termcolor.LevelBasic: case termcolor.LevelBasic:
colors = true colors = true
case termcolor.LevelNone:
colors = false
default: default:
// no color, return text as is. // no color, return text as is.
colors = false colors = false
@ -41,8 +43,7 @@ func main() {
NoColor: !colors, NoColor: !colors,
}) })
err := cli.LoadConfig("") if err := cli.LoadConfig(""); err != nil {
if err != nil {
log.Fatal().Err(err) log.Fatal().Err(err)
} }
@ -63,13 +64,15 @@ func main() {
} }
if !viper.GetBool("disable_check_updates") && !machineOutput { if !viper.GetBool("disable_check_updates") && !machineOutput {
if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") && cli.Version != "dev" { if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") &&
cli.Version != "dev" {
githubTag := &latest.GithubTag{ githubTag := &latest.GithubTag{
Owner: "juanfont", Owner: "juanfont",
Repository: "headscale", Repository: "headscale",
} }
res, err := latest.Check(githubTag, cli.Version) res, err := latest.Check(githubTag, cli.Version)
if err == nil && res.Outdated { if err == nil && res.Outdated {
//nolint
fmt.Printf( fmt.Printf(
"An updated version of Headscale has been found (%s vs. your current %s). Check it out https://github.com/juanfont/headscale/releases\n", "An updated version of Headscale has been found (%s vs. your current %s). Check it out https://github.com/juanfont/headscale/releases\n",
res.Current, res.Current,

View file

@ -1,7 +1,6 @@
package main package main
import ( import (
"fmt"
"io/ioutil" "io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
@ -40,7 +39,10 @@ func (*Suite) TestConfigLoading(c *check.C) {
} }
// Symlink the example config file // Symlink the example config file
err = os.Symlink(filepath.Clean(path+"/../../config-example.yaml"), filepath.Join(tmpDir, "config.yaml")) err = os.Symlink(
filepath.Clean(path+"/../../config-example.yaml"),
filepath.Join(tmpDir, "config.yaml"),
)
if err != nil { if err != nil {
c.Fatal(err) c.Fatal(err)
} }
@ -74,7 +76,10 @@ func (*Suite) TestDNSConfigLoading(c *check.C) {
} }
// Symlink the example config file // Symlink the example config file
err = os.Symlink(filepath.Clean(path+"/../../config-example.yaml"), filepath.Join(tmpDir, "config.yaml")) err = os.Symlink(
filepath.Clean(path+"/../../config-example.yaml"),
filepath.Join(tmpDir, "config.yaml"),
)
if err != nil { if err != nil {
c.Fatal(err) c.Fatal(err)
} }
@ -94,7 +99,7 @@ func (*Suite) TestDNSConfigLoading(c *check.C) {
func writeConfig(c *check.C, tmpDir string, configYaml []byte) { func writeConfig(c *check.C, tmpDir string, configYaml []byte) {
// Populate a custom config file // Populate a custom config file
configFile := filepath.Join(tmpDir, "config.yaml") configFile := filepath.Join(tmpDir, "config.yaml")
err := ioutil.WriteFile(configFile, configYaml, 0o644) err := ioutil.WriteFile(configFile, configYaml, 0o600)
if err != nil { if err != nil {
c.Fatalf("Couldn't write file %s", configFile) c.Fatalf("Couldn't write file %s", configFile)
} }
@ -106,7 +111,6 @@ func (*Suite) TestTLSConfigValidation(c *check.C) {
c.Fatal(err) c.Fatal(err)
} }
// defer os.RemoveAll(tmpDir) // defer os.RemoveAll(tmpDir)
fmt.Println(tmpDir)
configYaml := []byte( configYaml := []byte(
"---\ntls_letsencrypt_hostname: \"example.com\"\ntls_letsencrypt_challenge_type: \"\"\ntls_cert_path: \"abc.pem\"", "---\ntls_letsencrypt_hostname: \"example.com\"\ntls_letsencrypt_challenge_type: \"\"\ntls_cert_path: \"abc.pem\"",
@ -128,8 +132,11 @@ func (*Suite) TestTLSConfigValidation(c *check.C) {
check.Matches, check.Matches,
".*Fatal config error: the only supported values for tls_letsencrypt_challenge_type are.*", ".*Fatal config error: the only supported values for tls_letsencrypt_challenge_type are.*",
) )
c.Assert(tmp, check.Matches, ".*Fatal config error: server_url must start with https:// or http://.*") c.Assert(
fmt.Println(tmp) tmp,
check.Matches,
".*Fatal config error: server_url must start with https:// or http://.*",
)
// Check configuration validation errors (2) // Check configuration validation errors (2)
configYaml = []byte( configYaml = []byte(

35
db.go
View file

@ -9,7 +9,10 @@ import (
"gorm.io/gorm/logger" "gorm.io/gorm/logger"
) )
const dbVersion = "1" const (
dbVersion = "1"
errValueNotFound = Error("not found")
)
// KV is a key-value store in a psql table. For future use... // KV is a key-value store in a psql table. For future use...
type KV struct { type KV struct {
@ -24,7 +27,7 @@ func (h *Headscale) initDB() error {
} }
h.db = db h.db = db
if h.dbType == "postgres" { if h.dbType == Postgres {
db.Exec("create extension if not exists \"uuid-ossp\";") db.Exec("create extension if not exists \"uuid-ossp\";")
} }
err = db.AutoMigrate(&Machine{}) err = db.AutoMigrate(&Machine{})
@ -50,6 +53,7 @@ func (h *Headscale) initDB() error {
} }
err = h.setValue("db_version", dbVersion) err = h.setValue("db_version", dbVersion)
return err return err
} }
@ -65,12 +69,12 @@ func (h *Headscale) openDB() (*gorm.DB, error) {
} }
switch h.dbType { switch h.dbType {
case "sqlite3": case Sqlite:
db, err = gorm.Open(sqlite.Open(h.dbString), &gorm.Config{ db, err = gorm.Open(sqlite.Open(h.dbString), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true, DisableForeignKeyConstraintWhenMigrating: true,
Logger: log, Logger: log,
}) })
case "postgres": case Postgres:
db, err = gorm.Open(postgres.Open(h.dbString), &gorm.Config{ db, err = gorm.Open(postgres.Open(h.dbString), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true, DisableForeignKeyConstraintWhenMigrating: true,
Logger: log, Logger: log,
@ -84,28 +88,33 @@ func (h *Headscale) openDB() (*gorm.DB, error) {
return db, nil return db, nil
} }
// getValue returns the value for the given key in KV // getValue returns the value for the given key in KV.
func (h *Headscale) getValue(key string) (string, error) { func (h *Headscale) getValue(key string) (string, error) {
var row KV var row KV
if result := h.db.First(&row, "key = ?", key); errors.Is(result.Error, gorm.ErrRecordNotFound) { if result := h.db.First(&row, "key = ?", key); errors.Is(
return "", errors.New("not found") result.Error,
gorm.ErrRecordNotFound,
) {
return "", errValueNotFound
} }
return row.Value, nil return row.Value, nil
} }
// setValue sets value for the given key in KV // setValue sets value for the given key in KV.
func (h *Headscale) setValue(key string, value string) error { func (h *Headscale) setValue(key string, value string) error {
kv := KV{ keyValue := KV{
Key: key, Key: key,
Value: value, Value: value,
} }
_, err := h.getValue(key) if _, err := h.getValue(key); err == nil {
if err == nil { h.db.Model(&keyValue).Where("key = ?", key).Update("value", value)
h.db.Model(&kv).Where("key = ?", key).Update("value", value)
return nil return nil
} }
h.db.Create(kv) h.db.Create(keyValue)
return nil return nil
} }

24
derp.go
View file

@ -1,6 +1,7 @@
package headscale package headscale
import ( import (
"context"
"encoding/json" "encoding/json"
"io" "io"
"io/ioutil" "io/ioutil"
@ -10,9 +11,7 @@ import (
"time" "time"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
) )
@ -28,14 +27,24 @@ func loadDERPMapFromPath(path string) (*tailcfg.DERPMap, error) {
return nil, err return nil, err
} }
err = yaml.Unmarshal(b, &derpMap) err = yaml.Unmarshal(b, &derpMap)
return &derpMap, err return &derpMap, err
} }
func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) { func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) {
client := http.Client{ ctx, cancel := context.WithTimeout(context.Background(), HTTPReadTimeout)
Timeout: 10 * time.Second, defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", addr.String(), nil)
if err != nil {
return nil, err
} }
resp, err := client.Get(addr.String())
client := http.Client{
Timeout: HTTPReadTimeout,
}
resp, err := client.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -48,6 +57,7 @@ func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) {
var derpMap tailcfg.DERPMap var derpMap tailcfg.DERPMap
err = json.Unmarshal(body, &derpMap) err = json.Unmarshal(body, &derpMap)
return &derpMap, err return &derpMap, err
} }
@ -55,7 +65,7 @@ func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) {
// DERPMap, it will _only_ look at the Regions, an integer. // DERPMap, it will _only_ look at the Regions, an integer.
// If a region exists in two of the given DERPMaps, the region // If a region exists in two of the given DERPMaps, the region
// form the _last_ DERPMap will be preserved. // form the _last_ DERPMap will be preserved.
// An empty DERPMap list will result in a DERPMap with no regions // An empty DERPMap list will result in a DERPMap with no regions.
func mergeDERPMaps(derpMaps []*tailcfg.DERPMap) *tailcfg.DERPMap { func mergeDERPMaps(derpMaps []*tailcfg.DERPMap) *tailcfg.DERPMap {
result := tailcfg.DERPMap{ result := tailcfg.DERPMap{
OmitDefaultRegions: false, OmitDefaultRegions: false,
@ -86,6 +96,7 @@ func GetDERPMap(cfg DERPConfig) *tailcfg.DERPMap {
Str("path", path). Str("path", path).
Err(err). Err(err).
Msg("Could not load DERP map from path") Msg("Could not load DERP map from path")
break break
} }
@ -104,6 +115,7 @@ func GetDERPMap(cfg DERPConfig) *tailcfg.DERPMap {
Str("url", addr.String()). Str("url", addr.String()).
Err(err). Err(err).
Msg("Could not load DERP map from path") Msg("Could not load DERP map from path")
break break
} }

34
dns.go
View file

@ -10,6 +10,10 @@ import (
"tailscale.com/util/dnsname" "tailscale.com/util/dnsname"
) )
const (
ByteSize = 8
)
// generateMagicDNSRootDomains generates a list of DNS entries to be included in `Routes` in `MapResponse`. // generateMagicDNSRootDomains generates a list of DNS entries to be included in `Routes` in `MapResponse`.
// This list of reverse DNS entries instructs the OS on what subnets and domains the Tailscale embedded DNS // This list of reverse DNS entries instructs the OS on what subnets and domains the Tailscale embedded DNS
// server (listening in 100.100.100.100 udp/53) should be used for. // server (listening in 100.100.100.100 udp/53) should be used for.
@ -30,7 +34,9 @@ import (
// From the netmask we can find out the wildcard bits (the bits that are not set in the netmask). // From the netmask we can find out the wildcard bits (the bits that are not set in the netmask).
// This allows us to then calculate the subnets included in the subsequent class block and generate the entries. // This allows us to then calculate the subnets included in the subsequent class block and generate the entries.
func generateMagicDNSRootDomains(ipPrefix netaddr.IPPrefix, baseDomain string) ([]dnsname.FQDN, error) { func generateMagicDNSRootDomains(
ipPrefix netaddr.IPPrefix,
) []dnsname.FQDN {
// TODO(juanfont): we are not handing out IPv6 addresses yet // TODO(juanfont): we are not handing out IPv6 addresses yet
// and in fact this is Tailscale.com's range (note the fd7a:115c:a1e0: range in the fc00::/7 network) // and in fact this is Tailscale.com's range (note the fd7a:115c:a1e0: range in the fc00::/7 network)
ipv6base := dnsname.FQDN("0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa.") ipv6base := dnsname.FQDN("0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa.")
@ -41,15 +47,15 @@ func generateMagicDNSRootDomains(ipPrefix netaddr.IPPrefix, baseDomain string) (
maskBits, _ := netRange.Mask.Size() maskBits, _ := netRange.Mask.Size()
// lastOctet is the last IP byte covered by the mask // lastOctet is the last IP byte covered by the mask
lastOctet := maskBits / 8 lastOctet := maskBits / ByteSize
// wildcardBits is the number of bits not under the mask in the lastOctet // wildcardBits is the number of bits not under the mask in the lastOctet
wildcardBits := 8 - maskBits%8 wildcardBits := ByteSize - maskBits%ByteSize
// min is the value in the lastOctet byte of the IP // min is the value in the lastOctet byte of the IP
// max is basically 2^wildcardBits - i.e., the value when all the wildcardBits are set to 1 // max is basically 2^wildcardBits - i.e., the value when all the wildcardBits are set to 1
min := uint(netRange.IP[lastOctet]) min := uint(netRange.IP[lastOctet])
max := uint((min + 1<<uint(wildcardBits)) - 1) max := (min + 1<<uint(wildcardBits)) - 1
// here we generate the base domain (e.g., 100.in-addr.arpa., 16.172.in-addr.arpa., etc.) // here we generate the base domain (e.g., 100.in-addr.arpa., 16.172.in-addr.arpa., etc.)
rdnsSlice := []string{} rdnsSlice := []string{}
@ -66,18 +72,27 @@ func generateMagicDNSRootDomains(ipPrefix netaddr.IPPrefix, baseDomain string) (
} }
fqdns = append(fqdns, fqdn) fqdns = append(fqdns, fqdn)
} }
return fqdns, nil
return fqdns
} }
func getMapResponseDNSConfig(dnsConfigOrig *tailcfg.DNSConfig, baseDomain string, m Machine, peers Machines) (*tailcfg.DNSConfig, error) { func getMapResponseDNSConfig(
dnsConfigOrig *tailcfg.DNSConfig,
baseDomain string,
machine Machine,
peers Machines,
) *tailcfg.DNSConfig {
var dnsConfig *tailcfg.DNSConfig var dnsConfig *tailcfg.DNSConfig
if dnsConfigOrig != nil && dnsConfigOrig.Proxied { // if MagicDNS is enabled if dnsConfigOrig != nil && dnsConfigOrig.Proxied { // if MagicDNS is enabled
// Only inject the Search Domain of the current namespace - shared nodes should use their full FQDN // Only inject the Search Domain of the current namespace - shared nodes should use their full FQDN
dnsConfig = dnsConfigOrig.Clone() dnsConfig = dnsConfigOrig.Clone()
dnsConfig.Domains = append(dnsConfig.Domains, fmt.Sprintf("%s.%s", m.Namespace.Name, baseDomain)) dnsConfig.Domains = append(
dnsConfig.Domains,
fmt.Sprintf("%s.%s", machine.Namespace.Name, baseDomain),
)
namespaceSet := set.New(set.ThreadSafe) namespaceSet := set.New(set.ThreadSafe)
namespaceSet.Add(m.Namespace) namespaceSet.Add(machine.Namespace)
for _, p := range peers { for _, p := range peers {
namespaceSet.Add(p.Namespace) namespaceSet.Add(p.Namespace)
} }
@ -88,5 +103,6 @@ func getMapResponseDNSConfig(dnsConfigOrig *tailcfg.DNSConfig, baseDomain string
} else { } else {
dnsConfig = dnsConfigOrig dnsConfig = dnsConfigOrig
} }
return dnsConfig, nil
return dnsConfig
} }

View file

@ -11,13 +11,13 @@ import (
func (s *Suite) TestMagicDNSRootDomains100(c *check.C) { func (s *Suite) TestMagicDNSRootDomains100(c *check.C) {
prefix := netaddr.MustParseIPPrefix("100.64.0.0/10") prefix := netaddr.MustParseIPPrefix("100.64.0.0/10")
domains, err := generateMagicDNSRootDomains(prefix, "foobar.headscale.net") domains := generateMagicDNSRootDomains(prefix)
c.Assert(err, check.IsNil)
found := false found := false
for _, domain := range domains { for _, domain := range domains {
if domain == "64.100.in-addr.arpa." { if domain == "64.100.in-addr.arpa." {
found = true found = true
break break
} }
} }
@ -27,6 +27,7 @@ func (s *Suite) TestMagicDNSRootDomains100(c *check.C) {
for _, domain := range domains { for _, domain := range domains {
if domain == "100.100.in-addr.arpa." { if domain == "100.100.in-addr.arpa." {
found = true found = true
break break
} }
} }
@ -36,6 +37,7 @@ func (s *Suite) TestMagicDNSRootDomains100(c *check.C) {
for _, domain := range domains { for _, domain := range domains {
if domain == "127.100.in-addr.arpa." { if domain == "127.100.in-addr.arpa." {
found = true found = true
break break
} }
} }
@ -44,13 +46,13 @@ func (s *Suite) TestMagicDNSRootDomains100(c *check.C) {
func (s *Suite) TestMagicDNSRootDomains172(c *check.C) { func (s *Suite) TestMagicDNSRootDomains172(c *check.C) {
prefix := netaddr.MustParseIPPrefix("172.16.0.0/16") prefix := netaddr.MustParseIPPrefix("172.16.0.0/16")
domains, err := generateMagicDNSRootDomains(prefix, "headscale.net") domains := generateMagicDNSRootDomains(prefix)
c.Assert(err, check.IsNil)
found := false found := false
for _, domain := range domains { for _, domain := range domains {
if domain == "0.16.172.in-addr.arpa." { if domain == "0.16.172.in-addr.arpa." {
found = true found = true
break break
} }
} }
@ -60,6 +62,7 @@ func (s *Suite) TestMagicDNSRootDomains172(c *check.C) {
for _, domain := range domains { for _, domain := range domains {
if domain == "255.16.172.in-addr.arpa." { if domain == "255.16.172.in-addr.arpa." {
found = true found = true
break break
} }
} }
@ -67,100 +70,120 @@ func (s *Suite) TestMagicDNSRootDomains172(c *check.C) {
} }
func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
n1, err := h.CreateNamespace("shared1") namespaceShared1, err := app.CreateNamespace("shared1")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
n2, err := h.CreateNamespace("shared2") namespaceShared2, err := app.CreateNamespace("shared2")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
n3, err := h.CreateNamespace("shared3") namespaceShared3, err := app.CreateNamespace("shared3")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak1n1, err := h.CreatePreAuthKey(n1.Name, false, false, nil) preAuthKeyInShared1, err := app.CreatePreAuthKey(
namespaceShared1.Name,
false,
false,
nil,
)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak2n2, err := h.CreatePreAuthKey(n2.Name, false, false, nil) preAuthKeyInShared2, err := app.CreatePreAuthKey(
namespaceShared2.Name,
false,
false,
nil,
)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak3n3, err := h.CreatePreAuthKey(n3.Name, false, false, nil) preAuthKeyInShared3, err := app.CreatePreAuthKey(
namespaceShared3.Name,
false,
false,
nil,
)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak4n1, err := h.CreatePreAuthKey(n1.Name, false, false, nil) PreAuthKey2InShared1, err := app.CreatePreAuthKey(
namespaceShared1.Name,
false,
false,
nil,
)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = h.GetMachine(n1.Name, "test_get_shared_nodes_1") _, err = app.GetMachine(namespaceShared1.Name, "test_get_shared_nodes_1")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
m1 := &Machine{ machineInShared1 := &Machine{
ID: 1, ID: 1,
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
Name: "test_get_shared_nodes_1", Name: "test_get_shared_nodes_1",
NamespaceID: n1.ID, NamespaceID: namespaceShared1.ID,
Namespace: *n1, Namespace: *namespaceShared1,
Registered: true, Registered: true,
RegisterMethod: "authKey", RegisterMethod: "authKey",
IPAddress: "100.64.0.1", IPAddress: "100.64.0.1",
AuthKeyID: uint(pak1n1.ID), AuthKeyID: uint(preAuthKeyInShared1.ID),
} }
h.db.Save(m1) app.db.Save(machineInShared1)
_, err = h.GetMachine(n1.Name, m1.Name) _, err = app.GetMachine(namespaceShared1.Name, machineInShared1.Name)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
m2 := &Machine{ machineInShared2 := &Machine{
ID: 2, ID: 2,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Name: "test_get_shared_nodes_2", Name: "test_get_shared_nodes_2",
NamespaceID: n2.ID, NamespaceID: namespaceShared2.ID,
Namespace: *n2, Namespace: *namespaceShared2,
Registered: true, Registered: true,
RegisterMethod: "authKey", RegisterMethod: "authKey",
IPAddress: "100.64.0.2", IPAddress: "100.64.0.2",
AuthKeyID: uint(pak2n2.ID), AuthKeyID: uint(preAuthKeyInShared2.ID),
} }
h.db.Save(m2) app.db.Save(machineInShared2)
_, err = h.GetMachine(n2.Name, m2.Name) _, err = app.GetMachine(namespaceShared2.Name, machineInShared2.Name)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
m3 := &Machine{ machineInShared3 := &Machine{
ID: 3, ID: 3,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Name: "test_get_shared_nodes_3", Name: "test_get_shared_nodes_3",
NamespaceID: n3.ID, NamespaceID: namespaceShared3.ID,
Namespace: *n3, Namespace: *namespaceShared3,
Registered: true, Registered: true,
RegisterMethod: "authKey", RegisterMethod: "authKey",
IPAddress: "100.64.0.3", IPAddress: "100.64.0.3",
AuthKeyID: uint(pak3n3.ID), AuthKeyID: uint(preAuthKeyInShared3.ID),
} }
h.db.Save(m3) app.db.Save(machineInShared3)
_, err = h.GetMachine(n3.Name, m3.Name) _, err = app.GetMachine(namespaceShared3.Name, machineInShared3.Name)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
m4 := &Machine{ machine2InShared1 := &Machine{
ID: 4, ID: 4,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Name: "test_get_shared_nodes_4", Name: "test_get_shared_nodes_4",
NamespaceID: n1.ID, NamespaceID: namespaceShared1.ID,
Namespace: *n1, Namespace: *namespaceShared1,
Registered: true, Registered: true,
RegisterMethod: "authKey", RegisterMethod: "authKey",
IPAddress: "100.64.0.4", IPAddress: "100.64.0.4",
AuthKeyID: uint(pak4n1.ID), AuthKeyID: uint(PreAuthKey2InShared1.ID),
} }
h.db.Save(m4) app.db.Save(machine2InShared1)
err = h.AddSharedMachineToNamespace(m2, n1) err = app.AddSharedMachineToNamespace(machineInShared2, namespaceShared1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
baseDomain := "foobar.headscale.net" baseDomain := "foobar.headscale.net"
@ -170,122 +193,146 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
Proxied: true, Proxied: true,
} }
m1peers, err := h.getPeers(m1) peersOfMachineInShared1, err := app.getPeers(machineInShared1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
dnsConfig, err := getMapResponseDNSConfig(&dnsConfigOrig, baseDomain, *m1, m1peers) dnsConfig := getMapResponseDNSConfig(
c.Assert(err, check.IsNil) &dnsConfigOrig,
baseDomain,
*machineInShared1,
peersOfMachineInShared1,
)
c.Assert(dnsConfig, check.NotNil) c.Assert(dnsConfig, check.NotNil)
c.Assert(len(dnsConfig.Routes), check.Equals, 2) c.Assert(len(dnsConfig.Routes), check.Equals, 2)
routeN1 := fmt.Sprintf("%s.%s", n1.Name, baseDomain) domainRouteShared1 := fmt.Sprintf("%s.%s", namespaceShared1.Name, baseDomain)
_, ok := dnsConfig.Routes[routeN1] _, ok := dnsConfig.Routes[domainRouteShared1]
c.Assert(ok, check.Equals, true) c.Assert(ok, check.Equals, true)
routeN2 := fmt.Sprintf("%s.%s", n2.Name, baseDomain) domainRouteShared2 := fmt.Sprintf("%s.%s", namespaceShared2.Name, baseDomain)
_, ok = dnsConfig.Routes[routeN2] _, ok = dnsConfig.Routes[domainRouteShared2]
c.Assert(ok, check.Equals, true) c.Assert(ok, check.Equals, true)
routeN3 := fmt.Sprintf("%s.%s", n3.Name, baseDomain) domainRouteShared3 := fmt.Sprintf("%s.%s", namespaceShared3.Name, baseDomain)
_, ok = dnsConfig.Routes[routeN3] _, ok = dnsConfig.Routes[domainRouteShared3]
c.Assert(ok, check.Equals, false) c.Assert(ok, check.Equals, false)
} }
func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
n1, err := h.CreateNamespace("shared1") namespaceShared1, err := app.CreateNamespace("shared1")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
n2, err := h.CreateNamespace("shared2") namespaceShared2, err := app.CreateNamespace("shared2")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
n3, err := h.CreateNamespace("shared3") namespaceShared3, err := app.CreateNamespace("shared3")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak1n1, err := h.CreatePreAuthKey(n1.Name, false, false, nil) preAuthKeyInShared1, err := app.CreatePreAuthKey(
namespaceShared1.Name,
false,
false,
nil,
)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak2n2, err := h.CreatePreAuthKey(n2.Name, false, false, nil) preAuthKeyInShared2, err := app.CreatePreAuthKey(
namespaceShared2.Name,
false,
false,
nil,
)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak3n3, err := h.CreatePreAuthKey(n3.Name, false, false, nil) preAuthKeyInShared3, err := app.CreatePreAuthKey(
namespaceShared3.Name,
false,
false,
nil,
)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak4n1, err := h.CreatePreAuthKey(n1.Name, false, false, nil) preAuthKey2InShared1, err := app.CreatePreAuthKey(
namespaceShared1.Name,
false,
false,
nil,
)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = h.GetMachine(n1.Name, "test_get_shared_nodes_1") _, err = app.GetMachine(namespaceShared1.Name, "test_get_shared_nodes_1")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
m1 := &Machine{ machineInShared1 := &Machine{
ID: 1, ID: 1,
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
Name: "test_get_shared_nodes_1", Name: "test_get_shared_nodes_1",
NamespaceID: n1.ID, NamespaceID: namespaceShared1.ID,
Namespace: *n1, Namespace: *namespaceShared1,
Registered: true, Registered: true,
RegisterMethod: "authKey", RegisterMethod: "authKey",
IPAddress: "100.64.0.1", IPAddress: "100.64.0.1",
AuthKeyID: uint(pak1n1.ID), AuthKeyID: uint(preAuthKeyInShared1.ID),
} }
h.db.Save(m1) app.db.Save(machineInShared1)
_, err = h.GetMachine(n1.Name, m1.Name) _, err = app.GetMachine(namespaceShared1.Name, machineInShared1.Name)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
m2 := &Machine{ machineInShared2 := &Machine{
ID: 2, ID: 2,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Name: "test_get_shared_nodes_2", Name: "test_get_shared_nodes_2",
NamespaceID: n2.ID, NamespaceID: namespaceShared2.ID,
Namespace: *n2, Namespace: *namespaceShared2,
Registered: true, Registered: true,
RegisterMethod: "authKey", RegisterMethod: "authKey",
IPAddress: "100.64.0.2", IPAddress: "100.64.0.2",
AuthKeyID: uint(pak2n2.ID), AuthKeyID: uint(preAuthKeyInShared2.ID),
} }
h.db.Save(m2) app.db.Save(machineInShared2)
_, err = h.GetMachine(n2.Name, m2.Name) _, err = app.GetMachine(namespaceShared2.Name, machineInShared2.Name)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
m3 := &Machine{ machineInShared3 := &Machine{
ID: 3, ID: 3,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Name: "test_get_shared_nodes_3", Name: "test_get_shared_nodes_3",
NamespaceID: n3.ID, NamespaceID: namespaceShared3.ID,
Namespace: *n3, Namespace: *namespaceShared3,
Registered: true, Registered: true,
RegisterMethod: "authKey", RegisterMethod: "authKey",
IPAddress: "100.64.0.3", IPAddress: "100.64.0.3",
AuthKeyID: uint(pak3n3.ID), AuthKeyID: uint(preAuthKeyInShared3.ID),
} }
h.db.Save(m3) app.db.Save(machineInShared3)
_, err = h.GetMachine(n3.Name, m3.Name) _, err = app.GetMachine(namespaceShared3.Name, machineInShared3.Name)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
m4 := &Machine{ machine2InShared1 := &Machine{
ID: 4, ID: 4,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Name: "test_get_shared_nodes_4", Name: "test_get_shared_nodes_4",
NamespaceID: n1.ID, NamespaceID: namespaceShared1.ID,
Namespace: *n1, Namespace: *namespaceShared1,
Registered: true, Registered: true,
RegisterMethod: "authKey", RegisterMethod: "authKey",
IPAddress: "100.64.0.4", IPAddress: "100.64.0.4",
AuthKeyID: uint(pak4n1.ID), AuthKeyID: uint(preAuthKey2InShared1.ID),
} }
h.db.Save(m4) app.db.Save(machine2InShared1)
err = h.AddSharedMachineToNamespace(m2, n1) err = app.AddSharedMachineToNamespace(machineInShared2, namespaceShared1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
baseDomain := "foobar.headscale.net" baseDomain := "foobar.headscale.net"
@ -295,11 +342,15 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
Proxied: false, Proxied: false,
} }
m1peers, err := h.getPeers(m1) peersOfMachine1Shared1, err := app.getPeers(machineInShared1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
dnsConfig, err := getMapResponseDNSConfig(&dnsConfigOrig, baseDomain, *m1, m1peers) dnsConfig := getMapResponseDNSConfig(
c.Assert(err, check.IsNil) &dnsConfigOrig,
baseDomain,
*machineInShared1,
peersOfMachine1Shared1,
)
c.Assert(dnsConfig, check.NotNil) c.Assert(dnsConfig, check.NotNil)
c.Assert(len(dnsConfig.Routes), check.Equals, 0) c.Assert(len(dnsConfig.Routes), check.Equals, 0)
c.Assert(len(dnsConfig.Domains), check.Equals, 1) c.Assert(len(dnsConfig.Domains), check.Equals, 1)

View file

@ -1,6 +1,7 @@
# Running headscale # Running headscale
## Server configuration ## Server configuration
1. Download the headscale binary https://github.com/juanfont/headscale/releases, and place it somewhere in your $PATH or use the docker container 1. Download the headscale binary https://github.com/juanfont/headscale/releases, and place it somewhere in your $PATH or use the docker container
```shell ```shell
@ -22,6 +23,7 @@
3. Get yourself a DB 3. Get yourself a DB
a) Get a Postgres DB running in docker a) Get a Postgres DB running in docker
```shell ```shell
docker run --name headscale \ docker run --name headscale \
-e POSTGRES_DB=headscale -e POSTGRES_DB=headscale
@ -30,7 +32,9 @@
-p 5432:5432 \ -p 5432:5432 \
-d postgres -d postgres
``` ```
or b) Prepare a SQLite DB file or b) Prepare a SQLite DB file
```shell ```shell
touch config/db.sqlite touch config/db.sqlite
``` ```
@ -81,6 +85,7 @@
-p 127.0.0.1:8080:8080 \ -p 127.0.0.1:8080:8080 \
headscale/headscale:x.x.x headscale serve headscale/headscale:x.x.x headscale serve
``` ```
## Nodes configuration ## Nodes configuration
If you used tailscale.com before in your nodes, make sure you clear the tailscaled data folder If you used tailscale.com before in your nodes, make sure you clear the tailscaled data folder
@ -90,7 +95,9 @@ If you used tailscale.com before in your nodes, make sure you clear the tailscal
rm -fr /var/lib/tailscale rm -fr /var/lib/tailscale
systemctl start tailscaled systemctl start tailscaled
``` ```
### Adding node based on MACHINEKEY ### Adding node based on MACHINEKEY
1. Add your first machine 1. Add your first machine
```shell ```shell

View file

@ -155,7 +155,10 @@ func (api headscaleV1APIServer) RegisterMachine(
ctx context.Context, ctx context.Context,
request *v1.RegisterMachineRequest, request *v1.RegisterMachineRequest,
) (*v1.RegisterMachineResponse, error) { ) (*v1.RegisterMachineResponse, error) {
log.Trace().Str("namespace", request.GetNamespace()).Str("machine_key", request.GetKey()).Msg("Registering machine") log.Trace().
Str("namespace", request.GetNamespace()).
Str("machine_key", request.GetKey()).
Msg("Registering machine")
machine, err := api.h.RegisterMachine( machine, err := api.h.RegisterMachine(
request.GetKey(), request.GetKey(),
request.GetNamespace(), request.GetNamespace(),
@ -208,7 +211,9 @@ func (api headscaleV1APIServer) ListMachines(
return nil, err return nil, err
} }
sharedMachines, err := api.h.ListSharedMachinesInNamespace(request.GetNamespace()) sharedMachines, err := api.h.ListSharedMachinesInNamespace(
request.GetNamespace(),
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -333,12 +338,16 @@ func (api headscaleV1APIServer) DebugCreateMachine(
return nil, err return nil, err
} }
routes, err := stringToIpPrefix(request.GetRoutes()) routes, err := stringToIPPrefix(request.GetRoutes())
if err != nil { if err != nil {
return nil, err return nil, err
} }
log.Trace().Caller().Interface("route-prefix", routes).Interface("route-str", request.GetRoutes()).Msg("") log.Trace().
Caller().
Interface("route-prefix", routes).
Interface("route-str", request.GetRoutes()).
Msg("")
hostinfo := tailcfg.Hostinfo{ hostinfo := tailcfg.Hostinfo{
RoutableIPs: routes, RoutableIPs: routes,

View file

@ -88,6 +88,7 @@ func (s *IntegrationCLITestSuite) SetupTest() {
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return fmt.Errorf("status code not OK") return fmt.Errorf("status code not OK")
} }
return nil return nil
}); err != nil { }); err != nil {
// TODO(kradalby): If we cannot access headscale, or any other fatal error during // TODO(kradalby): If we cannot access headscale, or any other fatal error during
@ -109,7 +110,10 @@ func (s *IntegrationCLITestSuite) TearDownTest() {
} }
} }
func (s *IntegrationCLITestSuite) HandleStats(suiteName string, stats *suite.SuiteInformation) { func (s *IntegrationCLITestSuite) HandleStats(
suiteName string,
stats *suite.SuiteInformation,
) {
s.stats = stats s.stats = stats
} }
@ -144,7 +148,6 @@ func (s *IntegrationCLITestSuite) TestNamespaceCommand() {
namespaces := make([]*v1.Namespace, len(names)) namespaces := make([]*v1.Namespace, len(names))
for index, namespaceName := range names { for index, namespaceName := range names {
namespace, err := s.createNamespace(namespaceName) namespace, err := s.createNamespace(namespaceName)
assert.Nil(s.T(), err) assert.Nil(s.T(), err)
@ -298,11 +301,26 @@ func (s *IntegrationCLITestSuite) TestPreAuthKeyCommand() {
assert.True(s.T(), listedPreAuthKeys[3].Expiration.AsTime().After(time.Now())) assert.True(s.T(), listedPreAuthKeys[3].Expiration.AsTime().After(time.Now()))
assert.True(s.T(), listedPreAuthKeys[4].Expiration.AsTime().After(time.Now())) assert.True(s.T(), listedPreAuthKeys[4].Expiration.AsTime().After(time.Now()))
assert.True(s.T(), listedPreAuthKeys[0].Expiration.AsTime().Before(time.Now().Add(time.Hour*26))) assert.True(
assert.True(s.T(), listedPreAuthKeys[1].Expiration.AsTime().Before(time.Now().Add(time.Hour*26))) s.T(),
assert.True(s.T(), listedPreAuthKeys[2].Expiration.AsTime().Before(time.Now().Add(time.Hour*26))) listedPreAuthKeys[0].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)),
assert.True(s.T(), listedPreAuthKeys[3].Expiration.AsTime().Before(time.Now().Add(time.Hour*26))) )
assert.True(s.T(), listedPreAuthKeys[4].Expiration.AsTime().Before(time.Now().Add(time.Hour*26))) assert.True(
s.T(),
listedPreAuthKeys[1].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)),
)
assert.True(
s.T(),
listedPreAuthKeys[2].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)),
)
assert.True(
s.T(),
listedPreAuthKeys[3].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)),
)
assert.True(
s.T(),
listedPreAuthKeys[4].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)),
)
// Expire three keys // Expire three keys
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
@ -341,11 +359,26 @@ func (s *IntegrationCLITestSuite) TestPreAuthKeyCommand() {
err = json.Unmarshal([]byte(listAfterExpireResult), &listedAfterExpirePreAuthKeys) err = json.Unmarshal([]byte(listAfterExpireResult), &listedAfterExpirePreAuthKeys)
assert.Nil(s.T(), err) assert.Nil(s.T(), err)
assert.True(s.T(), listedAfterExpirePreAuthKeys[0].Expiration.AsTime().Before(time.Now())) assert.True(
assert.True(s.T(), listedAfterExpirePreAuthKeys[1].Expiration.AsTime().Before(time.Now())) s.T(),
assert.True(s.T(), listedAfterExpirePreAuthKeys[2].Expiration.AsTime().Before(time.Now())) listedAfterExpirePreAuthKeys[0].Expiration.AsTime().Before(time.Now()),
assert.True(s.T(), listedAfterExpirePreAuthKeys[3].Expiration.AsTime().After(time.Now())) )
assert.True(s.T(), listedAfterExpirePreAuthKeys[4].Expiration.AsTime().After(time.Now())) assert.True(
s.T(),
listedAfterExpirePreAuthKeys[1].Expiration.AsTime().Before(time.Now()),
)
assert.True(
s.T(),
listedAfterExpirePreAuthKeys[2].Expiration.AsTime().Before(time.Now()),
)
assert.True(
s.T(),
listedAfterExpirePreAuthKeys[3].Expiration.AsTime().After(time.Now()),
)
assert.True(
s.T(),
listedAfterExpirePreAuthKeys[4].Expiration.AsTime().After(time.Now()),
)
} }
func (s *IntegrationCLITestSuite) TestPreAuthKeyCommandWithoutExpiry() { func (s *IntegrationCLITestSuite) TestPreAuthKeyCommandWithoutExpiry() {
@ -689,7 +722,10 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() {
assert.Nil(s.T(), err) assert.Nil(s.T(), err)
var listOnlySharedMachineNamespace []v1.Machine var listOnlySharedMachineNamespace []v1.Machine
err = json.Unmarshal([]byte(listOnlySharedMachineNamespaceResult), &listOnlySharedMachineNamespace) err = json.Unmarshal(
[]byte(listOnlySharedMachineNamespaceResult),
&listOnlySharedMachineNamespace,
)
assert.Nil(s.T(), err) assert.Nil(s.T(), err)
assert.Len(s.T(), listOnlySharedMachineNamespace, 2) assert.Len(s.T(), listOnlySharedMachineNamespace, 2)
@ -738,7 +774,10 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() {
assert.Nil(s.T(), err) assert.Nil(s.T(), err)
var listOnlyMachineNamespaceAfterDelete []v1.Machine var listOnlyMachineNamespaceAfterDelete []v1.Machine
err = json.Unmarshal([]byte(listOnlyMachineNamespaceAfterDeleteResult), &listOnlyMachineNamespaceAfterDelete) err = json.Unmarshal(
[]byte(listOnlyMachineNamespaceAfterDeleteResult),
&listOnlyMachineNamespaceAfterDelete,
)
assert.Nil(s.T(), err) assert.Nil(s.T(), err)
assert.Len(s.T(), listOnlyMachineNamespaceAfterDelete, 4) assert.Len(s.T(), listOnlyMachineNamespaceAfterDelete, 4)
@ -789,7 +828,10 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() {
assert.Nil(s.T(), err) assert.Nil(s.T(), err)
var listOnlyMachineNamespaceAfterShare []v1.Machine var listOnlyMachineNamespaceAfterShare []v1.Machine
err = json.Unmarshal([]byte(listOnlyMachineNamespaceAfterShareResult), &listOnlyMachineNamespaceAfterShare) err = json.Unmarshal(
[]byte(listOnlyMachineNamespaceAfterShareResult),
&listOnlyMachineNamespaceAfterShare,
)
assert.Nil(s.T(), err) assert.Nil(s.T(), err)
assert.Len(s.T(), listOnlyMachineNamespaceAfterShare, 5) assert.Len(s.T(), listOnlyMachineNamespaceAfterShare, 5)
@ -846,7 +888,10 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() {
assert.Nil(s.T(), err) assert.Nil(s.T(), err)
var listOnlyMachineNamespaceAfterUnshare []v1.Machine var listOnlyMachineNamespaceAfterUnshare []v1.Machine
err = json.Unmarshal([]byte(listOnlyMachineNamespaceAfterUnshareResult), &listOnlyMachineNamespaceAfterUnshare) err = json.Unmarshal(
[]byte(listOnlyMachineNamespaceAfterUnshareResult),
&listOnlyMachineNamespaceAfterUnshare,
)
assert.Nil(s.T(), err) assert.Nil(s.T(), err)
assert.Len(s.T(), listOnlyMachineNamespaceAfterUnshare, 4) assert.Len(s.T(), listOnlyMachineNamespaceAfterUnshare, 4)
@ -1010,5 +1055,9 @@ func (s *IntegrationCLITestSuite) TestRouteCommand() {
) )
assert.Nil(s.T(), err) assert.Nil(s.T(), err)
assert.Contains(s.T(), string(failEnableNonAdvertisedRoute), "route (route-machine) is not available on node") assert.Contains(
s.T(),
string(failEnableNonAdvertisedRoute),
"route (route-machine) is not available on node",
)
} }

View file

@ -12,12 +12,18 @@ import (
"github.com/ory/dockertest/v3/docker" "github.com/ory/dockertest/v3/docker"
) )
func ExecuteCommand(resource *dockertest.Resource, cmd []string, env []string) (string, error) { const DOCKER_EXECUTE_TIMEOUT = 10 * time.Second
func ExecuteCommand(
resource *dockertest.Resource,
cmd []string,
env []string,
) (string, error) {
var stdout bytes.Buffer var stdout bytes.Buffer
var stderr bytes.Buffer var stderr bytes.Buffer
// TODO(kradalby): Make configurable // TODO(kradalby): Make configurable
timeout := 10 * time.Second timeout := DOCKER_EXECUTE_TIMEOUT
type result struct { type result struct {
exitCode int exitCode int
@ -51,11 +57,13 @@ func ExecuteCommand(resource *dockertest.Resource, cmd []string, env []string) (
fmt.Println("Command: ", cmd) fmt.Println("Command: ", cmd)
fmt.Println("stdout: ", stdout.String()) fmt.Println("stdout: ", stdout.String())
fmt.Println("stderr: ", stderr.String()) fmt.Println("stderr: ", stderr.String())
return "", fmt.Errorf("command failed with: %s", stderr.String()) return "", fmt.Errorf("command failed with: %s", stderr.String())
} }
return stdout.String(), nil return stdout.String(), nil
case <-time.After(timeout): case <-time.After(timeout):
return "", fmt.Errorf("command timed out after %s", timeout) return "", fmt.Errorf("command timed out after %s", timeout)
} }
} }

View file

@ -23,10 +23,9 @@ import (
"github.com/ory/dockertest/v3/docker" "github.com/ory/dockertest/v3/docker"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"inet.af/netaddr"
"tailscale.com/client/tailscale/apitype" "tailscale.com/client/tailscale/apitype"
"tailscale.com/ipn/ipnstate" "tailscale.com/ipn/ipnstate"
"inet.af/netaddr"
) )
var tailscaleVersions = []string{"1.16.2", "1.14.3", "1.12.3"} var tailscaleVersions = []string{"1.16.2", "1.14.3", "1.12.3"}
@ -89,7 +88,10 @@ func TestIntegrationTestSuite(t *testing.T) {
} }
} }
func (s *IntegrationTestSuite) saveLog(resource *dockertest.Resource, basePath string) error { func (s *IntegrationTestSuite) saveLog(
resource *dockertest.Resource,
basePath string,
) error {
err := os.MkdirAll(basePath, os.ModePerm) err := os.MkdirAll(basePath, os.ModePerm)
if err != nil { if err != nil {
return err return err
@ -118,12 +120,20 @@ func (s *IntegrationTestSuite) saveLog(resource *dockertest.Resource, basePath s
fmt.Printf("Saving logs for %s to %s\n", resource.Container.Name, basePath) fmt.Printf("Saving logs for %s to %s\n", resource.Container.Name, basePath)
err = ioutil.WriteFile(path.Join(basePath, resource.Container.Name+".stdout.log"), []byte(stdout.String()), 0o644) err = ioutil.WriteFile(
path.Join(basePath, resource.Container.Name+".stdout.log"),
[]byte(stdout.String()),
0o644,
)
if err != nil { if err != nil {
return err return err
} }
err = ioutil.WriteFile(path.Join(basePath, resource.Container.Name+".stderr.log"), []byte(stdout.String()), 0o644) err = ioutil.WriteFile(
path.Join(basePath, resource.Container.Name+".stderr.log"),
[]byte(stdout.String()),
0o644,
)
if err != nil { if err != nil {
return err return err
} }
@ -144,24 +154,38 @@ func (s *IntegrationTestSuite) tailscaleContainer(
}, },
}, },
} }
hostname := fmt.Sprintf("%s-tailscale-%s-%s", namespace, strings.Replace(version, ".", "-", -1), identifier) hostname := fmt.Sprintf(
"%s-tailscale-%s-%s",
namespace,
strings.Replace(version, ".", "-", -1),
identifier,
)
tailscaleOptions := &dockertest.RunOptions{ tailscaleOptions := &dockertest.RunOptions{
Name: hostname, Name: hostname,
Networks: []*dockertest.Network{&s.network}, Networks: []*dockertest.Network{&s.network},
Cmd: []string{"tailscaled", "--tun=userspace-networking", "--socks5-server=localhost:1055"}, Cmd: []string{
"tailscaled",
"--tun=userspace-networking",
"--socks5-server=localhost:1055",
},
} }
pts, err := s.pool.BuildAndRunWithBuildOptions(tailscaleBuildOptions, tailscaleOptions, DockerRestartPolicy) pts, err := s.pool.BuildAndRunWithBuildOptions(
tailscaleBuildOptions,
tailscaleOptions,
DockerRestartPolicy,
)
if err != nil { if err != nil {
log.Fatalf("Could not start resource: %s", err) log.Fatalf("Could not start resource: %s", err)
} }
fmt.Printf("Created %s container\n", hostname) fmt.Printf("Created %s container\n", hostname)
return hostname, pts return hostname, pts
} }
func (s *IntegrationTestSuite) SetupSuite() { func (s *IntegrationTestSuite) SetupSuite() {
var err error var err error
h = Headscale{ app = Headscale{
dbType: "sqlite3", dbType: "sqlite3",
dbString: "integration_test_db.sqlite3", dbString: "integration_test_db.sqlite3",
} }
@ -210,7 +234,11 @@ func (s *IntegrationTestSuite) SetupSuite() {
for i := 0; i < scales.count; i++ { for i := 0; i < scales.count; i++ {
version := tailscaleVersions[i%len(tailscaleVersions)] version := tailscaleVersions[i%len(tailscaleVersions)]
hostname, container := s.tailscaleContainer(namespace, fmt.Sprint(i), version) hostname, container := s.tailscaleContainer(
namespace,
fmt.Sprint(i),
version,
)
scales.tailscales[hostname] = *container scales.tailscales[hostname] = *container
} }
} }
@ -220,13 +248,16 @@ func (s *IntegrationTestSuite) SetupSuite() {
if err := s.pool.Retry(func() error { if err := s.pool.Retry(func() error {
url := fmt.Sprintf("http://%s/health", hostEndpoint) url := fmt.Sprintf("http://%s/health", hostEndpoint)
resp, err := http.Get(url) resp, err := http.Get(url)
if err != nil { if err != nil {
return err return err
} }
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return fmt.Errorf("status code not OK") return fmt.Errorf("status code not OK")
} }
return nil return nil
}); err != nil { }); err != nil {
// TODO(kradalby): If we cannot access headscale, or any other fatal error during // TODO(kradalby): If we cannot access headscale, or any other fatal error during
@ -273,7 +304,10 @@ func (s *IntegrationTestSuite) SetupSuite() {
headscaleEndpoint := "http://headscale:8080" headscaleEndpoint := "http://headscale:8080"
fmt.Printf("Joining tailscale containers to headscale at %s\n", headscaleEndpoint) fmt.Printf(
"Joining tailscale containers to headscale at %s\n",
headscaleEndpoint,
)
for hostname, tailscale := range scales.tailscales { for hostname, tailscale := range scales.tailscales {
command := []string{ command := []string{
"tailscale", "tailscale",
@ -307,7 +341,10 @@ func (s *IntegrationTestSuite) SetupSuite() {
func (s *IntegrationTestSuite) TearDownSuite() { func (s *IntegrationTestSuite) TearDownSuite() {
} }
func (s *IntegrationTestSuite) HandleStats(suiteName string, stats *suite.SuiteInformation) { func (s *IntegrationTestSuite) HandleStats(
suiteName string,
stats *suite.SuiteInformation,
) {
s.stats = stats s.stats = stats
} }
@ -427,7 +464,13 @@ func (s *IntegrationTestSuite) TestPingAllPeers() {
ip.String(), ip.String(),
} }
fmt.Printf("Pinging from %s (%s) to %s (%s)\n", hostname, ips[hostname], peername, ip) fmt.Printf(
"Pinging from %s (%s) to %s (%s)\n",
hostname,
ips[hostname],
peername,
ip,
)
result, err := ExecuteCommand( result, err := ExecuteCommand(
&tailscale, &tailscale,
command, command,
@ -449,7 +492,15 @@ func (s *IntegrationTestSuite) TestSharedNodes() {
result, err := ExecuteCommand( result, err := ExecuteCommand(
&s.headscale, &s.headscale,
[]string{"headscale", "nodes", "list", "--output", "json", "--namespace", "shared"}, []string{
"headscale",
"nodes",
"list",
"--output",
"json",
"--namespace",
"shared",
},
[]string{}, []string{},
) )
assert.Nil(s.T(), err) assert.Nil(s.T(), err)
@ -459,7 +510,6 @@ func (s *IntegrationTestSuite) TestSharedNodes() {
assert.Nil(s.T(), err) assert.Nil(s.T(), err)
for _, machine := range machineList { for _, machine := range machineList {
result, err := ExecuteCommand( result, err := ExecuteCommand(
&s.headscale, &s.headscale,
[]string{ []string{
@ -520,7 +570,13 @@ func (s *IntegrationTestSuite) TestSharedNodes() {
ip.String(), ip.String(),
} }
fmt.Printf("Pinging from %s (%s) to %s (%s)\n", hostname, mainIps[hostname], peername, ip) fmt.Printf(
"Pinging from %s (%s) to %s (%s)\n",
hostname,
mainIps[hostname],
peername,
ip,
)
result, err := ExecuteCommand( result, err := ExecuteCommand(
&tailscale, &tailscale,
command, command,
@ -553,7 +609,6 @@ func (s *IntegrationTestSuite) TestTailDrop() {
for peername, ip := range ips { for peername, ip := range ips {
s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) { s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) {
if peername != hostname { if peername != hostname {
// Under normal circumstances, we should be able to send a file // Under normal circumstances, we should be able to send a file
// using `tailscale file cp` - but not in userspace networking mode // using `tailscale file cp` - but not in userspace networking mode
// So curl! // So curl!
@ -578,9 +633,19 @@ func (s *IntegrationTestSuite) TestTailDrop() {
"PUT", "PUT",
"--upload-file", "--upload-file",
fmt.Sprintf("/tmp/file_from_%s", hostname), fmt.Sprintf("/tmp/file_from_%s", hostname),
fmt.Sprintf("%s/v0/put/file_from_%s", peerAPI, hostname), fmt.Sprintf(
"%s/v0/put/file_from_%s",
peerAPI,
hostname,
),
} }
fmt.Printf("Sending file from %s (%s) to %s (%s)\n", hostname, ips[hostname], peername, ip) fmt.Printf(
"Sending file from %s (%s) to %s (%s)\n",
hostname,
ips[hostname],
peername,
ip,
)
_, err = ExecuteCommand( _, err = ExecuteCommand(
&tailscale, &tailscale,
command, command,
@ -621,7 +686,13 @@ func (s *IntegrationTestSuite) TestTailDrop() {
"ls", "ls",
fmt.Sprintf("/tmp/file_from_%s", peername), fmt.Sprintf("/tmp/file_from_%s", peername),
} }
fmt.Printf("Checking file in %s (%s) from %s (%s)\n", hostname, ips[hostname], peername, ip) fmt.Printf(
"Checking file in %s (%s) from %s (%s)\n",
hostname,
ips[hostname],
peername,
ip,
)
result, err := ExecuteCommand( result, err := ExecuteCommand(
&tailscale, &tailscale,
command, command,
@ -629,7 +700,11 @@ func (s *IntegrationTestSuite) TestTailDrop() {
) )
assert.Nil(t, err) assert.Nil(t, err)
fmt.Printf("Result for %s: %s\n", peername, result) fmt.Printf("Result for %s: %s\n", peername, result)
assert.Equal(t, result, fmt.Sprintf("/tmp/file_from_%s\n", peername)) assert.Equal(
t,
result,
fmt.Sprintf("/tmp/file_from_%s\n", peername),
)
} }
}) })
} }
@ -696,10 +771,13 @@ func getIPs(tailscales map[string]dockertest.Resource) (map[string]netaddr.IP, e
ips[hostname] = ip ips[hostname] = ip
} }
return ips, nil return ips, nil
} }
func getAPIURLs(tailscales map[string]dockertest.Resource) (map[netaddr.IP]string, error) { func getAPIURLs(
tailscales map[string]dockertest.Resource,
) (map[netaddr.IP]string, error) {
fts := make(map[netaddr.IP]string) fts := make(map[netaddr.IP]string)
for _, tailscale := range tailscales { for _, tailscale := range tailscales {
command := []string{ command := []string{
@ -733,5 +811,6 @@ func getAPIURLs(tailscales map[string]dockertest.Resource) (map[netaddr.IP]strin
} }
} }
} }
return fts, nil return fts, nil
} }

View file

@ -24,6 +24,7 @@ Configure DERP servers by editing `base/site/derp.yaml` if needed.
You'll somehow need to get `headscale:latest` into your cluster image registry. You'll somehow need to get `headscale:latest` into your cluster image registry.
An easy way to do this with k3s: An easy way to do this with k3s:
- Reconfigure k3s to use docker instead of containerd (`k3s server --docker`) - Reconfigure k3s to use docker instead of containerd (`k3s server --docker`)
- `docker build -t headscale:latest ..` from here - `docker build -t headscale:latest ..` from here

View file

@ -10,10 +10,9 @@ import (
"time" "time"
"github.com/fatih/set" "github.com/fatih/set"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"gorm.io/datatypes" "gorm.io/datatypes"
"gorm.io/gorm" "gorm.io/gorm"
"inet.af/netaddr" "inet.af/netaddr"
@ -21,7 +20,13 @@ import (
"tailscale.com/types/wgkey" "tailscale.com/types/wgkey"
) )
// Machine is a Headscale client const (
errMachineNotFound = Error("machine not found")
errMachineAlreadyRegistered = Error("machine already registered")
errMachineRouteIsNotAvailable = Error("route is not available on machine")
)
// Machine is a Headscale client.
type Machine struct { type Machine struct {
ID uint64 `gorm:"primary_key"` ID uint64 `gorm:"primary_key"`
MachineKey string `gorm:"type:varchar(64);unique_index"` MachineKey string `gorm:"type:varchar(64);unique_index"`
@ -56,53 +61,58 @@ type (
MachinesP []*Machine MachinesP []*Machine
) )
// For the time being this method is rather naive // For the time being this method is rather naive.
func (m Machine) isAlreadyRegistered() bool { func (machine Machine) isAlreadyRegistered() bool {
return m.Registered return machine.Registered
} }
// isExpired returns whether the machine registration has expired // isExpired returns whether the machine registration has expired.
func (m Machine) isExpired() bool { func (machine Machine) isExpired() bool {
return time.Now().UTC().After(*m.Expiry) return time.Now().UTC().After(*machine.Expiry)
} }
// If the Machine is expired, updateMachineExpiry updates the Machine Expiry time to the maximum allowed duration, // If the Machine is expired, updateMachineExpiry updates the Machine Expiry time to the maximum allowed duration,
// or the default duration if no Expiry time was requested by the client. The expiry time here does not (yet) cause // or the default duration if no Expiry time was requested by the client. The expiry time here does not (yet) cause
// a client to be disconnected, however they will have to re-auth the machine if they attempt to reconnect after the // a client to be disconnected, however they will have to re-auth the machine if they attempt to reconnect after the
// expiry time. // expiry time.
func (h *Headscale) updateMachineExpiry(m *Machine) { func (h *Headscale) updateMachineExpiry(machine *Machine) {
if m.isExpired() { if machine.isExpired() {
now := time.Now().UTC() now := time.Now().UTC()
maxExpiry := now.Add(h.cfg.MaxMachineRegistrationDuration) // calculate the maximum expiry maxExpiry := now.Add(
defaultExpiry := now.Add(h.cfg.DefaultMachineRegistrationDuration) // calculate the default expiry h.cfg.MaxMachineRegistrationDuration,
) // calculate the maximum expiry
defaultExpiry := now.Add(
h.cfg.DefaultMachineRegistrationDuration,
) // calculate the default expiry
// clamp the expiry time of the machine registration to the maximum allowed, or use the default if none supplied // clamp the expiry time of the machine registration to the maximum allowed, or use the default if none supplied
if maxExpiry.Before(*m.RequestedExpiry) { if maxExpiry.Before(*machine.RequestedExpiry) {
log.Debug(). log.Debug().
Msgf("Clamping registration expiry time to maximum: %v (%v)", maxExpiry, h.cfg.MaxMachineRegistrationDuration) Msgf("Clamping registration expiry time to maximum: %v (%v)", maxExpiry, h.cfg.MaxMachineRegistrationDuration)
m.Expiry = &maxExpiry machine.Expiry = &maxExpiry
} else if m.RequestedExpiry.IsZero() { } else if machine.RequestedExpiry.IsZero() {
log.Debug().Msgf("Using default machine registration expiry time: %v (%v)", defaultExpiry, h.cfg.DefaultMachineRegistrationDuration) log.Debug().Msgf("Using default machine registration expiry time: %v (%v)", defaultExpiry, h.cfg.DefaultMachineRegistrationDuration)
m.Expiry = &defaultExpiry machine.Expiry = &defaultExpiry
} else { } else {
log.Debug().Msgf("Using requested machine registration expiry time: %v", m.RequestedExpiry) log.Debug().Msgf("Using requested machine registration expiry time: %v", machine.RequestedExpiry)
m.Expiry = m.RequestedExpiry machine.Expiry = machine.RequestedExpiry
} }
h.db.Save(&m) h.db.Save(&machine)
} }
} }
func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) { func (h *Headscale) getDirectPeers(machine *Machine) (Machines, error) {
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Finding direct peers") Msg("Finding direct peers")
machines := Machines{} machines := Machines{}
if err := h.db.Preload("Namespace").Where("namespace_id = ? AND machine_key <> ? AND registered", if err := h.db.Preload("Namespace").Where("namespace_id = ? AND machine_key <> ? AND registered",
m.NamespaceID, m.MachineKey).Find(&machines).Error; err != nil { machine.NamespaceID, machine.MachineKey).Find(&machines).Error; err != nil {
log.Error().Err(err).Msg("Error accessing db") log.Error().Err(err).Msg("Error accessing db")
return Machines{}, err return Machines{}, err
} }
@ -110,21 +120,22 @@ func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) {
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine", m.Name). Str("machine", machine.Name).
Msgf("Found direct machines: %s", machines.String()) Msgf("Found direct machines: %s", machines.String())
return machines, nil return machines, nil
} }
// getShared fetches machines that are shared to the `Namespace` of the machine we are getting peers for // getShared fetches machines that are shared to the `Namespace` of the machine we are getting peers for.
func (h *Headscale) getShared(m *Machine) (Machines, error) { func (h *Headscale) getShared(machine *Machine) (Machines, error) {
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Finding shared peers") Msg("Finding shared peers")
sharedMachines := []SharedMachine{} sharedMachines := []SharedMachine{}
if err := h.db.Preload("Namespace").Preload("Machine").Preload("Machine.Namespace").Where("namespace_id = ?", if err := h.db.Preload("Namespace").Preload("Machine").Preload("Machine.Namespace").Where("namespace_id = ?",
m.NamespaceID).Find(&sharedMachines).Error; err != nil { machine.NamespaceID).Find(&sharedMachines).Error; err != nil {
return Machines{}, err return Machines{}, err
} }
@ -137,27 +148,30 @@ func (h *Headscale) getShared(m *Machine) (Machines, error) {
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine", m.Name). Str("machine", machine.Name).
Msgf("Found shared peers: %s", peers.String()) Msgf("Found shared peers: %s", peers.String())
return peers, nil return peers, nil
} }
// getSharedTo fetches the machines of the namespaces this machine is shared in // getSharedTo fetches the machines of the namespaces this machine is shared in.
func (h *Headscale) getSharedTo(m *Machine) (Machines, error) { func (h *Headscale) getSharedTo(machine *Machine) (Machines, error) {
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Finding peers in namespaces this machine is shared with") Msg("Finding peers in namespaces this machine is shared with")
sharedMachines := []SharedMachine{} sharedMachines := []SharedMachine{}
if err := h.db.Preload("Namespace").Preload("Machine").Preload("Machine.Namespace").Where("machine_id = ?", if err := h.db.Preload("Namespace").Preload("Machine").Preload("Machine.Namespace").Where("machine_id = ?",
m.ID).Find(&sharedMachines).Error; err != nil { machine.ID).Find(&sharedMachines).Error; err != nil {
return Machines{}, err return Machines{}, err
} }
peers := make(Machines, 0) peers := make(Machines, 0)
for _, sharedMachine := range sharedMachines { for _, sharedMachine := range sharedMachines {
namespaceMachines, err := h.ListMachinesInNamespace(sharedMachine.Namespace.Name) namespaceMachines, err := h.ListMachinesInNamespace(
sharedMachine.Namespace.Name,
)
if err != nil { if err != nil {
return Machines{}, err return Machines{}, err
} }
@ -168,36 +182,40 @@ func (h *Headscale) getSharedTo(m *Machine) (Machines, error) {
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine", m.Name). Str("machine", machine.Name).
Msgf("Found peers we are shared with: %s", peers.String()) Msgf("Found peers we are shared with: %s", peers.String())
return peers, nil return peers, nil
} }
func (h *Headscale) getPeers(m *Machine) (Machines, error) { func (h *Headscale) getPeers(machine *Machine) (Machines, error) {
direct, err := h.getDirectPeers(m) direct, err := h.getDirectPeers(machine)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
Err(err). Err(err).
Msg("Cannot fetch peers") Msg("Cannot fetch peers")
return Machines{}, err return Machines{}, err
} }
shared, err := h.getShared(m) shared, err := h.getShared(machine)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
Err(err). Err(err).
Msg("Cannot fetch peers") Msg("Cannot fetch peers")
return Machines{}, err return Machines{}, err
} }
sharedTo, err := h.getSharedTo(m) sharedTo, err := h.getSharedTo(machine)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
Err(err). Err(err).
Msg("Cannot fetch peers") Msg("Cannot fetch peers")
return Machines{}, err return Machines{}, err
} }
@ -208,7 +226,7 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) {
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine", m.Name). Str("machine", machine.Name).
Msgf("Found total peers: %s", peers.String()) Msgf("Found total peers: %s", peers.String())
return peers, nil return peers, nil
@ -219,10 +237,11 @@ func (h *Headscale) ListMachines() ([]Machine, error) {
if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Find(&machines).Error; err != nil { if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Find(&machines).Error; err != nil {
return nil, err return nil, err
} }
return machines, nil return machines, nil
} }
// GetMachine finds a Machine by name and namespace and returns the Machine struct // GetMachine finds a Machine by name and namespace and returns the Machine struct.
func (h *Headscale) GetMachine(namespace string, name string) (*Machine, error) { func (h *Headscale) GetMachine(namespace string, name string) (*Machine, error) {
machines, err := h.ListMachinesInNamespace(namespace) machines, err := h.ListMachinesInNamespace(namespace)
if err != nil { if err != nil {
@ -234,73 +253,77 @@ func (h *Headscale) GetMachine(namespace string, name string) (*Machine, error)
return &m, nil return &m, nil
} }
} }
return nil, fmt.Errorf("machine not found")
return nil, errMachineNotFound
} }
// GetMachineByID finds a Machine by ID and returns the Machine struct // GetMachineByID finds a Machine by ID and returns the Machine struct.
func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) { func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
m := Machine{} m := Machine{}
if result := h.db.Preload("Namespace").Find(&Machine{ID: id}).First(&m); result.Error != nil { if result := h.db.Preload("Namespace").Find(&Machine{ID: id}).First(&m); result.Error != nil {
return nil, result.Error return nil, result.Error
} }
return &m, nil return &m, nil
} }
// GetMachineByMachineKey finds a Machine by ID and returns the Machine struct // GetMachineByMachineKey finds a Machine by ID and returns the Machine struct.
func (h *Headscale) GetMachineByMachineKey(mKey string) (*Machine, error) { func (h *Headscale) GetMachineByMachineKey(machineKey string) (*Machine, error) {
m := Machine{} m := Machine{}
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey); result.Error != nil { if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", machineKey); result.Error != nil {
return nil, result.Error return nil, result.Error
} }
return &m, nil return &m, nil
} }
// UpdateMachine takes a Machine struct pointer (typically already loaded from database // UpdateMachine takes a Machine struct pointer (typically already loaded from database
// and updates it with the latest data from the database. // and updates it with the latest data from the database.
func (h *Headscale) UpdateMachine(m *Machine) error { func (h *Headscale) UpdateMachine(machine *Machine) error {
if result := h.db.Find(m).First(&m); result.Error != nil { if result := h.db.Find(machine).First(&machine); result.Error != nil {
return result.Error return result.Error
} }
return nil return nil
} }
// DeleteMachine softs deletes a Machine from the database // DeleteMachine softs deletes a Machine from the database.
func (h *Headscale) DeleteMachine(m *Machine) error { func (h *Headscale) DeleteMachine(machine *Machine) error {
err := h.RemoveSharedMachineFromAllNamespaces(m) err := h.RemoveSharedMachineFromAllNamespaces(machine)
if err != nil && err != errorMachineNotShared { if err != nil && errors.Is(err, errMachineNotShared) {
return err return err
} }
m.Registered = false machine.Registered = false
namespaceID := m.NamespaceID namespaceID := machine.NamespaceID
h.db.Save(&m) // we mark it as unregistered, just in case h.db.Save(&machine) // we mark it as unregistered, just in case
if err := h.db.Delete(&m).Error; err != nil { if err := h.db.Delete(&machine).Error; err != nil {
return err return err
} }
return h.RequestMapUpdates(namespaceID) return h.RequestMapUpdates(namespaceID)
} }
// HardDeleteMachine hard deletes a Machine from the database // HardDeleteMachine hard deletes a Machine from the database.
func (h *Headscale) HardDeleteMachine(m *Machine) error { func (h *Headscale) HardDeleteMachine(machine *Machine) error {
err := h.RemoveSharedMachineFromAllNamespaces(m) err := h.RemoveSharedMachineFromAllNamespaces(machine)
if err != nil && err != errorMachineNotShared { if err != nil && errors.Is(err, errMachineNotShared) {
return err return err
} }
namespaceID := m.NamespaceID namespaceID := machine.NamespaceID
if err := h.db.Unscoped().Delete(&m).Error; err != nil { if err := h.db.Unscoped().Delete(&machine).Error; err != nil {
return err return err
} }
return h.RequestMapUpdates(namespaceID) return h.RequestMapUpdates(namespaceID)
} }
// GetHostInfo returns a Hostinfo struct for the machine // GetHostInfo returns a Hostinfo struct for the machine.
func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) { func (machine *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) {
hostinfo := tailcfg.Hostinfo{} hostinfo := tailcfg.Hostinfo{}
if len(m.HostInfo) != 0 { if len(machine.HostInfo) != 0 {
hi, err := m.HostInfo.MarshalJSON() hi, err := machine.HostInfo.MarshalJSON()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -309,21 +332,21 @@ func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) {
return nil, err return nil, err
} }
} }
return &hostinfo, nil return &hostinfo, nil
} }
func (h *Headscale) isOutdated(m *Machine) bool { func (h *Headscale) isOutdated(machine *Machine) bool {
err := h.UpdateMachine(m) if err := h.UpdateMachine(machine); err != nil {
if err != nil {
// It does not seem meaningful to propagate this error as the end result // It does not seem meaningful to propagate this error as the end result
// will have to be that the machine has to be considered outdated. // will have to be that the machine has to be considered outdated.
return true return true
} }
sharedMachines, _ := h.getShared(m) sharedMachines, _ := h.getShared(machine)
namespaceSet := set.New(set.ThreadSafe) namespaceSet := set.New(set.ThreadSafe)
namespaceSet.Add(m.Namespace.Name) namespaceSet.Add(machine.Namespace.Name)
// Check if any of our shared namespaces has updates that we have // Check if any of our shared namespaces has updates that we have
// not propagated. // not propagated.
@ -333,27 +356,30 @@ func (h *Headscale) isOutdated(m *Machine) bool {
namespaces := make([]string, namespaceSet.Size()) namespaces := make([]string, namespaceSet.Size())
for index, namespace := range namespaceSet.List() { for index, namespace := range namespaceSet.List() {
namespaces[index] = namespace.(string) if name, ok := namespace.(string); ok {
namespaces[index] = name
}
} }
lastChange := h.getLastStateChange(namespaces...) lastChange := h.getLastStateChange(namespaces...)
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine", m.Name). Str("machine", machine.Name).
Time("last_successful_update", *m.LastSuccessfulUpdate). Time("last_successful_update", *machine.LastSuccessfulUpdate).
Time("last_state_change", lastChange). Time("last_state_change", lastChange).
Msgf("Checking if %s is missing updates", m.Name) Msgf("Checking if %s is missing updates", machine.Name)
return m.LastSuccessfulUpdate.Before(lastChange)
return machine.LastSuccessfulUpdate.Before(lastChange)
} }
func (m Machine) String() string { func (machine Machine) String() string {
return m.Name return machine.Name
} }
func (ms Machines) String() string { func (machines Machines) String() string {
temp := make([]string, len(ms)) temp := make([]string, len(machines))
for index, machine := range ms { for index, machine := range machines {
temp[index] = machine.Name temp[index] = machine.Name
} }
@ -361,24 +387,24 @@ func (ms Machines) String() string {
} }
// TODO(kradalby): Remove when we have generics... // TODO(kradalby): Remove when we have generics...
func (ms MachinesP) String() string { func (machines MachinesP) String() string {
temp := make([]string, len(ms)) temp := make([]string, len(machines))
for index, machine := range ms { for index, machine := range machines {
temp[index] = machine.Name temp[index] = machine.Name
} }
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp)) return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
} }
func (ms Machines) toNodes( func (machines Machines) toNodes(
baseDomain string, baseDomain string,
dnsConfig *tailcfg.DNSConfig, dnsConfig *tailcfg.DNSConfig,
includeRoutes bool, includeRoutes bool,
) ([]*tailcfg.Node, error) { ) ([]*tailcfg.Node, error) {
nodes := make([]*tailcfg.Node, len(ms)) nodes := make([]*tailcfg.Node, len(machines))
for index, machine := range ms { for index, machine := range machines {
node, err := machine.toNode(baseDomain, dnsConfig, includeRoutes) node, err := machine.toNode(baseDomain, dnsConfig, includeRoutes)
if err != nil { if err != nil {
return nil, err return nil, err
@ -391,20 +417,25 @@ func (ms Machines) toNodes(
} }
// toNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes // toNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes
// as per the expected behaviour in the official SaaS // as per the expected behaviour in the official SaaS.
func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, includeRoutes bool) (*tailcfg.Node, error) { func (machine Machine) toNode(
nKey, err := wgkey.ParseHex(m.NodeKey) baseDomain string,
dnsConfig *tailcfg.DNSConfig,
includeRoutes bool,
) (*tailcfg.Node, error) {
nodeKey, err := wgkey.ParseHex(machine.NodeKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
mKey, err := wgkey.ParseHex(m.MachineKey)
machineKey, err := wgkey.ParseHex(machine.MachineKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var discoKey tailcfg.DiscoKey var discoKey tailcfg.DiscoKey
if m.DiscoKey != "" { if machine.DiscoKey != "" {
dKey, err := wgkey.ParseHex(m.DiscoKey) dKey, err := wgkey.ParseHex(machine.DiscoKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -414,23 +445,27 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include
} }
addrs := []netaddr.IPPrefix{} addrs := []netaddr.IPPrefix{}
ip, err := netaddr.ParseIPPrefix(fmt.Sprintf("%s/32", m.IPAddress)) ip, err := netaddr.ParseIPPrefix(fmt.Sprintf("%s/32", machine.IPAddress))
if err != nil { if err != nil {
log.Trace(). log.Trace().
Caller(). Caller().
Str("ip", m.IPAddress). Str("ip", machine.IPAddress).
Msgf("Failed to parse IP Prefix from IP: %s", m.IPAddress) Msgf("Failed to parse IP Prefix from IP: %s", machine.IPAddress)
return nil, err return nil, err
} }
addrs = append(addrs, ip) // missing the ipv6 ? addrs = append(addrs, ip) // missing the ipv6 ?
allowedIPs := []netaddr.IPPrefix{} allowedIPs := []netaddr.IPPrefix{}
allowedIPs = append(allowedIPs, ip) // we append the node own IP, as it is required by the clients allowedIPs = append(
allowedIPs,
ip,
) // we append the node own IP, as it is required by the clients
if includeRoutes { if includeRoutes {
routesStr := []string{} routesStr := []string{}
if len(m.EnabledRoutes) != 0 { if len(machine.EnabledRoutes) != 0 {
allwIps, err := m.EnabledRoutes.MarshalJSON() allwIps, err := machine.EnabledRoutes.MarshalJSON()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -450,8 +485,8 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include
} }
endpoints := []string{} endpoints := []string{}
if len(m.Endpoints) != 0 { if len(machine.Endpoints) != 0 {
be, err := m.Endpoints.MarshalJSON() be, err := machine.Endpoints.MarshalJSON()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -462,8 +497,8 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include
} }
hostinfo := tailcfg.Hostinfo{} hostinfo := tailcfg.Hostinfo{}
if len(m.HostInfo) != 0 { if len(machine.HostInfo) != 0 {
hi, err := m.HostInfo.MarshalJSON() hi, err := machine.HostInfo.MarshalJSON()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -481,29 +516,34 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include
} }
var keyExpiry time.Time var keyExpiry time.Time
if m.Expiry != nil { if machine.Expiry != nil {
keyExpiry = *m.Expiry keyExpiry = *machine.Expiry
} else { } else {
keyExpiry = time.Time{} keyExpiry = time.Time{}
} }
var hostname string var hostname string
if dnsConfig != nil && dnsConfig.Proxied { // MagicDNS if dnsConfig != nil && dnsConfig.Proxied { // MagicDNS
hostname = fmt.Sprintf("%s.%s.%s", m.Name, m.Namespace.Name, baseDomain) hostname = fmt.Sprintf(
"%s.%s.%s",
machine.Name,
machine.Namespace.Name,
baseDomain,
)
} else { } else {
hostname = m.Name hostname = machine.Name
} }
n := tailcfg.Node{ node := tailcfg.Node{
ID: tailcfg.NodeID(m.ID), // this is the actual ID ID: tailcfg.NodeID(machine.ID), // this is the actual ID
StableID: tailcfg.StableNodeID( StableID: tailcfg.StableNodeID(
strconv.FormatUint(m.ID, 10), strconv.FormatUint(machine.ID, Base10),
), // in headscale, unlike tailcontrol server, IDs are permanent ), // in headscale, unlike tailcontrol server, IDs are permanent
Name: hostname, Name: hostname,
User: tailcfg.UserID(m.NamespaceID), User: tailcfg.UserID(machine.NamespaceID),
Key: tailcfg.NodeKey(nKey), Key: tailcfg.NodeKey(nodeKey),
KeyExpiry: keyExpiry, KeyExpiry: keyExpiry,
Machine: tailcfg.MachineKey(mKey), Machine: tailcfg.MachineKey(machineKey),
DiscoKey: discoKey, DiscoKey: discoKey,
Addresses: addrs, Addresses: addrs,
AllowedIPs: allowedIPs, AllowedIPs: allowedIPs,
@ -511,81 +551,90 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include
DERP: derp, DERP: derp,
Hostinfo: hostinfo, Hostinfo: hostinfo,
Created: m.CreatedAt, Created: machine.CreatedAt,
LastSeen: m.LastSeen, LastSeen: machine.LastSeen,
KeepAlive: true, KeepAlive: true,
MachineAuthorized: m.Registered, MachineAuthorized: machine.Registered,
Capabilities: []string{tailcfg.CapabilityFileSharing}, Capabilities: []string{tailcfg.CapabilityFileSharing},
} }
return &n, nil
return &node, nil
} }
func (m *Machine) toProto() *v1.Machine { func (machine *Machine) toProto() *v1.Machine {
machine := &v1.Machine{ machineProto := &v1.Machine{
Id: m.ID, Id: machine.ID,
MachineKey: m.MachineKey, MachineKey: machine.MachineKey,
NodeKey: m.NodeKey, NodeKey: machine.NodeKey,
DiscoKey: m.DiscoKey, DiscoKey: machine.DiscoKey,
IpAddress: m.IPAddress, IpAddress: machine.IPAddress,
Name: m.Name, Name: machine.Name,
Namespace: m.Namespace.toProto(), Namespace: machine.Namespace.toProto(),
Registered: m.Registered, Registered: machine.Registered,
// TODO(kradalby): Implement register method enum converter // TODO(kradalby): Implement register method enum converter
// RegisterMethod: , // RegisterMethod: ,
CreatedAt: timestamppb.New(m.CreatedAt), CreatedAt: timestamppb.New(machine.CreatedAt),
} }
if m.AuthKey != nil { if machine.AuthKey != nil {
machine.PreAuthKey = m.AuthKey.toProto() machineProto.PreAuthKey = machine.AuthKey.toProto()
} }
if m.LastSeen != nil { if machine.LastSeen != nil {
machine.LastSeen = timestamppb.New(*m.LastSeen) machineProto.LastSeen = timestamppb.New(*machine.LastSeen)
} }
if m.LastSuccessfulUpdate != nil { if machine.LastSuccessfulUpdate != nil {
machine.LastSuccessfulUpdate = timestamppb.New(*m.LastSuccessfulUpdate) machineProto.LastSuccessfulUpdate = timestamppb.New(
*machine.LastSuccessfulUpdate,
)
} }
if m.Expiry != nil { if machine.Expiry != nil {
machine.Expiry = timestamppb.New(*m.Expiry) machineProto.Expiry = timestamppb.New(*machine.Expiry)
} }
return machine return machineProto
} }
// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey // RegisterMachine is executed from the CLI to register a new Machine using its MachineKey.
func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, error) { func (h *Headscale) RegisterMachine(
ns, err := h.GetNamespace(namespace) key string,
namespaceName string,
) (*Machine, error) {
namespace, err := h.GetNamespace(namespaceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
mKey, err := wgkey.ParseHex(key) machineKey, err := wgkey.ParseHex(key)
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := Machine{} machine := Machine{}
if result := h.db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) { if result := h.db.First(&machine, "machine_key = ?", machineKey.HexString()); errors.Is(
return nil, errors.New("Machine not found") result.Error,
gorm.ErrRecordNotFound,
) {
return nil, errMachineNotFound
} }
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Attempting to register machine") Msg("Attempting to register machine")
if m.isAlreadyRegistered() { if machine.isAlreadyRegistered() {
err := errors.New("Machine already registered") err := errMachineAlreadyRegistered
log.Error(). log.Error().
Caller(). Caller().
Err(err). Err(err).
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Attempting to register machine") Msg("Attempting to register machine")
return nil, err return nil, err
@ -596,42 +645,44 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err
log.Error(). log.Error().
Caller(). Caller().
Err(err). Err(err).
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Could not find IP for the new machine") Msg("Could not find IP for the new machine")
return nil, err return nil, err
} }
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine", m.Name). Str("machine", machine.Name).
Str("ip", ip.String()). Str("ip", ip.String()).
Msg("Found IP for host") Msg("Found IP for host")
m.IPAddress = ip.String() machine.IPAddress = ip.String()
m.NamespaceID = ns.ID machine.NamespaceID = namespace.ID
m.Registered = true machine.Registered = true
m.RegisterMethod = "cli" machine.RegisterMethod = "cli"
h.db.Save(&m) h.db.Save(&machine)
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine", m.Name). Str("machine", machine.Name).
Str("ip", ip.String()). Str("ip", ip.String()).
Msg("Machine registered with the database") Msg("Machine registered with the database")
return &m, nil return &machine, nil
} }
func (m *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) { func (machine *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) {
hostInfo, err := m.GetHostInfo() hostInfo, err := machine.GetHostInfo()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return hostInfo.RoutableIPs, nil return hostInfo.RoutableIPs, nil
} }
func (m *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) { func (machine *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) {
data, err := m.EnabledRoutes.MarshalJSON() data, err := machine.EnabledRoutes.MarshalJSON()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -654,13 +705,13 @@ func (m *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) {
return routes, nil return routes, nil
} }
func (m *Machine) IsRoutesEnabled(routeStr string) bool { func (machine *Machine) IsRoutesEnabled(routeStr string) bool {
route, err := netaddr.ParseIPPrefix(routeStr) route, err := netaddr.ParseIPPrefix(routeStr)
if err != nil { if err != nil {
return false return false
} }
enabledRoutes, err := m.GetEnabledRoutes() enabledRoutes, err := machine.GetEnabledRoutes()
if err != nil { if err != nil {
return false return false
} }
@ -670,12 +721,13 @@ func (m *Machine) IsRoutesEnabled(routeStr string) bool {
return true return true
} }
} }
return false return false
} }
// EnableNodeRoute enables new routes based on a list of new routes. It will _replace_ the // EnableNodeRoute enables new routes based on a list of new routes. It will _replace_ the
// previous list of routes. // previous list of routes.
func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error { func (h *Headscale) EnableRoutes(machine *Machine, routeStrs ...string) error {
newRoutes := make([]netaddr.IPPrefix, len(routeStrs)) newRoutes := make([]netaddr.IPPrefix, len(routeStrs))
for index, routeStr := range routeStrs { for index, routeStr := range routeStrs {
route, err := netaddr.ParseIPPrefix(routeStr) route, err := netaddr.ParseIPPrefix(routeStr)
@ -686,14 +738,18 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
newRoutes[index] = route newRoutes[index] = route
} }
availableRoutes, err := m.GetAdvertisedRoutes() availableRoutes, err := machine.GetAdvertisedRoutes()
if err != nil { if err != nil {
return err return err
} }
for _, newRoute := range newRoutes { for _, newRoute := range newRoutes {
if !containsIpPrefix(availableRoutes, newRoute) { if !containsIPPrefix(availableRoutes, newRoute) {
return fmt.Errorf("route (%s) is not available on node %s", m.Name, newRoute) return fmt.Errorf(
"route (%s) is not available on node %s: %w",
machine.Name,
newRoute, errMachineRouteIsNotAvailable,
)
} }
} }
@ -702,10 +758,10 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
return err return err
} }
m.EnabledRoutes = datatypes.JSON(routes) machine.EnabledRoutes = datatypes.JSON(routes)
h.db.Save(&m) h.db.Save(&machine)
err = h.RequestMapUpdates(m.NamespaceID) err = h.RequestMapUpdates(machine.NamespaceID)
if err != nil { if err != nil {
return err return err
} }
@ -713,13 +769,13 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
return nil return nil
} }
func (m *Machine) RoutesToProto() (*v1.Routes, error) { func (machine *Machine) RoutesToProto() (*v1.Routes, error) {
availableRoutes, err := m.GetAdvertisedRoutes() availableRoutes, err := machine.GetAdvertisedRoutes()
if err != nil { if err != nil {
return nil, err return nil, err
} }
enabledRoutes, err := m.GetEnabledRoutes() enabledRoutes, err := machine.GetEnabledRoutes()
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -8,152 +8,159 @@ import (
) )
func (s *Suite) TestGetMachine(c *check.C) { func (s *Suite) TestGetMachine(c *check.C) {
n, err := h.CreateNamespace("test") namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = h.GetMachine("test", "testmachine") _, err = app.GetMachine("test", "testmachine")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
m := &Machine{ machine := &Machine{
ID: 0, ID: 0,
MachineKey: "foo", MachineKey: "foo",
NodeKey: "bar", NodeKey: "bar",
DiscoKey: "faa", DiscoKey: "faa",
Name: "testmachine", Name: "testmachine",
NamespaceID: n.ID, NamespaceID: namespace.ID,
Registered: true, Registered: true,
RegisterMethod: "authKey", RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
} }
h.db.Save(m) app.db.Save(machine)
m1, err := h.GetMachine("test", "testmachine") machineFromDB, err := app.GetMachine("test", "testmachine")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = m1.GetHostInfo() _, err = machineFromDB.GetHostInfo()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
} }
func (s *Suite) TestGetMachineByID(c *check.C) { func (s *Suite) TestGetMachineByID(c *check.C) {
n, err := h.CreateNamespace("test") namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = h.GetMachineByID(0) _, err = app.GetMachineByID(0)
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
m := Machine{ machine := Machine{
ID: 0, ID: 0,
MachineKey: "foo", MachineKey: "foo",
NodeKey: "bar", NodeKey: "bar",
DiscoKey: "faa", DiscoKey: "faa",
Name: "testmachine", Name: "testmachine",
NamespaceID: n.ID, NamespaceID: namespace.ID,
Registered: true, Registered: true,
RegisterMethod: "authKey", RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
} }
h.db.Save(&m) app.db.Save(&machine)
m1, err := h.GetMachineByID(0) machineByID, err := app.GetMachineByID(0)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = m1.GetHostInfo() _, err = machineByID.GetHostInfo()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
} }
func (s *Suite) TestDeleteMachine(c *check.C) { func (s *Suite) TestDeleteMachine(c *check.C) {
n, err := h.CreateNamespace("test") namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
m := Machine{ machine := Machine{
ID: 0, ID: 0,
MachineKey: "foo", MachineKey: "foo",
NodeKey: "bar", NodeKey: "bar",
DiscoKey: "faa", DiscoKey: "faa",
Name: "testmachine", Name: "testmachine",
NamespaceID: n.ID, NamespaceID: namespace.ID,
Registered: true, Registered: true,
RegisterMethod: "authKey", RegisterMethod: "authKey",
AuthKeyID: uint(1), AuthKeyID: uint(1),
} }
h.db.Save(&m) app.db.Save(&machine)
err = h.DeleteMachine(&m)
err = app.DeleteMachine(&machine)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
v, err := h.getValue("namespaces_pending_updates")
namespacesPendingUpdates, err := app.getValue("namespaces_pending_updates")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
names := []string{} names := []string{}
err = json.Unmarshal([]byte(v), &names) err = json.Unmarshal([]byte(namespacesPendingUpdates), &names)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(names, check.DeepEquals, []string{n.Name}) c.Assert(names, check.DeepEquals, []string{namespace.Name})
h.checkForNamespacesPendingUpdates()
v, _ = h.getValue("namespaces_pending_updates") app.checkForNamespacesPendingUpdates()
c.Assert(v, check.Equals, "")
_, err = h.GetMachine(n.Name, "testmachine") namespacesPendingUpdates, _ = app.getValue("namespaces_pending_updates")
c.Assert(namespacesPendingUpdates, check.Equals, "")
_, err = app.GetMachine(namespace.Name, "testmachine")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
} }
func (s *Suite) TestHardDeleteMachine(c *check.C) { func (s *Suite) TestHardDeleteMachine(c *check.C) {
n, err := h.CreateNamespace("test") namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
m := Machine{ machine := Machine{
ID: 0, ID: 0,
MachineKey: "foo", MachineKey: "foo",
NodeKey: "bar", NodeKey: "bar",
DiscoKey: "faa", DiscoKey: "faa",
Name: "testmachine3", Name: "testmachine3",
NamespaceID: n.ID, NamespaceID: namespace.ID,
Registered: true, Registered: true,
RegisterMethod: "authKey", RegisterMethod: "authKey",
AuthKeyID: uint(1), AuthKeyID: uint(1),
} }
h.db.Save(&m) app.db.Save(&machine)
err = h.HardDeleteMachine(&m)
err = app.HardDeleteMachine(&machine)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = h.GetMachine(n.Name, "testmachine3")
_, err = app.GetMachine(namespace.Name, "testmachine3")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
} }
func (s *Suite) TestGetDirectPeers(c *check.C) { func (s *Suite) TestGetDirectPeers(c *check.C) {
n, err := h.CreateNamespace("test") namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = h.GetMachineByID(0) _, err = app.GetMachineByID(0)
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
for i := 0; i <= 10; i++ { for index := 0; index <= 10; index++ {
m := Machine{ machine := Machine{
ID: uint64(i), ID: uint64(index),
MachineKey: "foo" + strconv.Itoa(i), MachineKey: "foo" + strconv.Itoa(index),
NodeKey: "bar" + strconv.Itoa(i), NodeKey: "bar" + strconv.Itoa(index),
DiscoKey: "faa" + strconv.Itoa(i), DiscoKey: "faa" + strconv.Itoa(index),
Name: "testmachine" + strconv.Itoa(i), Name: "testmachine" + strconv.Itoa(index),
NamespaceID: n.ID, NamespaceID: namespace.ID,
Registered: true, Registered: true,
RegisterMethod: "authKey", RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
} }
h.db.Save(&m) app.db.Save(&machine)
} }
m1, err := h.GetMachineByID(0) machine0ByID, err := app.GetMachineByID(0)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = m1.GetHostInfo() _, err = machine0ByID.GetHostInfo()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
peers, err := h.getDirectPeers(m1) peersOfMachine0, err := app.getDirectPeers(machine0ByID)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(peers), check.Equals, 9) c.Assert(len(peersOfMachine0), check.Equals, 9)
c.Assert(peers[0].Name, check.Equals, "testmachine2") c.Assert(peersOfMachine0[0].Name, check.Equals, "testmachine2")
c.Assert(peers[5].Name, check.Equals, "testmachine7") c.Assert(peersOfMachine0[5].Name, check.Equals, "testmachine7")
c.Assert(peers[8].Name, check.Equals, "testmachine10") c.Assert(peersOfMachine0[8].Name, check.Equals, "testmachine10")
} }

View file

@ -15,9 +15,9 @@ import (
) )
const ( const (
errorNamespaceExists = Error("Namespace already exists") errNamespaceExists = Error("Namespace already exists")
errorNamespaceNotFound = Error("Namespace not found") errNamespaceNotFound = Error("Namespace not found")
errorNamespaceNotEmptyOfNodes = Error("Namespace not empty: node(s) found") errNamespaceNotEmptyOfNodes = Error("Namespace not empty: node(s) found")
) )
// Namespace is the way Headscale implements the concept of users in Tailscale // Namespace is the way Headscale implements the concept of users in Tailscale
@ -30,51 +30,53 @@ type Namespace struct {
} }
// CreateNamespace creates a new Namespace. Returns error if could not be created // CreateNamespace creates a new Namespace. Returns error if could not be created
// or another namespace already exists // or another namespace already exists.
func (h *Headscale) CreateNamespace(name string) (*Namespace, error) { func (h *Headscale) CreateNamespace(name string) (*Namespace, error) {
n := Namespace{} namespace := Namespace{}
if err := h.db.Where("name = ?", name).First(&n).Error; err == nil { if err := h.db.Where("name = ?", name).First(&namespace).Error; err == nil {
return nil, errorNamespaceExists return nil, errNamespaceExists
} }
n.Name = name namespace.Name = name
if err := h.db.Create(&n).Error; err != nil { if err := h.db.Create(&namespace).Error; err != nil {
log.Error(). log.Error().
Str("func", "CreateNamespace"). Str("func", "CreateNamespace").
Err(err). Err(err).
Msg("Could not create row") Msg("Could not create row")
return nil, err return nil, err
} }
return &n, nil
return &namespace, nil
} }
// DestroyNamespace destroys a Namespace. Returns error if the Namespace does // DestroyNamespace destroys a Namespace. Returns error if the Namespace does
// not exist or if there are machines associated with it. // not exist or if there are machines associated with it.
func (h *Headscale) DestroyNamespace(name string) error { func (h *Headscale) DestroyNamespace(name string) error {
n, err := h.GetNamespace(name) namespace, err := h.GetNamespace(name)
if err != nil { if err != nil {
return errorNamespaceNotFound return errNamespaceNotFound
} }
m, err := h.ListMachinesInNamespace(name) machines, err := h.ListMachinesInNamespace(name)
if err != nil { if err != nil {
return err return err
} }
if len(m) > 0 { if len(machines) > 0 {
return errorNamespaceNotEmptyOfNodes return errNamespaceNotEmptyOfNodes
} }
keys, err := h.ListPreAuthKeys(name) keys, err := h.ListPreAuthKeys(name)
if err != nil { if err != nil {
return err return err
} }
for _, p := range keys { for _, key := range keys {
err = h.DestroyPreAuthKey(&p) err = h.DestroyPreAuthKey(key)
if err != nil { if err != nil {
return err return err
} }
} }
if result := h.db.Unscoped().Delete(&n); result.Error != nil { if result := h.db.Unscoped().Delete(&namespace); result.Error != nil {
return result.Error return result.Error
} }
@ -84,25 +86,25 @@ func (h *Headscale) DestroyNamespace(name string) error {
// RenameNamespace renames a Namespace. Returns error if the Namespace does // RenameNamespace renames a Namespace. Returns error if the Namespace does
// not exist or if another Namespace exists with the new name. // not exist or if another Namespace exists with the new name.
func (h *Headscale) RenameNamespace(oldName, newName string) error { func (h *Headscale) RenameNamespace(oldName, newName string) error {
n, err := h.GetNamespace(oldName) oldNamespace, err := h.GetNamespace(oldName)
if err != nil { if err != nil {
return err return err
} }
_, err = h.GetNamespace(newName) _, err = h.GetNamespace(newName)
if err == nil { if err == nil {
return errorNamespaceExists return errNamespaceExists
} }
if !errors.Is(err, errorNamespaceNotFound) { if !errors.Is(err, errNamespaceNotFound) {
return err return err
} }
n.Name = newName oldNamespace.Name = newName
if result := h.db.Save(&n); result.Error != nil { if result := h.db.Save(&oldNamespace); result.Error != nil {
return result.Error return result.Error
} }
err = h.RequestMapUpdates(n.ID) err = h.RequestMapUpdates(oldNamespace.ID)
if err != nil { if err != nil {
return err return err
} }
@ -110,39 +112,45 @@ func (h *Headscale) RenameNamespace(oldName, newName string) error {
return nil return nil
} }
// GetNamespace fetches a namespace by name // GetNamespace fetches a namespace by name.
func (h *Headscale) GetNamespace(name string) (*Namespace, error) { func (h *Headscale) GetNamespace(name string) (*Namespace, error) {
n := Namespace{} namespace := Namespace{}
if result := h.db.First(&n, "name = ?", name); errors.Is(result.Error, gorm.ErrRecordNotFound) { if result := h.db.First(&namespace, "name = ?", name); errors.Is(
return nil, errorNamespaceNotFound result.Error,
} gorm.ErrRecordNotFound,
return &n, nil ) {
return nil, errNamespaceNotFound
} }
// ListNamespaces gets all the existing namespaces return &namespace, nil
}
// ListNamespaces gets all the existing namespaces.
func (h *Headscale) ListNamespaces() ([]Namespace, error) { func (h *Headscale) ListNamespaces() ([]Namespace, error) {
namespaces := []Namespace{} namespaces := []Namespace{}
if err := h.db.Find(&namespaces).Error; err != nil { if err := h.db.Find(&namespaces).Error; err != nil {
return nil, err return nil, err
} }
return namespaces, nil return namespaces, nil
} }
// ListMachinesInNamespace gets all the nodes in a given namespace // ListMachinesInNamespace gets all the nodes in a given namespace.
func (h *Headscale) ListMachinesInNamespace(name string) ([]Machine, error) { func (h *Headscale) ListMachinesInNamespace(name string) ([]Machine, error) {
n, err := h.GetNamespace(name) namespace, err := h.GetNamespace(name)
if err != nil { if err != nil {
return nil, err return nil, err
} }
machines := []Machine{} machines := []Machine{}
if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Where(&Machine{NamespaceID: n.ID}).Find(&machines).Error; err != nil { if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Where(&Machine{NamespaceID: namespace.ID}).Find(&machines).Error; err != nil {
return nil, err return nil, err
} }
return machines, nil return machines, nil
} }
// ListSharedMachinesInNamespace returns all the machines that are shared to the specified namespace // ListSharedMachinesInNamespace returns all the machines that are shared to the specified namespace.
func (h *Headscale) ListSharedMachinesInNamespace(name string) ([]Machine, error) { func (h *Headscale) ListSharedMachinesInNamespace(name string) ([]Machine, error) {
namespace, err := h.GetNamespace(name) namespace, err := h.GetNamespace(name)
if err != nil { if err != nil {
@ -155,48 +163,61 @@ func (h *Headscale) ListSharedMachinesInNamespace(name string) ([]Machine, error
machines := []Machine{} machines := []Machine{}
for _, sharedMachine := range sharedMachines { for _, sharedMachine := range sharedMachines {
machine, err := h.GetMachineByID(sharedMachine.MachineID) // otherwise not everything comes filled machine, err := h.GetMachineByID(
sharedMachine.MachineID,
) // otherwise not everything comes filled
if err != nil { if err != nil {
return nil, err return nil, err
} }
machines = append(machines, *machine) machines = append(machines, *machine)
} }
return machines, nil return machines, nil
} }
// SetMachineNamespace assigns a Machine to a namespace // SetMachineNamespace assigns a Machine to a namespace.
func (h *Headscale) SetMachineNamespace(m *Machine, namespaceName string) error { func (h *Headscale) SetMachineNamespace(machine *Machine, namespaceName string) error {
n, err := h.GetNamespace(namespaceName) namespace, err := h.GetNamespace(namespaceName)
if err != nil { if err != nil {
return err return err
} }
m.NamespaceID = n.ID machine.NamespaceID = namespace.ID
h.db.Save(&m) h.db.Save(&machine)
return nil return nil
} }
// RequestMapUpdates signals the KV worker to update the maps for this namespace // TODO(kradalby): Remove the need for this.
// RequestMapUpdates signals the KV worker to update the maps for this namespace.
func (h *Headscale) RequestMapUpdates(namespaceID uint) error { func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
namespace := Namespace{} namespace := Namespace{}
if err := h.db.First(&namespace, namespaceID).Error; err != nil { if err := h.db.First(&namespace, namespaceID).Error; err != nil {
return err return err
} }
v, err := h.getValue("namespaces_pending_updates") namespacesPendingUpdates, err := h.getValue("namespaces_pending_updates")
if err != nil || v == "" { if err != nil || namespacesPendingUpdates == "" {
err = h.setValue("namespaces_pending_updates", fmt.Sprintf(`["%s"]`, namespace.Name)) err = h.setValue(
"namespaces_pending_updates",
fmt.Sprintf(`["%s"]`, namespace.Name),
)
if err != nil { if err != nil {
return err return err
} }
return nil return nil
} }
names := []string{} names := []string{}
err = json.Unmarshal([]byte(v), &names) err = json.Unmarshal([]byte(namespacesPendingUpdates), &names)
if err != nil { if err != nil {
err = h.setValue("namespaces_pending_updates", fmt.Sprintf(`["%s"]`, namespace.Name)) err = h.setValue(
"namespaces_pending_updates",
fmt.Sprintf(`["%s"]`, namespace.Name),
)
if err != nil { if err != nil {
return err return err
} }
return nil return nil
} }
@ -207,22 +228,24 @@ func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
Str("func", "RequestMapUpdates"). Str("func", "RequestMapUpdates").
Err(err). Err(err).
Msg("Could not marshal namespaces_pending_updates") Msg("Could not marshal namespaces_pending_updates")
return err return err
} }
return h.setValue("namespaces_pending_updates", string(data)) return h.setValue("namespaces_pending_updates", string(data))
} }
func (h *Headscale) checkForNamespacesPendingUpdates() { func (h *Headscale) checkForNamespacesPendingUpdates() {
v, err := h.getValue("namespaces_pending_updates") namespacesPendingUpdates, err := h.getValue("namespaces_pending_updates")
if err != nil { if err != nil {
return return
} }
if v == "" { if namespacesPendingUpdates == "" {
return return
} }
namespaces := []string{} namespaces := []string{}
err = json.Unmarshal([]byte(v), &namespaces) err = json.Unmarshal([]byte(namespacesPendingUpdates), &namespaces)
if err != nil { if err != nil {
return return
} }
@ -233,24 +256,25 @@ func (h *Headscale) checkForNamespacesPendingUpdates() {
Msg("Sending updates to nodes in namespacespace") Msg("Sending updates to nodes in namespacespace")
h.setLastStateChangeToNow(namespace) h.setLastStateChangeToNow(namespace)
} }
newV, err := h.getValue("namespaces_pending_updates") newPendingUpdateValue, err := h.getValue("namespaces_pending_updates")
if err != nil { if err != nil {
return return
} }
if v == newV { // only clear when no changes, so we notified everybody if namespacesPendingUpdates == newPendingUpdateValue { // only clear when no changes, so we notified everybody
err = h.setValue("namespaces_pending_updates", "") err = h.setValue("namespaces_pending_updates", "")
if err != nil { if err != nil {
log.Error(). log.Error().
Str("func", "checkForNamespacesPendingUpdates"). Str("func", "checkForNamespacesPendingUpdates").
Err(err). Err(err).
Msg("Could not save to KV") Msg("Could not save to KV")
return return
} }
} }
} }
func (n *Namespace) toUser() *tailcfg.User { func (n *Namespace) toUser() *tailcfg.User {
u := tailcfg.User{ user := tailcfg.User{
ID: tailcfg.UserID(n.ID), ID: tailcfg.UserID(n.ID),
LoginName: n.Name, LoginName: n.Name,
DisplayName: n.Name, DisplayName: n.Name,
@ -259,25 +283,27 @@ func (n *Namespace) toUser() *tailcfg.User {
Logins: []tailcfg.LoginID{}, Logins: []tailcfg.LoginID{},
Created: time.Time{}, Created: time.Time{},
} }
return &u
return &user
} }
func (n *Namespace) toLogin() *tailcfg.Login { func (n *Namespace) toLogin() *tailcfg.Login {
l := tailcfg.Login{ login := tailcfg.Login{
ID: tailcfg.LoginID(n.ID), ID: tailcfg.LoginID(n.ID),
LoginName: n.Name, LoginName: n.Name,
DisplayName: n.Name, DisplayName: n.Name,
ProfilePicURL: "", ProfilePicURL: "",
Domain: "headscale.net", Domain: "headscale.net",
} }
return &l
return &login
} }
func getMapResponseUserProfiles(m Machine, peers Machines) []tailcfg.UserProfile { func getMapResponseUserProfiles(machine Machine, peers Machines) []tailcfg.UserProfile {
namespaceMap := make(map[string]Namespace) namespaceMap := make(map[string]Namespace)
namespaceMap[m.Namespace.Name] = m.Namespace namespaceMap[machine.Namespace.Name] = machine.Namespace
for _, p := range peers { for _, peer := range peers {
namespaceMap[p.Namespace.Name] = p.Namespace // not worth checking if already is there namespaceMap[peer.Namespace.Name] = peer.Namespace // not worth checking if already is there
} }
profiles := []tailcfg.UserProfile{} profiles := []tailcfg.UserProfile{}
@ -289,12 +315,13 @@ func getMapResponseUserProfiles(m Machine, peers Machines) []tailcfg.UserProfile
DisplayName: namespace.Name, DisplayName: namespace.Name,
}) })
} }
return profiles return profiles
} }
func (n *Namespace) toProto() *v1.Namespace { func (n *Namespace) toProto() *v1.Namespace {
return &v1.Namespace{ return &v1.Namespace{
Id: strconv.FormatUint(uint64(n.ID), 10), Id: strconv.FormatUint(uint64(n.ID), Base10),
Name: n.Name, Name: n.Name,
CreatedAt: timestamppb.New(n.CreatedAt), CreatedAt: timestamppb.New(n.CreatedAt),
} }

View file

@ -7,207 +7,232 @@ import (
) )
func (s *Suite) TestCreateAndDestroyNamespace(c *check.C) { func (s *Suite) TestCreateAndDestroyNamespace(c *check.C) {
n, err := h.CreateNamespace("test") namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(n.Name, check.Equals, "test") c.Assert(namespace.Name, check.Equals, "test")
ns, err := h.ListNamespaces() namespaces, err := app.ListNamespaces()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(ns), check.Equals, 1) c.Assert(len(namespaces), check.Equals, 1)
err = h.DestroyNamespace("test") err = app.DestroyNamespace("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = h.GetNamespace("test") _, err = app.GetNamespace("test")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
} }
func (s *Suite) TestDestroyNamespaceErrors(c *check.C) { func (s *Suite) TestDestroyNamespaceErrors(c *check.C) {
err := h.DestroyNamespace("test") err := app.DestroyNamespace("test")
c.Assert(err, check.Equals, errorNamespaceNotFound) c.Assert(err, check.Equals, errNamespaceNotFound)
n, err := h.CreateNamespace("test") namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
err = h.DestroyNamespace("test") err = app.DestroyNamespace("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
result := h.db.Preload("Namespace").First(&pak, "key = ?", pak.Key) result := app.db.Preload("Namespace").First(&pak, "key = ?", pak.Key)
// destroying a namespace also deletes all associated preauthkeys // destroying a namespace also deletes all associated preauthkeys
c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound) c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound)
n, err = h.CreateNamespace("test") namespace, err = app.CreateNamespace("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err = h.CreatePreAuthKey(n.Name, false, false, nil) pak, err = app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
m := Machine{ machine := Machine{
ID: 0, ID: 0,
MachineKey: "foo", MachineKey: "foo",
NodeKey: "bar", NodeKey: "bar",
DiscoKey: "faa", DiscoKey: "faa",
Name: "testmachine", Name: "testmachine",
NamespaceID: n.ID, NamespaceID: namespace.ID,
Registered: true, Registered: true,
RegisterMethod: "authKey", RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
} }
h.db.Save(&m) app.db.Save(&machine)
err = h.DestroyNamespace("test") err = app.DestroyNamespace("test")
c.Assert(err, check.Equals, errorNamespaceNotEmptyOfNodes) c.Assert(err, check.Equals, errNamespaceNotEmptyOfNodes)
} }
func (s *Suite) TestRenameNamespace(c *check.C) { func (s *Suite) TestRenameNamespace(c *check.C) {
n, err := h.CreateNamespace("test") namespaceTest, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(n.Name, check.Equals, "test") c.Assert(namespaceTest.Name, check.Equals, "test")
ns, err := h.ListNamespaces() namespaces, err := app.ListNamespaces()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(ns), check.Equals, 1) c.Assert(len(namespaces), check.Equals, 1)
err = h.RenameNamespace("test", "test_renamed") err = app.RenameNamespace("test", "test_renamed")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = h.GetNamespace("test") _, err = app.GetNamespace("test")
c.Assert(err, check.Equals, errorNamespaceNotFound) c.Assert(err, check.Equals, errNamespaceNotFound)
_, err = h.GetNamespace("test_renamed") _, err = app.GetNamespace("test_renamed")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
err = h.RenameNamespace("test_does_not_exit", "test") err = app.RenameNamespace("test_does_not_exit", "test")
c.Assert(err, check.Equals, errorNamespaceNotFound) c.Assert(err, check.Equals, errNamespaceNotFound)
n2, err := h.CreateNamespace("test2") namespaceTest2, err := app.CreateNamespace("test2")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(n2.Name, check.Equals, "test2") c.Assert(namespaceTest2.Name, check.Equals, "test2")
err = h.RenameNamespace("test2", "test_renamed") err = app.RenameNamespace("test2", "test_renamed")
c.Assert(err, check.Equals, errorNamespaceExists) c.Assert(err, check.Equals, errNamespaceExists)
} }
func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
n1, err := h.CreateNamespace("shared1") namespaceShared1, err := app.CreateNamespace("shared1")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
n2, err := h.CreateNamespace("shared2") namespaceShared2, err := app.CreateNamespace("shared2")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
n3, err := h.CreateNamespace("shared3") namespaceShared3, err := app.CreateNamespace("shared3")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak1n1, err := h.CreatePreAuthKey(n1.Name, false, false, nil) preAuthKeyShared1, err := app.CreatePreAuthKey(
namespaceShared1.Name,
false,
false,
nil,
)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak2n2, err := h.CreatePreAuthKey(n2.Name, false, false, nil) preAuthKeyShared2, err := app.CreatePreAuthKey(
namespaceShared2.Name,
false,
false,
nil,
)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak3n3, err := h.CreatePreAuthKey(n3.Name, false, false, nil) preAuthKeyShared3, err := app.CreatePreAuthKey(
namespaceShared3.Name,
false,
false,
nil,
)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak4n1, err := h.CreatePreAuthKey(n1.Name, false, false, nil) preAuthKey2Shared1, err := app.CreatePreAuthKey(
namespaceShared1.Name,
false,
false,
nil,
)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = h.GetMachine(n1.Name, "test_get_shared_nodes_1") _, err = app.GetMachine(namespaceShared1.Name, "test_get_shared_nodes_1")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
m1 := &Machine{ machineInShared1 := &Machine{
ID: 1, ID: 1,
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
Name: "test_get_shared_nodes_1", Name: "test_get_shared_nodes_1",
NamespaceID: n1.ID, NamespaceID: namespaceShared1.ID,
Namespace: *n1, Namespace: *namespaceShared1,
Registered: true, Registered: true,
RegisterMethod: "authKey", RegisterMethod: "authKey",
IPAddress: "100.64.0.1", IPAddress: "100.64.0.1",
AuthKeyID: uint(pak1n1.ID), AuthKeyID: uint(preAuthKeyShared1.ID),
} }
h.db.Save(m1) app.db.Save(machineInShared1)
_, err = h.GetMachine(n1.Name, m1.Name) _, err = app.GetMachine(namespaceShared1.Name, machineInShared1.Name)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
m2 := &Machine{ machineInShared2 := &Machine{
ID: 2, ID: 2,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Name: "test_get_shared_nodes_2", Name: "test_get_shared_nodes_2",
NamespaceID: n2.ID, NamespaceID: namespaceShared2.ID,
Namespace: *n2, Namespace: *namespaceShared2,
Registered: true, Registered: true,
RegisterMethod: "authKey", RegisterMethod: "authKey",
IPAddress: "100.64.0.2", IPAddress: "100.64.0.2",
AuthKeyID: uint(pak2n2.ID), AuthKeyID: uint(preAuthKeyShared2.ID),
} }
h.db.Save(m2) app.db.Save(machineInShared2)
_, err = h.GetMachine(n2.Name, m2.Name) _, err = app.GetMachine(namespaceShared2.Name, machineInShared2.Name)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
m3 := &Machine{ machineInShared3 := &Machine{
ID: 3, ID: 3,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Name: "test_get_shared_nodes_3", Name: "test_get_shared_nodes_3",
NamespaceID: n3.ID, NamespaceID: namespaceShared3.ID,
Namespace: *n3, Namespace: *namespaceShared3,
Registered: true, Registered: true,
RegisterMethod: "authKey", RegisterMethod: "authKey",
IPAddress: "100.64.0.3", IPAddress: "100.64.0.3",
AuthKeyID: uint(pak3n3.ID), AuthKeyID: uint(preAuthKeyShared3.ID),
} }
h.db.Save(m3) app.db.Save(machineInShared3)
_, err = h.GetMachine(n3.Name, m3.Name) _, err = app.GetMachine(namespaceShared3.Name, machineInShared3.Name)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
m4 := &Machine{ machine2InShared1 := &Machine{
ID: 4, ID: 4,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Name: "test_get_shared_nodes_4", Name: "test_get_shared_nodes_4",
NamespaceID: n1.ID, NamespaceID: namespaceShared1.ID,
Namespace: *n1, Namespace: *namespaceShared1,
Registered: true, Registered: true,
RegisterMethod: "authKey", RegisterMethod: "authKey",
IPAddress: "100.64.0.4", IPAddress: "100.64.0.4",
AuthKeyID: uint(pak4n1.ID), AuthKeyID: uint(preAuthKey2Shared1.ID),
} }
h.db.Save(m4) app.db.Save(machine2InShared1)
err = h.AddSharedMachineToNamespace(m2, n1) err = app.AddSharedMachineToNamespace(machineInShared2, namespaceShared1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
m1peers, err := h.getPeers(m1) peersOfMachine1InShared1, err := app.getPeers(machineInShared1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
userProfiles := getMapResponseUserProfiles(*m1, m1peers) userProfiles := getMapResponseUserProfiles(
*machineInShared1,
peersOfMachine1InShared1,
)
log.Trace().Msgf("userProfiles %#v", userProfiles) log.Trace().Msgf("userProfiles %#v", userProfiles)
c.Assert(len(userProfiles), check.Equals, 2) c.Assert(len(userProfiles), check.Equals, 2)
found := false found := false
for _, up := range userProfiles { for _, userProfiles := range userProfiles {
if up.DisplayName == n1.Name { if userProfiles.DisplayName == namespaceShared1.Name {
found = true found = true
break break
} }
} }
c.Assert(found, check.Equals, true) c.Assert(found, check.Equals, true)
found = false found = false
for _, up := range userProfiles { for _, userProfile := range userProfiles {
if up.DisplayName == n2.Name { if userProfile.DisplayName == namespaceShared2.Name {
found = true found = true
break break
} }
} }

138
oidc.go
View file

@ -17,6 +17,12 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
const (
oidcStateCacheExpiration = time.Minute * 5
oidcStateCacheCleanupInterval = time.Minute * 10
randomByteSize = 16
)
type IDTokenClaims struct { type IDTokenClaims struct {
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
Groups []string `json:"groups,omitempty"` Groups []string `json:"groups,omitempty"`
@ -32,6 +38,7 @@ func (h *Headscale) initOIDC() error {
if err != nil { if err != nil {
log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error()) log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error())
return err return err
} }
@ -39,14 +46,20 @@ func (h *Headscale) initOIDC() error {
ClientID: h.cfg.OIDC.ClientID, ClientID: h.cfg.OIDC.ClientID,
ClientSecret: h.cfg.OIDC.ClientSecret, ClientSecret: h.cfg.OIDC.ClientSecret,
Endpoint: h.oidcProvider.Endpoint(), Endpoint: h.oidcProvider.Endpoint(),
RedirectURL: fmt.Sprintf("%s/oidc/callback", strings.TrimSuffix(h.cfg.ServerURL, "/")), RedirectURL: fmt.Sprintf(
"%s/oidc/callback",
strings.TrimSuffix(h.cfg.ServerURL, "/"),
),
Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
} }
} }
// init the state cache if it hasn't been already // init the state cache if it hasn't been already
if h.oidcStateCache == nil { if h.oidcStateCache == nil {
h.oidcStateCache = cache.New(time.Minute*5, time.Minute*10) h.oidcStateCache = cache.New(
oidcStateCacheExpiration,
oidcStateCacheCleanupInterval,
)
} }
return nil return nil
@ -54,50 +67,53 @@ func (h *Headscale) initOIDC() error {
// RegisterOIDC redirects to the OIDC provider for authentication // RegisterOIDC redirects to the OIDC provider for authentication
// Puts machine key in cache so the callback can retrieve it using the oidc state param // Puts machine key in cache so the callback can retrieve it using the oidc state param
// Listens in /oidc/register/:mKey // Listens in /oidc/register/:mKey.
func (h *Headscale) RegisterOIDC(c *gin.Context) { func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
mKeyStr := c.Param("mkey") mKeyStr := ctx.Param("mkey")
if mKeyStr == "" { if mKeyStr == "" {
c.String(http.StatusBadRequest, "Wrong params") ctx.String(http.StatusBadRequest, "Wrong params")
return return
} }
b := make([]byte, 16) randomBlob := make([]byte, randomByteSize)
_, err := rand.Read(b) if _, err := rand.Read(randomBlob); err != nil {
if err != nil {
log.Error().Msg("could not read 16 bytes from rand") log.Error().Msg("could not read 16 bytes from rand")
c.String(http.StatusInternalServerError, "could not read 16 bytes from rand") ctx.String(http.StatusInternalServerError, "could not read 16 bytes from rand")
return return
} }
stateStr := hex.EncodeToString(b)[:32] stateStr := hex.EncodeToString(randomBlob)[:32]
// place the machine key into the state cache, so it can be retrieved later // place the machine key into the state cache, so it can be retrieved later
h.oidcStateCache.Set(stateStr, mKeyStr, time.Minute*5) h.oidcStateCache.Set(stateStr, mKeyStr, oidcStateCacheExpiration)
authUrl := h.oauth2Config.AuthCodeURL(stateStr) authURL := h.oauth2Config.AuthCodeURL(stateStr)
log.Debug().Msgf("Redirecting to %s for authentication", authUrl) log.Debug().Msgf("Redirecting to %s for authentication", authURL)
c.Redirect(http.StatusFound, authUrl) ctx.Redirect(http.StatusFound, authURL)
} }
// OIDCCallback handles the callback from the OIDC endpoint // OIDCCallback handles the callback from the OIDC endpoint
// Retrieves the mkey from the state cache and adds the machine to the users email namespace // Retrieves the mkey from the state cache and adds the machine to the users email namespace
// TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities // TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities
// TODO: Add groups information from OIDC tokens into machine HostInfo // TODO: Add groups information from OIDC tokens into machine HostInfo
// Listens in /oidc/callback // Listens in /oidc/callback.
func (h *Headscale) OIDCCallback(c *gin.Context) { func (h *Headscale) OIDCCallback(ctx *gin.Context) {
code := c.Query("code") code := ctx.Query("code")
state := c.Query("state") state := ctx.Query("state")
if code == "" || state == "" { if code == "" || state == "" {
c.String(http.StatusBadRequest, "Wrong params") ctx.String(http.StatusBadRequest, "Wrong params")
return return
} }
oauth2Token, err := h.oauth2Config.Exchange(context.Background(), code) oauth2Token, err := h.oauth2Config.Exchange(context.Background(), code)
if err != nil { if err != nil {
c.String(http.StatusBadRequest, "Could not exchange code for token") ctx.String(http.StatusBadRequest, "Could not exchange code for token")
return return
} }
@ -105,7 +121,8 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string) rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string)
if !rawIDTokenOK { if !rawIDTokenOK {
c.String(http.StatusBadRequest, "Could not extract ID Token") ctx.String(http.StatusBadRequest, "Could not extract ID Token")
return return
} }
@ -113,7 +130,8 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
idToken, err := verifier.Verify(context.Background(), rawIDToken) idToken, err := verifier.Verify(context.Background(), rawIDToken)
if err != nil { if err != nil {
c.String(http.StatusBadRequest, "Failed to verify id token: %s", err.Error()) ctx.String(http.StatusBadRequest, "Failed to verify id token: %s", err.Error())
return return
} }
@ -127,7 +145,11 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
// Extract custom claims // Extract custom claims
var claims IDTokenClaims var claims IDTokenClaims
if err = idToken.Claims(&claims); err != nil { if err = idToken.Claims(&claims); err != nil {
c.String(http.StatusBadRequest, fmt.Sprintf("Failed to decode id token claims: %s", err)) ctx.String(
http.StatusBadRequest,
fmt.Sprintf("Failed to decode id token claims: %s", err),
)
return return
} }
@ -135,62 +157,80 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
mKeyIf, mKeyFound := h.oidcStateCache.Get(state) mKeyIf, mKeyFound := h.oidcStateCache.Get(state)
if !mKeyFound { if !mKeyFound {
log.Error().Msg("requested machine state key expired before authorisation completed") log.Error().
c.String(http.StatusBadRequest, "state has expired") Msg("requested machine state key expired before authorisation completed")
ctx.String(http.StatusBadRequest, "state has expired")
return return
} }
mKeyStr, mKeyOK := mKeyIf.(string) mKeyStr, mKeyOK := mKeyIf.(string)
if !mKeyOK { if !mKeyOK {
log.Error().Msg("could not get machine key from cache") log.Error().Msg("could not get machine key from cache")
c.String(http.StatusInternalServerError, "could not get machine key from cache") ctx.String(
http.StatusInternalServerError,
"could not get machine key from cache",
)
return return
} }
// retrieve machine information // retrieve machine information
m, err := h.GetMachineByMachineKey(mKeyStr) machine, err := h.GetMachineByMachineKey(mKeyStr)
if err != nil { if err != nil {
log.Error().Msg("machine key not found in database") log.Error().Msg("machine key not found in database")
c.String(http.StatusInternalServerError, "could not get machine info from database") ctx.String(
http.StatusInternalServerError,
"could not get machine info from database",
)
return return
} }
now := time.Now().UTC() now := time.Now().UTC()
if nsName, ok := h.getNamespaceFromEmail(claims.Email); ok { if namespaceName, ok := h.getNamespaceFromEmail(claims.Email); ok {
// register the machine if it's new // register the machine if it's new
if !m.Registered { if !machine.Registered {
log.Debug().Msg("Registering new machine after successful callback") log.Debug().Msg("Registering new machine after successful callback")
ns, err := h.GetNamespace(nsName) namespace, err := h.GetNamespace(namespaceName)
if err != nil { if err != nil {
ns, err = h.CreateNamespace(nsName) namespace, err = h.CreateNamespace(namespaceName)
if err != nil { if err != nil {
log.Error().Msgf("could not create new namespace '%s'", claims.Email) log.Error().
c.String(http.StatusInternalServerError, "could not create new namespace") Msgf("could not create new namespace '%s'", claims.Email)
ctx.String(
http.StatusInternalServerError,
"could not create new namespace",
)
return return
} }
} }
ip, err := h.getAvailableIP() ip, err := h.getAvailableIP()
if err != nil { if err != nil {
c.String(http.StatusInternalServerError, "could not get an IP from the pool") ctx.String(
http.StatusInternalServerError,
"could not get an IP from the pool",
)
return return
} }
m.IPAddress = ip.String() machine.IPAddress = ip.String()
m.NamespaceID = ns.ID machine.NamespaceID = namespace.ID
m.Registered = true machine.Registered = true
m.RegisterMethod = "oidc" machine.RegisterMethod = "oidc"
m.LastSuccessfulUpdate = &now machine.LastSuccessfulUpdate = &now
h.db.Save(&m) h.db.Save(&machine)
} }
h.updateMachineExpiry(m) h.updateMachineExpiry(machine)
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(` ctx.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
<html> <html>
<body> <body>
<h1>headscale</h1> <h1>headscale</h1>
@ -201,15 +241,17 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
</html> </html>
`, claims.Email))) `, claims.Email)))
} }
log.Error(). log.Error().
Str("email", claims.Email). Str("email", claims.Email).
Str("username", claims.Username). Str("username", claims.Username).
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Email could not be mapped to a namespace") Msg("Email could not be mapped to a namespace")
c.String(http.StatusBadRequest, "email from claim could not be mapped to a namespace") ctx.String(
http.StatusBadRequest,
"email from claim could not be mapped to a namespace",
)
} }
// getNamespaceFromEmail passes the users email through a list of "matchers" // getNamespaceFromEmail passes the users email through a list of "matchers"

View file

@ -145,29 +145,37 @@ func TestHeadscale_getNamespaceFromEmail(t *testing.T) {
}, },
} }
//nolint //nolint
for _, tt := range tests { for _, test := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
h := &Headscale{ app := &Headscale{
cfg: tt.fields.cfg, cfg: test.fields.cfg,
db: tt.fields.db, db: test.fields.db,
dbString: tt.fields.dbString, dbString: test.fields.dbString,
dbType: tt.fields.dbType, dbType: test.fields.dbType,
dbDebug: tt.fields.dbDebug, dbDebug: test.fields.dbDebug,
publicKey: tt.fields.publicKey, publicKey: test.fields.publicKey,
privateKey: tt.fields.privateKey, privateKey: test.fields.privateKey,
aclPolicy: tt.fields.aclPolicy, aclPolicy: test.fields.aclPolicy,
aclRules: tt.fields.aclRules, aclRules: test.fields.aclRules,
lastStateChange: tt.fields.lastStateChange, lastStateChange: test.fields.lastStateChange,
oidcProvider: tt.fields.oidcProvider, oidcProvider: test.fields.oidcProvider,
oauth2Config: tt.fields.oauth2Config, oauth2Config: test.fields.oauth2Config,
oidcStateCache: tt.fields.oidcStateCache, oidcStateCache: test.fields.oidcStateCache,
} }
got, got1 := h.getNamespaceFromEmail(tt.args.email) got, got1 := app.getNamespaceFromEmail(test.args.email)
if got != tt.want { if got != test.want {
t.Errorf("Headscale.getNamespaceFromEmail() got = %v, want %v", got, tt.want) t.Errorf(
"Headscale.getNamespaceFromEmail() got = %v, want %v",
got,
test.want,
)
} }
if got1 != tt.want1 { if got1 != test.want1 {
t.Errorf("Headscale.getNamespaceFromEmail() got1 = %v, want %v", got1, tt.want1) t.Errorf(
"Headscale.getNamespaceFromEmail() got1 = %v, want %v",
got1,
test.want1,
)
} }
}) })
} }

263
poll.go
View file

@ -15,6 +15,11 @@ import (
"tailscale.com/types/wgkey" "tailscale.com/types/wgkey"
) )
const (
keepAliveInterval = 60 * time.Second
updateCheckInterval = 10 * time.Second
)
// PollNetMapHandler takes care of /machine/:id/map // PollNetMapHandler takes care of /machine/:id/map
// //
// This is the busiest endpoint, as it keeps the HTTP long poll that updates // This is the busiest endpoint, as it keeps the HTTP long poll that updates
@ -24,20 +29,21 @@ import (
// only after their first request (marked with the ReadOnly field). // only after their first request (marked with the ReadOnly field).
// //
// At this moment the updates are sent in a quite horrendous way, but they kinda work. // At this moment the updates are sent in a quite horrendous way, but they kinda work.
func (h *Headscale) PollNetMapHandler(c *gin.Context) { func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
log.Trace(). log.Trace().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("id", c.Param("id")). Str("id", ctx.Param("id")).
Msg("PollNetMapHandler called") Msg("PollNetMapHandler called")
body, _ := io.ReadAll(c.Request.Body) body, _ := io.ReadAll(ctx.Request.Body)
mKeyStr := c.Param("id") mKeyStr := ctx.Param("id")
mKey, err := wgkey.ParseHex(mKeyStr) mKey, err := wgkey.ParseHex(mKeyStr)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Err(err). Err(err).
Msg("Cannot parse client key") Msg("Cannot parse client key")
c.String(http.StatusBadRequest, "") ctx.String(http.StatusBadRequest, "")
return return
} }
req := tailcfg.MapRequest{} req := tailcfg.MapRequest{}
@ -47,34 +53,36 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Err(err). Err(err).
Msg("Cannot decode message") Msg("Cannot decode message")
c.String(http.StatusBadRequest, "") ctx.String(http.StatusBadRequest, "")
return return
} }
m, err := h.GetMachineByMachineKey(mKey.HexString()) machine, err := h.GetMachineByMachineKey(mKey.HexString())
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
log.Warn(). log.Warn().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Msgf("Ignoring request, cannot find machine with key %s", mKey.HexString()) Msgf("Ignoring request, cannot find machine with key %s", mKey.HexString())
c.String(http.StatusUnauthorized, "") ctx.String(http.StatusUnauthorized, "")
return return
} }
log.Error(). log.Error().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Msgf("Failed to fetch machine from the database with Machine key: %s", mKey.HexString()) Msgf("Failed to fetch machine from the database with Machine key: %s", mKey.HexString())
c.String(http.StatusInternalServerError, "") ctx.String(http.StatusInternalServerError, "")
} }
log.Trace(). log.Trace().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("id", c.Param("id")). Str("id", ctx.Param("id")).
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Found machine in database") Msg("Found machine in database")
hostinfo, _ := json.Marshal(req.Hostinfo) hostinfo, _ := json.Marshal(req.Hostinfo)
m.Name = req.Hostinfo.Hostname machine.Name = req.Hostinfo.Hostname
m.HostInfo = datatypes.JSON(hostinfo) machine.HostInfo = datatypes.JSON(hostinfo)
m.DiscoKey = wgkey.Key(req.DiscoKey).HexString() machine.DiscoKey = wgkey.Key(req.DiscoKey).HexString()
now := time.Now().UTC() now := time.Now().UTC()
// From Tailscale client: // From Tailscale client:
@ -87,20 +95,21 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
// before their first real endpoint update. // before their first real endpoint update.
if !req.ReadOnly { if !req.ReadOnly {
endpoints, _ := json.Marshal(req.Endpoints) endpoints, _ := json.Marshal(req.Endpoints)
m.Endpoints = datatypes.JSON(endpoints) machine.Endpoints = datatypes.JSON(endpoints)
m.LastSeen = &now machine.LastSeen = &now
} }
h.db.Save(&m) h.db.Save(&machine)
data, err := h.getMapResponse(mKey, req, m) data, err := h.getMapResponse(mKey, req, machine)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("id", c.Param("id")). Str("id", ctx.Param("id")).
Str("machine", m.Name). Str("machine", machine.Name).
Err(err). Err(err).
Msg("Failed to get Map response") Msg("Failed to get Map response")
c.String(http.StatusInternalServerError, ":(") ctx.String(http.StatusInternalServerError, ":(")
return return
} }
@ -111,8 +120,8 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
// Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696 // Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696
log.Debug(). log.Debug().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("id", c.Param("id")). Str("id", ctx.Param("id")).
Str("machine", m.Name). Str("machine", machine.Name).
Bool("readOnly", req.ReadOnly). Bool("readOnly", req.ReadOnly).
Bool("omitPeers", req.OmitPeers). Bool("omitPeers", req.OmitPeers).
Bool("stream", req.Stream). Bool("stream", req.Stream).
@ -121,15 +130,16 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
if req.ReadOnly { if req.ReadOnly {
log.Info(). log.Info().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Client is starting up. Probably interested in a DERP map") Msg("Client is starting up. Probably interested in a DERP map")
c.Data(200, "application/json; charset=utf-8", data) ctx.Data(http.StatusOK, "application/json; charset=utf-8", data)
return return
} }
// There has been an update to _any_ of the nodes that the other nodes would // There has been an update to _any_ of the nodes that the other nodes would
// need to know about // need to know about
h.setLastStateChangeToNow(m.Namespace.Name) h.setLastStateChangeToNow(machine.Namespace.Name)
// The request is not ReadOnly, so we need to set up channels for updating // The request is not ReadOnly, so we need to set up channels for updating
// peers via longpoll // peers via longpoll
@ -137,8 +147,8 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
// Only create update channel if it has not been created // Only create update channel if it has not been created
log.Trace(). log.Trace().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("id", c.Param("id")). Str("id", ctx.Param("id")).
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Loading or creating update channel") Msg("Loading or creating update channel")
updateChan := make(chan struct{}) updateChan := make(chan struct{})
@ -152,46 +162,59 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
if req.OmitPeers && !req.Stream { if req.OmitPeers && !req.Stream {
log.Info(). log.Info().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Client sent endpoint update and is ok with a response without peer list") Msg("Client sent endpoint update and is ok with a response without peer list")
c.Data(200, "application/json; charset=utf-8", data) ctx.Data(http.StatusOK, "application/json; charset=utf-8", data)
// It sounds like we should update the nodes when we have received a endpoint update // It sounds like we should update the nodes when we have received a endpoint update
// even tho the comments in the tailscale code dont explicitly say so. // even tho the comments in the tailscale code dont explicitly say so.
updateRequestsFromNode.WithLabelValues(m.Name, m.Namespace.Name, "endpoint-update").Inc() updateRequestsFromNode.WithLabelValues(machine.Name, machine.Namespace.Name, "endpoint-update").
Inc()
go func() { updateChan <- struct{}{} }() go func() { updateChan <- struct{}{} }()
return return
} else if req.OmitPeers && req.Stream { } else if req.OmitPeers && req.Stream {
log.Warn(). log.Warn().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Ignoring request, don't know how to handle it") Msg("Ignoring request, don't know how to handle it")
c.String(http.StatusBadRequest, "") ctx.String(http.StatusBadRequest, "")
return return
} }
log.Info(). log.Info().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Client is ready to access the tailnet") Msg("Client is ready to access the tailnet")
log.Info(). log.Info().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Sending initial map") Msg("Sending initial map")
go func() { pollDataChan <- data }() go func() { pollDataChan <- data }()
log.Info(). log.Info().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Notifying peers") Msg("Notifying peers")
updateRequestsFromNode.WithLabelValues(m.Name, m.Namespace.Name, "full-update").Inc() updateRequestsFromNode.WithLabelValues(machine.Name, machine.Namespace.Name, "full-update").
Inc()
go func() { updateChan <- struct{}{} }() go func() { updateChan <- struct{}{} }()
h.PollNetMapStream(c, m, req, mKey, pollDataChan, keepAliveChan, updateChan, cancelKeepAlive) h.PollNetMapStream(
ctx,
machine,
req,
mKey,
pollDataChan,
keepAliveChan,
updateChan,
cancelKeepAlive,
)
log.Trace(). log.Trace().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("id", c.Param("id")). Str("id", ctx.Param("id")).
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Finished stream, closing PollNetMap session") Msg("Finished stream, closing PollNetMap session")
} }
@ -199,165 +222,181 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
// stream logic, ensuring we communicate updates and data // stream logic, ensuring we communicate updates and data
// to the connected clients. // to the connected clients.
func (h *Headscale) PollNetMapStream( func (h *Headscale) PollNetMapStream(
c *gin.Context, ctx *gin.Context,
m *Machine, machine *Machine,
req tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
mKey wgkey.Key, machineKey wgkey.Key,
pollDataChan chan []byte, pollDataChan chan []byte,
keepAliveChan chan []byte, keepAliveChan chan []byte,
updateChan chan struct{}, updateChan chan struct{},
cancelKeepAlive chan struct{}, cancelKeepAlive chan struct{},
) { ) {
go h.scheduledPollWorker(cancelKeepAlive, updateChan, keepAliveChan, mKey, req, m) go h.scheduledPollWorker(
cancelKeepAlive,
updateChan,
keepAliveChan,
machineKey,
mapRequest,
machine,
)
c.Stream(func(w io.Writer) bool { ctx.Stream(func(writer io.Writer) bool {
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Waiting for data to stream...") Msg("Waiting for data to stream...")
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Msgf("pollData is %#v, keepAliveChan is %#v, updateChan is %#v", pollDataChan, keepAliveChan, updateChan) Msgf("pollData is %#v, keepAliveChan is %#v, updateChan is %#v", pollDataChan, keepAliveChan, updateChan)
select { select {
case data := <-pollDataChan: case data := <-pollDataChan:
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "pollData"). Str("channel", "pollData").
Int("bytes", len(data)). Int("bytes", len(data)).
Msg("Sending data received via pollData channel") Msg("Sending data received via pollData channel")
_, err := w.Write(data) _, err := writer.Write(data)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "pollData"). Str("channel", "pollData").
Err(err). Err(err).
Msg("Cannot write data") Msg("Cannot write data")
return false return false
} }
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "pollData"). Str("channel", "pollData").
Int("bytes", len(data)). Int("bytes", len(data)).
Msg("Data from pollData channel written successfully") Msg("Data from pollData channel written successfully")
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions // TODO(kradalby): Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from // when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten. // command line, but then overwritten.
err = h.UpdateMachine(m) err = h.UpdateMachine(machine)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "pollData"). Str("channel", "pollData").
Err(err). Err(err).
Msg("Cannot update machine from database") Msg("Cannot update machine from database")
} }
now := time.Now().UTC() now := time.Now().UTC()
m.LastSeen = &now machine.LastSeen = &now
lastStateUpdate.WithLabelValues(m.Namespace.Name, m.Name).Set(float64(now.Unix())) lastStateUpdate.WithLabelValues(machine.Namespace.Name, machine.Name).
m.LastSuccessfulUpdate = &now Set(float64(now.Unix()))
machine.LastSuccessfulUpdate = &now
h.db.Save(&m) h.db.Save(&machine)
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "pollData"). Str("channel", "pollData").
Int("bytes", len(data)). Int("bytes", len(data)).
Msg("Machine entry in database updated successfully after sending pollData") Msg("Machine entry in database updated successfully after sending pollData")
return true return true
case data := <-keepAliveChan: case data := <-keepAliveChan:
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "keepAlive"). Str("channel", "keepAlive").
Int("bytes", len(data)). Int("bytes", len(data)).
Msg("Sending keep alive message") Msg("Sending keep alive message")
_, err := w.Write(data) _, err := writer.Write(data)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "keepAlive"). Str("channel", "keepAlive").
Err(err). Err(err).
Msg("Cannot write keep alive message") Msg("Cannot write keep alive message")
return false return false
} }
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "keepAlive"). Str("channel", "keepAlive").
Int("bytes", len(data)). Int("bytes", len(data)).
Msg("Keep alive sent successfully") Msg("Keep alive sent successfully")
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions // TODO(kradalby): Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from // when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten. // command line, but then overwritten.
err = h.UpdateMachine(m) err = h.UpdateMachine(machine)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "keepAlive"). Str("channel", "keepAlive").
Err(err). Err(err).
Msg("Cannot update machine from database") Msg("Cannot update machine from database")
} }
now := time.Now().UTC() now := time.Now().UTC()
m.LastSeen = &now machine.LastSeen = &now
h.db.Save(&m) h.db.Save(&machine)
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "keepAlive"). Str("channel", "keepAlive").
Int("bytes", len(data)). Int("bytes", len(data)).
Msg("Machine updated successfully after sending keep alive") Msg("Machine updated successfully after sending keep alive")
return true return true
case <-updateChan: case <-updateChan:
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "update"). Str("channel", "update").
Msg("Received a request for update") Msg("Received a request for update")
updateRequestsReceivedOnChannel.WithLabelValues(m.Name, m.Namespace.Name).Inc() updateRequestsReceivedOnChannel.WithLabelValues(machine.Name, machine.Namespace.Name).
if h.isOutdated(m) { Inc()
if h.isOutdated(machine) {
log.Debug(). log.Debug().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Time("last_successful_update", *m.LastSuccessfulUpdate). Time("last_successful_update", *machine.LastSuccessfulUpdate).
Time("last_state_change", h.getLastStateChange(m.Namespace.Name)). Time("last_state_change", h.getLastStateChange(machine.Namespace.Name)).
Msgf("There has been updates since the last successful update to %s", m.Name) Msgf("There has been updates since the last successful update to %s", machine.Name)
data, err := h.getMapResponse(mKey, req, m) data, err := h.getMapResponse(machineKey, mapRequest, machine)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "update"). Str("channel", "update").
Err(err). Err(err).
Msg("Could not get the map update") Msg("Could not get the map update")
} }
_, err = w.Write(data) _, err = writer.Write(data)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "update"). Str("channel", "update").
Err(err). Err(err).
Msg("Could not write the map response") Msg("Could not write the map response")
updateRequestsSentToNode.WithLabelValues(m.Name, m.Namespace.Name, "failed").Inc() updateRequestsSentToNode.WithLabelValues(machine.Name, machine.Namespace.Name, "failed").
Inc()
return false return false
} }
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "update"). Str("channel", "update").
Msg("Updated Map has been sent") Msg("Updated Map has been sent")
updateRequestsSentToNode.WithLabelValues(m.Name, m.Namespace.Name, "success").Inc() updateRequestsSentToNode.WithLabelValues(machine.Name, machine.Namespace.Name, "success").
Inc()
// Keep track of the last successful update, // Keep track of the last successful update,
// we sometimes end in a state were the update // we sometimes end in a state were the update
@ -366,62 +405,64 @@ func (h *Headscale) PollNetMapStream(
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions // TODO(kradalby): Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from // when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten. // command line, but then overwritten.
err = h.UpdateMachine(m) err = h.UpdateMachine(machine)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "update"). Str("channel", "update").
Err(err). Err(err).
Msg("Cannot update machine from database") Msg("Cannot update machine from database")
} }
now := time.Now().UTC() now := time.Now().UTC()
lastStateUpdate.WithLabelValues(m.Namespace.Name, m.Name).Set(float64(now.Unix())) lastStateUpdate.WithLabelValues(machine.Namespace.Name, machine.Name).
m.LastSuccessfulUpdate = &now Set(float64(now.Unix()))
machine.LastSuccessfulUpdate = &now
h.db.Save(&m) h.db.Save(&machine)
} else { } else {
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Time("last_successful_update", *m.LastSuccessfulUpdate). Time("last_successful_update", *machine.LastSuccessfulUpdate).
Time("last_state_change", h.getLastStateChange(m.Namespace.Name)). Time("last_state_change", h.getLastStateChange(machine.Namespace.Name)).
Msgf("%s is up to date", m.Name) Msgf("%s is up to date", machine.Name)
} }
return true return true
case <-c.Request.Context().Done(): case <-ctx.Request.Context().Done():
log.Info(). log.Info().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Msg("The client has closed the connection") Msg("The client has closed the connection")
// TODO: Abstract away all the database calls, this can cause race conditions // TODO: Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from // when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten. // command line, but then overwritten.
err := h.UpdateMachine(m) err := h.UpdateMachine(machine)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "Done"). Str("channel", "Done").
Err(err). Err(err).
Msg("Cannot update machine from database") Msg("Cannot update machine from database")
} }
now := time.Now().UTC() now := time.Now().UTC()
m.LastSeen = &now machine.LastSeen = &now
h.db.Save(&m) h.db.Save(&machine)
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "Done"). Str("channel", "Done").
Msg("Cancelling keepAlive channel") Msg("Cancelling keepAlive channel")
cancelKeepAlive <- struct{}{} cancelKeepAlive <- struct{}{}
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "Done"). Str("channel", "Done").
Msg("Closing update channel") Msg("Closing update channel")
// h.closeUpdateChannel(m) // h.closeUpdateChannel(m)
@ -429,14 +470,14 @@ func (h *Headscale) PollNetMapStream(
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "Done"). Str("channel", "Done").
Msg("Closing pollData channel") Msg("Closing pollData channel")
close(pollDataChan) close(pollDataChan)
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "Done"). Str("channel", "Done").
Msg("Closing keepAliveChan channel") Msg("Closing keepAliveChan channel")
close(keepAliveChan) close(keepAliveChan)
@ -450,12 +491,12 @@ func (h *Headscale) scheduledPollWorker(
cancelChan <-chan struct{}, cancelChan <-chan struct{},
updateChan chan<- struct{}, updateChan chan<- struct{},
keepAliveChan chan<- []byte, keepAliveChan chan<- []byte,
mKey wgkey.Key, machineKey wgkey.Key,
req tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
m *Machine, machine *Machine,
) { ) {
keepAliveTicker := time.NewTicker(60 * time.Second) keepAliveTicker := time.NewTicker(keepAliveInterval)
updateCheckerTicker := time.NewTicker(10 * time.Second) updateCheckerTicker := time.NewTicker(updateCheckInterval)
for { for {
select { select {
@ -463,27 +504,29 @@ func (h *Headscale) scheduledPollWorker(
return return
case <-keepAliveTicker.C: case <-keepAliveTicker.C:
data, err := h.getMapKeepAliveResponse(mKey, req, m) data, err := h.getMapKeepAliveResponse(machineKey, mapRequest)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("func", "keepAlive"). Str("func", "keepAlive").
Err(err). Err(err).
Msg("Error generating the keep alive msg") Msg("Error generating the keep alive msg")
return return
} }
log.Debug(). log.Debug().
Str("func", "keepAlive"). Str("func", "keepAlive").
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Sending keepalive") Msg("Sending keepalive")
keepAliveChan <- data keepAliveChan <- data
case <-updateCheckerTicker.C: case <-updateCheckerTicker.C:
log.Debug(). log.Debug().
Str("func", "scheduledPollWorker"). Str("func", "scheduledPollWorker").
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Sending update request") Msg("Sending update request")
updateRequestsFromNode.WithLabelValues(m.Name, m.Namespace.Name, "scheduled-update").Inc() updateRequestsFromNode.WithLabelValues(machine.Name, machine.Namespace.Name, "scheduled-update").
Inc()
updateChan <- struct{}{} updateChan <- struct{}{}
} }
} }

View file

@ -7,19 +7,19 @@ import (
"strconv" "strconv"
"time" "time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
"gorm.io/gorm" "gorm.io/gorm"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
) )
const ( const (
errorAuthKeyNotFound = Error("AuthKey not found") errPreAuthKeyNotFound = Error("AuthKey not found")
errorAuthKeyExpired = Error("AuthKey expired") errPreAuthKeyExpired = Error("AuthKey expired")
errSingleUseAuthKeyHasBeenUsed = Error("AuthKey has already been used") errSingleUseAuthKeyHasBeenUsed = Error("AuthKey has already been used")
errNamespaceMismatch = Error("namespace mismatch")
) )
// PreAuthKey describes a pre-authorization key usable in a particular namespace // PreAuthKey describes a pre-authorization key usable in a particular namespace.
type PreAuthKey struct { type PreAuthKey struct {
ID uint64 `gorm:"primary_key"` ID uint64 `gorm:"primary_key"`
Key string Key string
@ -33,14 +33,14 @@ type PreAuthKey struct {
Expiration *time.Time Expiration *time.Time
} }
// CreatePreAuthKey creates a new PreAuthKey in a namespace, and returns it // CreatePreAuthKey creates a new PreAuthKey in a namespace, and returns it.
func (h *Headscale) CreatePreAuthKey( func (h *Headscale) CreatePreAuthKey(
namespaceName string, namespaceName string,
reusable bool, reusable bool,
ephemeral bool, ephemeral bool,
expiration *time.Time, expiration *time.Time,
) (*PreAuthKey, error) { ) (*PreAuthKey, error) {
n, err := h.GetNamespace(namespaceName) namespace, err := h.GetNamespace(namespaceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -51,35 +51,36 @@ func (h *Headscale) CreatePreAuthKey(
return nil, err return nil, err
} }
k := PreAuthKey{ key := PreAuthKey{
Key: kstr, Key: kstr,
NamespaceID: n.ID, NamespaceID: namespace.ID,
Namespace: *n, Namespace: *namespace,
Reusable: reusable, Reusable: reusable,
Ephemeral: ephemeral, Ephemeral: ephemeral,
CreatedAt: &now, CreatedAt: &now,
Expiration: expiration, Expiration: expiration,
} }
h.db.Save(&k) h.db.Save(&key)
return &k, nil return &key, nil
} }
// ListPreAuthKeys returns the list of PreAuthKeys for a namespace // ListPreAuthKeys returns the list of PreAuthKeys for a namespace.
func (h *Headscale) ListPreAuthKeys(namespaceName string) ([]PreAuthKey, error) { func (h *Headscale) ListPreAuthKeys(namespaceName string) ([]PreAuthKey, error) {
n, err := h.GetNamespace(namespaceName) namespace, err := h.GetNamespace(namespaceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
keys := []PreAuthKey{} keys := []PreAuthKey{}
if err := h.db.Preload("Namespace").Where(&PreAuthKey{NamespaceID: n.ID}).Find(&keys).Error; err != nil { if err := h.db.Preload("Namespace").Where(&PreAuthKey{NamespaceID: namespace.ID}).Find(&keys).Error; err != nil {
return nil, err return nil, err
} }
return keys, nil return keys, nil
} }
// GetPreAuthKey returns a PreAuthKey for a given key // GetPreAuthKey returns a PreAuthKey for a given key.
func (h *Headscale) GetPreAuthKey(namespace string, key string) (*PreAuthKey, error) { func (h *Headscale) GetPreAuthKey(namespace string, key string) (*PreAuthKey, error) {
pak, err := h.checkKeyValidity(key) pak, err := h.checkKeyValidity(key)
if err != nil { if err != nil {
@ -87,7 +88,7 @@ func (h *Headscale) GetPreAuthKey(namespace string, key string) (*PreAuthKey, er
} }
if pak.Namespace.Name != namespace { if pak.Namespace.Name != namespace {
return nil, errors.New("Namespace mismatch") return nil, errNamespaceMismatch
} }
return pak, nil return pak, nil
@ -95,32 +96,36 @@ func (h *Headscale) GetPreAuthKey(namespace string, key string) (*PreAuthKey, er
// DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey // DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey
// does not exist. // does not exist.
func (h *Headscale) DestroyPreAuthKey(pak *PreAuthKey) error { func (h *Headscale) DestroyPreAuthKey(pak PreAuthKey) error {
if result := h.db.Unscoped().Delete(&pak); result.Error != nil { if result := h.db.Unscoped().Delete(pak); result.Error != nil {
return result.Error return result.Error
} }
return nil return nil
} }
// MarkExpirePreAuthKey marks a PreAuthKey as expired // MarkExpirePreAuthKey marks a PreAuthKey as expired.
func (h *Headscale) ExpirePreAuthKey(k *PreAuthKey) error { func (h *Headscale) ExpirePreAuthKey(k *PreAuthKey) error {
if err := h.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil { if err := h.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
return err return err
} }
return nil return nil
} }
// checkKeyValidity does the heavy lifting for validation of the PreAuthKey coming from a node // checkKeyValidity does the heavy lifting for validation of the PreAuthKey coming from a node
// If returns no error and a PreAuthKey, it can be used // If returns no error and a PreAuthKey, it can be used.
func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) { func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) {
pak := PreAuthKey{} pak := PreAuthKey{}
if result := h.db.Preload("Namespace").First(&pak, "key = ?", k); errors.Is(result.Error, gorm.ErrRecordNotFound) { if result := h.db.Preload("Namespace").First(&pak, "key = ?", k); errors.Is(
return nil, errorAuthKeyNotFound result.Error,
gorm.ErrRecordNotFound,
) {
return nil, errPreAuthKeyNotFound
} }
if pak.Expiration != nil && pak.Expiration.Before(time.Now()) { if pak.Expiration != nil && pak.Expiration.Before(time.Now()) {
return nil, errorAuthKeyExpired return nil, errPreAuthKeyExpired
} }
if pak.Reusable || pak.Ephemeral { // we don't need to check if has been used before if pak.Reusable || pak.Ephemeral { // we don't need to check if has been used before
@ -145,13 +150,14 @@ func (h *Headscale) generateKey() (string, error) {
if _, err := rand.Read(bytes); err != nil { if _, err := rand.Read(bytes); err != nil {
return "", err return "", err
} }
return hex.EncodeToString(bytes), nil return hex.EncodeToString(bytes), nil
} }
func (key *PreAuthKey) toProto() *v1.PreAuthKey { func (key *PreAuthKey) toProto() *v1.PreAuthKey {
protoKey := v1.PreAuthKey{ protoKey := v1.PreAuthKey{
Namespace: key.Namespace.Name, Namespace: key.Namespace.Name,
Id: strconv.FormatUint(key.ID, 10), Id: strconv.FormatUint(key.ID, Base10),
Key: key.Key, Key: key.Key,
Ephemeral: key.Ephemeral, Ephemeral: key.Ephemeral,
Reusable: key.Reusable, Reusable: key.Reusable,

View file

@ -7,189 +7,189 @@ import (
) )
func (*Suite) TestCreatePreAuthKey(c *check.C) { func (*Suite) TestCreatePreAuthKey(c *check.C) {
_, err := h.CreatePreAuthKey("bogus", true, false, nil) _, err := app.CreatePreAuthKey("bogus", true, false, nil)
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
n, err := h.CreateNamespace("test") namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
k, err := h.CreatePreAuthKey(n.Name, true, false, nil) key, err := app.CreatePreAuthKey(namespace.Name, true, false, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
// Did we get a valid key? // Did we get a valid key?
c.Assert(k.Key, check.NotNil) c.Assert(key.Key, check.NotNil)
c.Assert(len(k.Key), check.Equals, 48) c.Assert(len(key.Key), check.Equals, 48)
// Make sure the Namespace association is populated // Make sure the Namespace association is populated
c.Assert(k.Namespace.Name, check.Equals, n.Name) c.Assert(key.Namespace.Name, check.Equals, namespace.Name)
_, err = h.ListPreAuthKeys("bogus") _, err = app.ListPreAuthKeys("bogus")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
keys, err := h.ListPreAuthKeys(n.Name) keys, err := app.ListPreAuthKeys(namespace.Name)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(keys), check.Equals, 1) c.Assert(len(keys), check.Equals, 1)
// Make sure the Namespace association is populated // Make sure the Namespace association is populated
c.Assert((keys)[0].Namespace.Name, check.Equals, n.Name) c.Assert((keys)[0].Namespace.Name, check.Equals, namespace.Name)
} }
func (*Suite) TestExpiredPreAuthKey(c *check.C) { func (*Suite) TestExpiredPreAuthKey(c *check.C) {
n, err := h.CreateNamespace("test2") namespace, err := app.CreateNamespace("test2")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
now := time.Now() now := time.Now()
pak, err := h.CreatePreAuthKey(n.Name, true, false, &now) pak, err := app.CreatePreAuthKey(namespace.Name, true, false, &now)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
p, err := h.checkKeyValidity(pak.Key) key, err := app.checkKeyValidity(pak.Key)
c.Assert(err, check.Equals, errorAuthKeyExpired) c.Assert(err, check.Equals, errPreAuthKeyExpired)
c.Assert(p, check.IsNil) c.Assert(key, check.IsNil)
} }
func (*Suite) TestPreAuthKeyDoesNotExist(c *check.C) { func (*Suite) TestPreAuthKeyDoesNotExist(c *check.C) {
p, err := h.checkKeyValidity("potatoKey") key, err := app.checkKeyValidity("potatoKey")
c.Assert(err, check.Equals, errorAuthKeyNotFound) c.Assert(err, check.Equals, errPreAuthKeyNotFound)
c.Assert(p, check.IsNil) c.Assert(key, check.IsNil)
} }
func (*Suite) TestValidateKeyOk(c *check.C) { func (*Suite) TestValidateKeyOk(c *check.C) {
n, err := h.CreateNamespace("test3") namespace, err := app.CreateNamespace("test3")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, true, false, nil) pak, err := app.CreatePreAuthKey(namespace.Name, true, false, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
p, err := h.checkKeyValidity(pak.Key) key, err := app.checkKeyValidity(pak.Key)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(p.ID, check.Equals, pak.ID) c.Assert(key.ID, check.Equals, pak.ID)
} }
func (*Suite) TestAlreadyUsedKey(c *check.C) { func (*Suite) TestAlreadyUsedKey(c *check.C) {
n, err := h.CreateNamespace("test4") namespace, err := app.CreateNamespace("test4")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
m := Machine{ machine := Machine{
ID: 0, ID: 0,
MachineKey: "foo", MachineKey: "foo",
NodeKey: "bar", NodeKey: "bar",
DiscoKey: "faa", DiscoKey: "faa",
Name: "testest", Name: "testest",
NamespaceID: n.ID, NamespaceID: namespace.ID,
Registered: true, Registered: true,
RegisterMethod: "authKey", RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
} }
h.db.Save(&m) app.db.Save(&machine)
p, err := h.checkKeyValidity(pak.Key) key, err := app.checkKeyValidity(pak.Key)
c.Assert(err, check.Equals, errSingleUseAuthKeyHasBeenUsed) c.Assert(err, check.Equals, errSingleUseAuthKeyHasBeenUsed)
c.Assert(p, check.IsNil) c.Assert(key, check.IsNil)
} }
func (*Suite) TestReusableBeingUsedKey(c *check.C) { func (*Suite) TestReusableBeingUsedKey(c *check.C) {
n, err := h.CreateNamespace("test5") namespace, err := app.CreateNamespace("test5")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, true, false, nil) pak, err := app.CreatePreAuthKey(namespace.Name, true, false, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
m := Machine{ machine := Machine{
ID: 1, ID: 1,
MachineKey: "foo", MachineKey: "foo",
NodeKey: "bar", NodeKey: "bar",
DiscoKey: "faa", DiscoKey: "faa",
Name: "testest", Name: "testest",
NamespaceID: n.ID, NamespaceID: namespace.ID,
Registered: true, Registered: true,
RegisterMethod: "authKey", RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
} }
h.db.Save(&m) app.db.Save(&machine)
p, err := h.checkKeyValidity(pak.Key) key, err := app.checkKeyValidity(pak.Key)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(p.ID, check.Equals, pak.ID) c.Assert(key.ID, check.Equals, pak.ID)
} }
func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) { func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) {
n, err := h.CreateNamespace("test6") namespace, err := app.CreateNamespace("test6")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
p, err := h.checkKeyValidity(pak.Key) key, err := app.checkKeyValidity(pak.Key)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(p.ID, check.Equals, pak.ID) c.Assert(key.ID, check.Equals, pak.ID)
} }
func (*Suite) TestEphemeralKey(c *check.C) { func (*Suite) TestEphemeralKey(c *check.C) {
n, err := h.CreateNamespace("test7") namespace, err := app.CreateNamespace("test7")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, true, nil) pak, err := app.CreatePreAuthKey(namespace.Name, false, true, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
now := time.Now() now := time.Now()
m := Machine{ machine := Machine{
ID: 0, ID: 0,
MachineKey: "foo", MachineKey: "foo",
NodeKey: "bar", NodeKey: "bar",
DiscoKey: "faa", DiscoKey: "faa",
Name: "testest", Name: "testest",
NamespaceID: n.ID, NamespaceID: namespace.ID,
Registered: true, Registered: true,
RegisterMethod: "authKey", RegisterMethod: "authKey",
LastSeen: &now, LastSeen: &now,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
} }
h.db.Save(&m) app.db.Save(&machine)
_, err = h.checkKeyValidity(pak.Key) _, err = app.checkKeyValidity(pak.Key)
// Ephemeral keys are by definition reusable // Ephemeral keys are by definition reusable
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = h.GetMachine("test7", "testest") _, err = app.GetMachine("test7", "testest")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
h.expireEphemeralNodesWorker() app.expireEphemeralNodesWorker()
// The machine record should have been deleted // The machine record should have been deleted
_, err = h.GetMachine("test7", "testest") _, err = app.GetMachine("test7", "testest")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
} }
func (*Suite) TestExpirePreauthKey(c *check.C) { func (*Suite) TestExpirePreauthKey(c *check.C) {
n, err := h.CreateNamespace("test3") namespace, err := app.CreateNamespace("test3")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, true, false, nil) pak, err := app.CreatePreAuthKey(namespace.Name, true, false, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(pak.Expiration, check.IsNil) c.Assert(pak.Expiration, check.IsNil)
err = h.ExpirePreAuthKey(pak) err = app.ExpirePreAuthKey(pak)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(pak.Expiration, check.NotNil) c.Assert(pak.Expiration, check.NotNil)
p, err := h.checkKeyValidity(pak.Key) key, err := app.checkKeyValidity(pak.Key)
c.Assert(err, check.Equals, errorAuthKeyExpired) c.Assert(err, check.Equals, errPreAuthKeyExpired)
c.Assert(p, check.IsNil) c.Assert(key, check.IsNil)
} }
func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) { func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) {
n, err := h.CreateNamespace("test6") namespace, err := app.CreateNamespace("test6")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak.Used = true pak.Used = true
h.db.Save(&pak) app.db.Save(&pak)
_, err = h.checkKeyValidity(pak.Key) _, err = app.checkKeyValidity(pak.Key)
c.Assert(err, check.Equals, errSingleUseAuthKeyHasBeenUsed) c.Assert(err, check.Equals, errSingleUseAuthKeyHasBeenUsed)
} }

View file

@ -2,38 +2,48 @@ package headscale
import ( import (
"encoding/json" "encoding/json"
"fmt"
"gorm.io/datatypes" "gorm.io/datatypes"
"inet.af/netaddr" "inet.af/netaddr"
) )
const (
errRouteIsNotAvailable = Error("route is not available")
)
// Deprecated: use machine function instead // Deprecated: use machine function instead
// GetAdvertisedNodeRoutes returns the subnet routes advertised by a node (identified by // GetAdvertisedNodeRoutes returns the subnet routes advertised by a node (identified by
// namespace and node name) // namespace and node name).
func (h *Headscale) GetAdvertisedNodeRoutes(namespace string, nodeName string) (*[]netaddr.IPPrefix, error) { func (h *Headscale) GetAdvertisedNodeRoutes(
m, err := h.GetMachine(namespace, nodeName) namespace string,
nodeName string,
) (*[]netaddr.IPPrefix, error) {
machine, err := h.GetMachine(namespace, nodeName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
hostInfo, err := m.GetHostInfo() hostInfo, err := machine.GetHostInfo()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &hostInfo.RoutableIPs, nil return &hostInfo.RoutableIPs, nil
} }
// Deprecated: use machine function instead // Deprecated: use machine function instead
// GetEnabledNodeRoutes returns the subnet routes enabled by a node (identified by // GetEnabledNodeRoutes returns the subnet routes enabled by a node (identified by
// namespace and node name) // namespace and node name).
func (h *Headscale) GetEnabledNodeRoutes(namespace string, nodeName string) ([]netaddr.IPPrefix, error) { func (h *Headscale) GetEnabledNodeRoutes(
m, err := h.GetMachine(namespace, nodeName) namespace string,
nodeName string,
) ([]netaddr.IPPrefix, error) {
machine, err := h.GetMachine(namespace, nodeName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
data, err := m.EnabledRoutes.MarshalJSON() data, err := machine.EnabledRoutes.MarshalJSON()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -57,8 +67,12 @@ func (h *Headscale) GetEnabledNodeRoutes(namespace string, nodeName string) ([]n
} }
// Deprecated: use machine function instead // Deprecated: use machine function instead
// IsNodeRouteEnabled checks if a certain route has been enabled // IsNodeRouteEnabled checks if a certain route has been enabled.
func (h *Headscale) IsNodeRouteEnabled(namespace string, nodeName string, routeStr string) bool { func (h *Headscale) IsNodeRouteEnabled(
namespace string,
nodeName string,
routeStr string,
) bool {
route, err := netaddr.ParseIPPrefix(routeStr) route, err := netaddr.ParseIPPrefix(routeStr)
if err != nil { if err != nil {
return false return false
@ -74,14 +88,19 @@ func (h *Headscale) IsNodeRouteEnabled(namespace string, nodeName string, routeS
return true return true
} }
} }
return false return false
} }
// Deprecated: use EnableRoute in machine.go // Deprecated: use EnableRoute in machine.go
// EnableNodeRoute enables a subnet route advertised by a node (identified by // EnableNodeRoute enables a subnet route advertised by a node (identified by
// namespace and node name) // namespace and node name).
func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr string) error { func (h *Headscale) EnableNodeRoute(
m, err := h.GetMachine(namespace, nodeName) namespace string,
nodeName string,
routeStr string,
) error {
machine, err := h.GetMachine(namespace, nodeName)
if err != nil { if err != nil {
return err return err
} }
@ -113,7 +132,7 @@ func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr
} }
if !available { if !available {
return fmt.Errorf("route (%s) is not available on node %s", nodeName, routeStr) return errRouteIsNotAvailable
} }
routes, err := json.Marshal(enabledRoutes) routes, err := json.Marshal(enabledRoutes)
@ -121,10 +140,10 @@ func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr
return err return err
} }
m.EnabledRoutes = datatypes.JSON(routes) machine.EnabledRoutes = datatypes.JSON(routes)
h.db.Save(&m) h.db.Save(&machine)
err = h.RequestMapUpdates(m.NamespaceID) err = h.RequestMapUpdates(machine.NamespaceID)
if err != nil { if err != nil {
return err return err
} }

View file

@ -10,57 +10,60 @@ import (
) )
func (s *Suite) TestGetRoutes(c *check.C) { func (s *Suite) TestGetRoutes(c *check.C) {
n, err := h.CreateNamespace("test") namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = h.GetMachine("test", "test_get_route_machine") _, err = app.GetMachine("test", "test_get_route_machine")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
route, err := netaddr.ParseIPPrefix("10.0.0.0/24") route, err := netaddr.ParseIPPrefix("10.0.0.0/24")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
hi := tailcfg.Hostinfo{ hostInfo := tailcfg.Hostinfo{
RoutableIPs: []netaddr.IPPrefix{route}, RoutableIPs: []netaddr.IPPrefix{route},
} }
hostinfo, err := json.Marshal(hi) hostinfo, err := json.Marshal(hostInfo)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
m := Machine{ machine := Machine{
ID: 0, ID: 0,
MachineKey: "foo", MachineKey: "foo",
NodeKey: "bar", NodeKey: "bar",
DiscoKey: "faa", DiscoKey: "faa",
Name: "test_get_route_machine", Name: "test_get_route_machine",
NamespaceID: n.ID, NamespaceID: namespace.ID,
Registered: true, Registered: true,
RegisterMethod: "authKey", RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
HostInfo: datatypes.JSON(hostinfo), HostInfo: datatypes.JSON(hostinfo),
} }
h.db.Save(&m) app.db.Save(&machine)
r, err := h.GetAdvertisedNodeRoutes("test", "test_get_route_machine") advertisedRoutes, err := app.GetAdvertisedNodeRoutes(
"test",
"test_get_route_machine",
)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(*r), check.Equals, 1) c.Assert(len(*advertisedRoutes), check.Equals, 1)
err = h.EnableNodeRoute("test", "test_get_route_machine", "192.168.0.0/24") err = app.EnableNodeRoute("test", "test_get_route_machine", "192.168.0.0/24")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
err = h.EnableNodeRoute("test", "test_get_route_machine", "10.0.0.0/24") err = app.EnableNodeRoute("test", "test_get_route_machine", "10.0.0.0/24")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
} }
func (s *Suite) TestGetEnableRoutes(c *check.C) { func (s *Suite) TestGetEnableRoutes(c *check.C) {
n, err := h.CreateNamespace("test") namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = h.GetMachine("test", "test_enable_route_machine") _, err = app.GetMachine("test", "test_enable_route_machine")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
route, err := netaddr.ParseIPPrefix( route, err := netaddr.ParseIPPrefix(
@ -73,56 +76,68 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
) )
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
hi := tailcfg.Hostinfo{ hostInfo := tailcfg.Hostinfo{
RoutableIPs: []netaddr.IPPrefix{route, route2}, RoutableIPs: []netaddr.IPPrefix{route, route2},
} }
hostinfo, err := json.Marshal(hi) hostinfo, err := json.Marshal(hostInfo)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
m := Machine{ machine := Machine{
ID: 0, ID: 0,
MachineKey: "foo", MachineKey: "foo",
NodeKey: "bar", NodeKey: "bar",
DiscoKey: "faa", DiscoKey: "faa",
Name: "test_enable_route_machine", Name: "test_enable_route_machine",
NamespaceID: n.ID, NamespaceID: namespace.ID,
Registered: true, Registered: true,
RegisterMethod: "authKey", RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
HostInfo: datatypes.JSON(hostinfo), HostInfo: datatypes.JSON(hostinfo),
} }
h.db.Save(&m) app.db.Save(&machine)
availableRoutes, err := h.GetAdvertisedNodeRoutes("test", "test_enable_route_machine") availableRoutes, err := app.GetAdvertisedNodeRoutes(
"test",
"test_enable_route_machine",
)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(*availableRoutes), check.Equals, 2) c.Assert(len(*availableRoutes), check.Equals, 2)
enabledRoutes, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine") noEnabledRoutes, err := app.GetEnabledNodeRoutes(
"test",
"test_enable_route_machine",
)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes), check.Equals, 0) c.Assert(len(noEnabledRoutes), check.Equals, 0)
err = h.EnableNodeRoute("test", "test_enable_route_machine", "192.168.0.0/24") err = app.EnableNodeRoute("test", "test_enable_route_machine", "192.168.0.0/24")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
err = h.EnableNodeRoute("test", "test_enable_route_machine", "10.0.0.0/24") err = app.EnableNodeRoute("test", "test_enable_route_machine", "10.0.0.0/24")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
enabledRoutes1, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine") enabledRoutes, err := app.GetEnabledNodeRoutes("test", "test_enable_route_machine")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes1), check.Equals, 1) c.Assert(len(enabledRoutes), check.Equals, 1)
// Adding it twice will just let it pass through // Adding it twice will just let it pass through
err = h.EnableNodeRoute("test", "test_enable_route_machine", "10.0.0.0/24") err = app.EnableNodeRoute("test", "test_enable_route_machine", "10.0.0.0/24")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
enabledRoutes2, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine") enableRoutesAfterDoubleApply, err := app.GetEnabledNodeRoutes(
"test",
"test_enable_route_machine",
)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes2), check.Equals, 1) c.Assert(len(enableRoutesAfterDoubleApply), check.Equals, 1)
err = h.EnableNodeRoute("test", "test_enable_route_machine", "150.0.10.0/25") err = app.EnableNodeRoute("test", "test_enable_route_machine", "150.0.10.0/25")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
enabledRoutes3, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine") enabledRoutesWithAdditionalRoute, err := app.GetEnabledNodeRoutes(
"test",
"test_enable_route_machine",
)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes3), check.Equals, 2) c.Assert(len(enabledRoutesWithAdditionalRoute), check.Equals, 2)
} }

View file

@ -2,11 +2,13 @@ package headscale
import "gorm.io/gorm" import "gorm.io/gorm"
const errorSameNamespace = Error("Destination namespace same as origin") const (
const errorMachineAlreadyShared = Error("Node already shared to this namespace") errSameNamespace = Error("Destination namespace same as origin")
const errorMachineNotShared = Error("Machine not shared to this namespace") errMachineAlreadyShared = Error("Node already shared to this namespace")
errMachineNotShared = Error("Machine not shared to this namespace")
)
// SharedMachine is a join table to support sharing nodes between namespaces // SharedMachine is a join table to support sharing nodes between namespaces.
type SharedMachine struct { type SharedMachine struct {
gorm.Model gorm.Model
MachineID uint64 MachineID uint64
@ -15,49 +17,57 @@ type SharedMachine struct {
Namespace Namespace Namespace Namespace
} }
// AddSharedMachineToNamespace adds a machine as a shared node to a namespace // AddSharedMachineToNamespace adds a machine as a shared node to a namespace.
func (h *Headscale) AddSharedMachineToNamespace(m *Machine, ns *Namespace) error { func (h *Headscale) AddSharedMachineToNamespace(
if m.NamespaceID == ns.ID { machine *Machine,
return errorSameNamespace namespace *Namespace,
) error {
if machine.NamespaceID == namespace.ID {
return errSameNamespace
} }
sharedMachines := []SharedMachine{} sharedMachines := []SharedMachine{}
if err := h.db.Where("machine_id = ? AND namespace_id = ?", m.ID, ns.ID).Find(&sharedMachines).Error; err != nil { if err := h.db.Where("machine_id = ? AND namespace_id = ?", machine.ID, namespace.ID).Find(&sharedMachines).Error; err != nil {
return err return err
} }
if len(sharedMachines) > 0 { if len(sharedMachines) > 0 {
return errorMachineAlreadyShared return errMachineAlreadyShared
} }
sharedMachine := SharedMachine{ sharedMachine := SharedMachine{
MachineID: m.ID, MachineID: machine.ID,
Machine: *m, Machine: *machine,
NamespaceID: ns.ID, NamespaceID: namespace.ID,
Namespace: *ns, Namespace: *namespace,
} }
h.db.Save(&sharedMachine) h.db.Save(&sharedMachine)
return nil return nil
} }
// RemoveSharedMachineFromNamespace removes a shared machine from a namespace // RemoveSharedMachineFromNamespace removes a shared machine from a namespace.
func (h *Headscale) RemoveSharedMachineFromNamespace(m *Machine, ns *Namespace) error { func (h *Headscale) RemoveSharedMachineFromNamespace(
if m.NamespaceID == ns.ID { machine *Machine,
namespace *Namespace,
) error {
if machine.NamespaceID == namespace.ID {
// Can't unshare from primary namespace // Can't unshare from primary namespace
return errorMachineNotShared return errMachineNotShared
} }
sharedMachine := SharedMachine{} sharedMachine := SharedMachine{}
result := h.db.Where("machine_id = ? AND namespace_id = ?", m.ID, ns.ID).Unscoped().Delete(&sharedMachine) result := h.db.Where("machine_id = ? AND namespace_id = ?", machine.ID, namespace.ID).
Unscoped().
Delete(&sharedMachine)
if result.Error != nil { if result.Error != nil {
return result.Error return result.Error
} }
if result.RowsAffected == 0 { if result.RowsAffected == 0 {
return errorMachineNotShared return errMachineNotShared
} }
err := h.RequestMapUpdates(ns.ID) err := h.RequestMapUpdates(namespace.ID)
if err != nil { if err != nil {
return err return err
} }
@ -65,10 +75,10 @@ func (h *Headscale) RemoveSharedMachineFromNamespace(m *Machine, ns *Namespace)
return nil return nil
} }
// RemoveSharedMachineFromAllNamespaces removes a machine as a shared node from all namespaces // RemoveSharedMachineFromAllNamespaces removes a machine as a shared node from all namespaces.
func (h *Headscale) RemoveSharedMachineFromAllNamespaces(m *Machine) error { func (h *Headscale) RemoveSharedMachineFromAllNamespaces(machine *Machine) error {
sharedMachine := SharedMachine{} sharedMachine := SharedMachine{}
if result := h.db.Where("machine_id = ?", m.ID).Unscoped().Delete(&sharedMachine); result.Error != nil { if result := h.db.Where("machine_id = ?", machine.ID).Unscoped().Delete(&sharedMachine); result.Error != nil {
return result.Error return result.Error
} }

View file

@ -4,45 +4,48 @@ import (
"gopkg.in/check.v1" "gopkg.in/check.v1"
) )
func CreateNodeNamespace(c *check.C, namespace, node, key, IP string) (*Namespace, *Machine) { func CreateNodeNamespace(
n1, err := h.CreateNamespace(namespace) c *check.C,
namespaceName, node, key, ip string,
) (*Namespace, *Machine) {
namespace, err := app.CreateNamespace(namespaceName)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak1, err := h.CreatePreAuthKey(n1.Name, false, false, nil) pak1, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = h.GetMachine(n1.Name, node) _, err = app.GetMachine(namespace.Name, node)
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
m1 := &Machine{ machine := &Machine{
ID: 0, ID: 0,
MachineKey: key, MachineKey: key,
NodeKey: key, NodeKey: key,
DiscoKey: key, DiscoKey: key,
Name: node, Name: node,
NamespaceID: n1.ID, NamespaceID: namespace.ID,
Registered: true, Registered: true,
RegisterMethod: "authKey", RegisterMethod: "authKey",
IPAddress: IP, IPAddress: ip,
AuthKeyID: uint(pak1.ID), AuthKeyID: uint(pak1.ID),
} }
h.db.Save(m1) app.db.Save(machine)
_, err = h.GetMachine(n1.Name, m1.Name) _, err = app.GetMachine(namespace.Name, machine.Name)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
return n1, m1 return namespace, machine
} }
func (s *Suite) TestBasicSharedNodesInNamespace(c *check.C) { func (s *Suite) TestBasicSharedNodesInNamespace(c *check.C) {
n1, m1 := CreateNodeNamespace( namespace1, machine1 := CreateNodeNamespace(
c, c,
"shared1", "shared1",
"test_get_shared_nodes_1", "test_get_shared_nodes_1",
"686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
"100.64.0.1", "100.64.0.1",
) )
_, m2 := CreateNodeNamespace( _, machine2 := CreateNodeNamespace(
c, c,
"shared2", "shared2",
"test_get_shared_nodes_2", "test_get_shared_nodes_2",
@ -50,21 +53,21 @@ func (s *Suite) TestBasicSharedNodesInNamespace(c *check.C) {
"100.64.0.2", "100.64.0.2",
) )
p1s, err := h.getPeers(m1) peersOfMachine1BeforeShared, err := app.getPeers(machine1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(p1s), check.Equals, 0) c.Assert(len(peersOfMachine1BeforeShared), check.Equals, 0)
err = h.AddSharedMachineToNamespace(m2, n1) err = app.AddSharedMachineToNamespace(machine2, namespace1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
p1sAfter, err := h.getPeers(m1) peersOfMachine1AfterShared, err := app.getPeers(machine1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(p1sAfter), check.Equals, 1) c.Assert(len(peersOfMachine1AfterShared), check.Equals, 1)
c.Assert(p1sAfter[0].ID, check.Equals, m2.ID) c.Assert(peersOfMachine1AfterShared[0].ID, check.Equals, machine2.ID)
} }
func (s *Suite) TestSameNamespace(c *check.C) { func (s *Suite) TestSameNamespace(c *check.C) {
n1, m1 := CreateNodeNamespace( namespace1, machine1 := CreateNodeNamespace(
c, c,
"shared1", "shared1",
"test_get_shared_nodes_1", "test_get_shared_nodes_1",
@ -72,23 +75,23 @@ func (s *Suite) TestSameNamespace(c *check.C) {
"100.64.0.1", "100.64.0.1",
) )
p1s, err := h.getPeers(m1) peersOfMachine1BeforeShare, err := app.getPeers(machine1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(p1s), check.Equals, 0) c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 0)
err = h.AddSharedMachineToNamespace(m1, n1) err = app.AddSharedMachineToNamespace(machine1, namespace1)
c.Assert(err, check.Equals, errorSameNamespace) c.Assert(err, check.Equals, errSameNamespace)
} }
func (s *Suite) TestUnshare(c *check.C) { func (s *Suite) TestUnshare(c *check.C) {
n1, m1 := CreateNodeNamespace( namespace1, machine1 := CreateNodeNamespace(
c, c,
"shared1", "shared1",
"test_unshare_1", "test_unshare_1",
"686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
"100.64.0.1", "100.64.0.1",
) )
_, m2 := CreateNodeNamespace( _, machine2 := CreateNodeNamespace(
c, c,
"shared2", "shared2",
"test_unshare_2", "test_unshare_2",
@ -96,40 +99,40 @@ func (s *Suite) TestUnshare(c *check.C) {
"100.64.0.2", "100.64.0.2",
) )
p1s, err := h.getPeers(m1) peersOfMachine1BeforeShare, err := app.getPeers(machine1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(p1s), check.Equals, 0) c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 0)
err = h.AddSharedMachineToNamespace(m2, n1) err = app.AddSharedMachineToNamespace(machine2, namespace1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
p1s, err = h.getShared(m1) peersOfMachine1BeforeShare, err = app.getShared(machine1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(p1s), check.Equals, 1) c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 1)
err = h.RemoveSharedMachineFromNamespace(m2, n1) err = app.RemoveSharedMachineFromNamespace(machine2, namespace1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
p1s, err = h.getShared(m1) peersOfMachine1BeforeShare, err = app.getShared(machine1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(p1s), check.Equals, 0) c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 0)
err = h.RemoveSharedMachineFromNamespace(m2, n1) err = app.RemoveSharedMachineFromNamespace(machine2, namespace1)
c.Assert(err, check.Equals, errorMachineNotShared) c.Assert(err, check.Equals, errMachineNotShared)
err = h.RemoveSharedMachineFromNamespace(m1, n1) err = app.RemoveSharedMachineFromNamespace(machine1, namespace1)
c.Assert(err, check.Equals, errorMachineNotShared) c.Assert(err, check.Equals, errMachineNotShared)
} }
func (s *Suite) TestAlreadyShared(c *check.C) { func (s *Suite) TestAlreadyShared(c *check.C) {
n1, m1 := CreateNodeNamespace( namespace1, machine1 := CreateNodeNamespace(
c, c,
"shared1", "shared1",
"test_get_shared_nodes_1", "test_get_shared_nodes_1",
"686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
"100.64.0.1", "100.64.0.1",
) )
_, m2 := CreateNodeNamespace( _, machine2 := CreateNodeNamespace(
c, c,
"shared2", "shared2",
"test_get_shared_nodes_2", "test_get_shared_nodes_2",
@ -137,25 +140,25 @@ func (s *Suite) TestAlreadyShared(c *check.C) {
"100.64.0.2", "100.64.0.2",
) )
p1s, err := h.getPeers(m1) peersOfMachine1BeforeShare, err := app.getPeers(machine1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(p1s), check.Equals, 0) c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 0)
err = h.AddSharedMachineToNamespace(m2, n1) err = app.AddSharedMachineToNamespace(machine2, namespace1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
err = h.AddSharedMachineToNamespace(m2, n1) err = app.AddSharedMachineToNamespace(machine2, namespace1)
c.Assert(err, check.Equals, errorMachineAlreadyShared) c.Assert(err, check.Equals, errMachineAlreadyShared)
} }
func (s *Suite) TestDoNotIncludeRoutesOnShared(c *check.C) { func (s *Suite) TestDoNotIncludeRoutesOnShared(c *check.C) {
n1, m1 := CreateNodeNamespace( namespace1, machine1 := CreateNodeNamespace(
c, c,
"shared1", "shared1",
"test_get_shared_nodes_1", "test_get_shared_nodes_1",
"686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
"100.64.0.1", "100.64.0.1",
) )
_, m2 := CreateNodeNamespace( _, machine2 := CreateNodeNamespace(
c, c,
"shared2", "shared2",
"test_get_shared_nodes_2", "test_get_shared_nodes_2",
@ -163,35 +166,35 @@ func (s *Suite) TestDoNotIncludeRoutesOnShared(c *check.C) {
"100.64.0.2", "100.64.0.2",
) )
p1s, err := h.getPeers(m1) peersOfMachine1BeforeShare, err := app.getPeers(machine1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(p1s), check.Equals, 0) c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 0)
err = h.AddSharedMachineToNamespace(m2, n1) err = app.AddSharedMachineToNamespace(machine2, namespace1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
p1sAfter, err := h.getPeers(m1) peersOfMachine1AfterShare, err := app.getPeers(machine1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(p1sAfter), check.Equals, 1) c.Assert(len(peersOfMachine1AfterShare), check.Equals, 1)
c.Assert(p1sAfter[0].Name, check.Equals, "test_get_shared_nodes_2") c.Assert(peersOfMachine1AfterShare[0].Name, check.Equals, "test_get_shared_nodes_2")
} }
func (s *Suite) TestComplexSharingAcrossNamespaces(c *check.C) { func (s *Suite) TestComplexSharingAcrossNamespaces(c *check.C) {
n1, m1 := CreateNodeNamespace( namespace1, machine1 := CreateNodeNamespace(
c, c,
"shared1", "shared1",
"test_get_shared_nodes_1", "test_get_shared_nodes_1",
"686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
"100.64.0.1", "100.64.0.1",
) )
_, m2 := CreateNodeNamespace( _, machine2 := CreateNodeNamespace(
c, c,
"shared2", "shared2",
"test_get_shared_nodes_2", "test_get_shared_nodes_2",
"dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
"100.64.0.2", "100.64.0.2",
) )
_, m3 := CreateNodeNamespace( _, machine3 := CreateNodeNamespace(
c, c,
"shared3", "shared3",
"test_get_shared_nodes_3", "test_get_shared_nodes_3",
@ -199,76 +202,80 @@ func (s *Suite) TestComplexSharingAcrossNamespaces(c *check.C) {
"100.64.0.3", "100.64.0.3",
) )
pak4, err := h.CreatePreAuthKey(n1.Name, false, false, nil) pak4, err := app.CreatePreAuthKey(namespace1.Name, false, false, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
m4 := &Machine{ machine4 := &Machine{
ID: 4, ID: 4,
MachineKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f", MachineKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f",
NodeKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f", NodeKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f",
DiscoKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f", DiscoKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f",
Name: "test_get_shared_nodes_4", Name: "test_get_shared_nodes_4",
NamespaceID: n1.ID, NamespaceID: namespace1.ID,
Registered: true, Registered: true,
RegisterMethod: "authKey", RegisterMethod: "authKey",
IPAddress: "100.64.0.4", IPAddress: "100.64.0.4",
AuthKeyID: uint(pak4.ID), AuthKeyID: uint(pak4.ID),
} }
h.db.Save(m4) app.db.Save(machine4)
_, err = h.GetMachine(n1.Name, m4.Name) _, err = app.GetMachine(namespace1.Name, machine4.Name)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
p1s, err := h.getPeers(m1) peersOfMachine1BeforeShare, err := app.getPeers(machine1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(p1s), check.Equals, 1) // node1 can see node4 c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 1) // node1 can see node4
c.Assert(p1s[0].Name, check.Equals, m4.Name) c.Assert(peersOfMachine1BeforeShare[0].Name, check.Equals, machine4.Name)
err = h.AddSharedMachineToNamespace(m2, n1) err = app.AddSharedMachineToNamespace(machine2, namespace1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
p1sAfter, err := h.getPeers(m1) peersOfMachine1AfterShare, err := app.getPeers(machine1)
c.Assert(err, check.IsNil)
c.Assert(len(p1sAfter), check.Equals, 2) // node1 can see node2 (shared) and node4 (same namespace)
c.Assert(p1sAfter[0].Name, check.Equals, m2.Name)
c.Assert(p1sAfter[1].Name, check.Equals, m4.Name)
node1shared, err := h.getShared(m1)
c.Assert(err, check.IsNil)
c.Assert(len(node1shared), check.Equals, 1) // node1 can see node2 as shared
c.Assert(node1shared[0].Name, check.Equals, m2.Name)
pAlone, err := h.getPeers(m3)
c.Assert(err, check.IsNil)
c.Assert(len(pAlone), check.Equals, 0) // node3 is alone
pSharedTo, err := h.getPeers(m2)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert( c.Assert(
len(pSharedTo), len(peersOfMachine1AfterShare),
check.Equals,
2,
) // node1 can see node2 (shared) and node4 (same namespace)
c.Assert(peersOfMachine1AfterShare[0].Name, check.Equals, machine2.Name)
c.Assert(peersOfMachine1AfterShare[1].Name, check.Equals, machine4.Name)
sharedOfMachine1, err := app.getShared(machine1)
c.Assert(err, check.IsNil)
c.Assert(len(sharedOfMachine1), check.Equals, 1) // node1 can see node2 as shared
c.Assert(sharedOfMachine1[0].Name, check.Equals, machine2.Name)
peersOfMachine3, err := app.getPeers(machine3)
c.Assert(err, check.IsNil)
c.Assert(len(peersOfMachine3), check.Equals, 0) // node3 is alone
peersOfMachine2, err := app.getPeers(machine2)
c.Assert(err, check.IsNil)
c.Assert(
len(peersOfMachine2),
check.Equals, check.Equals,
2, 2,
) // node2 should see node1 (sharedTo) and node4 (sharedTo), as is shared in namespace1 ) // node2 should see node1 (sharedTo) and node4 (sharedTo), as is shared in namespace1
c.Assert(pSharedTo[0].Name, check.Equals, m1.Name) c.Assert(peersOfMachine2[0].Name, check.Equals, machine1.Name)
c.Assert(pSharedTo[1].Name, check.Equals, m4.Name) c.Assert(peersOfMachine2[1].Name, check.Equals, machine4.Name)
} }
func (s *Suite) TestDeleteSharedMachine(c *check.C) { func (s *Suite) TestDeleteSharedMachine(c *check.C) {
n1, m1 := CreateNodeNamespace( namespace1, machine1 := CreateNodeNamespace(
c, c,
"shared1", "shared1",
"test_get_shared_nodes_1", "test_get_shared_nodes_1",
"686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
"100.64.0.1", "100.64.0.1",
) )
_, m2 := CreateNodeNamespace( _, machine2 := CreateNodeNamespace(
c, c,
"shared2", "shared2",
"test_get_shared_nodes_2", "test_get_shared_nodes_2",
"dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
"100.64.0.2", "100.64.0.2",
) )
_, m3 := CreateNodeNamespace( _, machine3 := CreateNodeNamespace(
c, c,
"shared3", "shared3",
"test_get_shared_nodes_3", "test_get_shared_nodes_3",
@ -276,56 +283,58 @@ func (s *Suite) TestDeleteSharedMachine(c *check.C) {
"100.64.0.3", "100.64.0.3",
) )
pak4n1, err := h.CreatePreAuthKey(n1.Name, false, false, nil) pak4n1, err := app.CreatePreAuthKey(namespace1.Name, false, false, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
m4 := &Machine{ machine4 := &Machine{
ID: 4, ID: 4,
MachineKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f", MachineKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f",
NodeKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f", NodeKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f",
DiscoKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f", DiscoKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f",
Name: "test_get_shared_nodes_4", Name: "test_get_shared_nodes_4",
NamespaceID: n1.ID, NamespaceID: namespace1.ID,
Registered: true, Registered: true,
RegisterMethod: "authKey", RegisterMethod: "authKey",
IPAddress: "100.64.0.4", IPAddress: "100.64.0.4",
AuthKeyID: uint(pak4n1.ID), AuthKeyID: uint(pak4n1.ID),
} }
h.db.Save(m4) app.db.Save(machine4)
_, err = h.GetMachine(n1.Name, m4.Name) _, err = app.GetMachine(namespace1.Name, machine4.Name)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
p1s, err := h.getPeers(m1) peersOfMachine1BeforeShare, err := app.getPeers(machine1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(p1s), check.Equals, 1) // nodes 1 and 4 c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 1) // nodes 1 and 4
c.Assert(p1s[0].Name, check.Equals, m4.Name) c.Assert(peersOfMachine1BeforeShare[0].Name, check.Equals, machine4.Name)
err = h.AddSharedMachineToNamespace(m2, n1) err = app.AddSharedMachineToNamespace(machine2, namespace1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
p1sAfter, err := h.getPeers(m1) peersOfMachine1AfterShare, err := app.getPeers(machine1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(p1sAfter), check.Equals, 2) // nodes 1, 2, 4 c.Assert(len(peersOfMachine1AfterShare), check.Equals, 2) // nodes 1, 2, 4
c.Assert(p1sAfter[0].Name, check.Equals, m2.Name) c.Assert(peersOfMachine1AfterShare[0].Name, check.Equals, machine2.Name)
c.Assert(p1sAfter[1].Name, check.Equals, m4.Name) c.Assert(peersOfMachine1AfterShare[1].Name, check.Equals, machine4.Name)
node1shared, err := h.getShared(m1) sharedOfMachine1, err := app.getShared(machine1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(node1shared), check.Equals, 1) // nodes 1, 2, 4 c.Assert(len(sharedOfMachine1), check.Equals, 1) // nodes 1, 2, 4
c.Assert(node1shared[0].Name, check.Equals, m2.Name) c.Assert(sharedOfMachine1[0].Name, check.Equals, machine2.Name)
pAlone, err := h.getPeers(m3) peersOfMachine3, err := app.getPeers(machine3)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(pAlone), check.Equals, 0) // node 3 is alone c.Assert(len(peersOfMachine3), check.Equals, 0) // node 3 is alone
sharedMachines, err := h.ListSharedMachinesInNamespace(n1.Name) sharedMachinesInNamespace1, err := app.ListSharedMachinesInNamespace(
namespace1.Name,
)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(sharedMachines), check.Equals, 1) c.Assert(len(sharedMachinesInNamespace1), check.Equals, 1)
err = h.DeleteMachine(m2) err = app.DeleteMachine(machine2)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
sharedMachines, err = h.ListSharedMachinesInNamespace(n1.Name) sharedMachinesInNamespace1, err = app.ListSharedMachinesInNamespace(namespace1.Name)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(sharedMachines), check.Equals, 0) c.Assert(len(sharedMachinesInNamespace1), check.Equals, 0)
} }

View file

@ -6,16 +6,15 @@ import (
"net/http" "net/http"
"text/template" "text/template"
"github.com/rs/zerolog/log"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/rs/zerolog/log"
) )
//go:embed gen/openapiv2/headscale/v1/headscale.swagger.json //go:embed gen/openapiv2/headscale/v1/headscale.swagger.json
var apiV1JSON []byte var apiV1JSON []byte
func SwaggerUI(c *gin.Context) { func SwaggerUI(ctx *gin.Context) {
t := template.Must(template.New("swagger").Parse(` swaggerTemplate := template.Must(template.New("swagger").Parse(`
<html> <html>
<head> <head>
<link rel="stylesheet" type="text/css" href="https://unpkg.com/swagger-ui-dist@3/swagger-ui.css"> <link rel="stylesheet" type="text/css" href="https://unpkg.com/swagger-ui-dist@3/swagger-ui.css">
@ -48,18 +47,23 @@ func SwaggerUI(c *gin.Context) {
</html>`)) </html>`))
var payload bytes.Buffer var payload bytes.Buffer
if err := t.Execute(&payload, struct{}{}); err != nil { if err := swaggerTemplate.Execute(&payload, struct{}{}); err != nil {
log.Error(). log.Error().
Caller(). Caller().
Err(err). Err(err).
Msg("Could not render Swagger") Msg("Could not render Swagger")
c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Swagger")) ctx.Data(
http.StatusInternalServerError,
"text/html; charset=utf-8",
[]byte("Could not render Swagger"),
)
return return
} }
c.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes()) ctx.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes())
} }
func SwaggerAPIv1(c *gin.Context) { func SwaggerAPIv1(ctx *gin.Context) {
c.Data(http.StatusOK, "application/json; charset=utf-8", apiV1JSON) ctx.Data(http.StatusOK, "application/json; charset=utf-8", apiV1JSON)
} }

View file

@ -20,31 +20,48 @@ import (
"tailscale.com/types/wgkey" "tailscale.com/types/wgkey"
) )
const (
errCannotDecryptReponse = Error("cannot decrypt response")
errResponseMissingNonce = Error("response missing nonce")
errCouldNotAllocateIP = Error("could not find any suitable IP")
)
// Error is used to compare errors as per https://dave.cheney.net/2016/04/07/constant-errors // Error is used to compare errors as per https://dave.cheney.net/2016/04/07/constant-errors
type Error string type Error string
func (e Error) Error() string { return string(e) } func (e Error) Error() string { return string(e) }
func decode(msg []byte, v interface{}, pubKey *wgkey.Key, privKey *wgkey.Private) error { func decode(
msg []byte,
v interface{},
pubKey *wgkey.Key,
privKey *wgkey.Private,
) error {
return decodeMsg(msg, v, pubKey, privKey) return decodeMsg(msg, v, pubKey, privKey)
} }
func decodeMsg(msg []byte, v interface{}, pubKey *wgkey.Key, privKey *wgkey.Private) error { func decodeMsg(
msg []byte,
output interface{},
pubKey *wgkey.Key,
privKey *wgkey.Private,
) error {
decrypted, err := decryptMsg(msg, pubKey, privKey) decrypted, err := decryptMsg(msg, pubKey, privKey)
if err != nil { if err != nil {
return err return err
} }
// fmt.Println(string(decrypted)) // fmt.Println(string(decrypted))
if err := json.Unmarshal(decrypted, v); err != nil { if err := json.Unmarshal(decrypted, output); err != nil {
return fmt.Errorf("response: %v", err) return err
} }
return nil return nil
} }
func decryptMsg(msg []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, error) { func decryptMsg(msg []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, error) {
var nonce [24]byte var nonce [24]byte
if len(msg) < len(nonce)+1 { if len(msg) < len(nonce)+1 {
return nil, fmt.Errorf("response missing nonce, len=%d", len(msg)) return nil, errResponseMissingNonce
} }
copy(nonce[:], msg) copy(nonce[:], msg)
msg = msg[len(nonce):] msg = msg[len(nonce):]
@ -52,8 +69,9 @@ func decryptMsg(msg []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte,
pub, pri := (*[32]byte)(pubKey), (*[32]byte)(privKey) pub, pri := (*[32]byte)(pubKey), (*[32]byte)(privKey)
decrypted, ok := box.Open(nil, msg, &nonce, pub, pri) decrypted, ok := box.Open(nil, msg, &nonce, pub, pri)
if !ok { if !ok {
return nil, fmt.Errorf("cannot decrypt response") return nil, errCannotDecryptReponse
} }
return decrypted, nil return decrypted, nil
} }
@ -66,13 +84,18 @@ func encode(v interface{}, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, e
return encodeMsg(b, pubKey, privKey) return encodeMsg(b, pubKey, privKey)
} }
func encodeMsg(b []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, error) { func encodeMsg(
payload []byte,
pubKey *wgkey.Key,
privKey *wgkey.Private,
) ([]byte, error) {
var nonce [24]byte var nonce [24]byte
if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil { if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil {
panic(err) panic(err)
} }
pub, pri := (*[32]byte)(pubKey), (*[32]byte)(privKey) pub, pri := (*[32]byte)(pubKey), (*[32]byte)(privKey)
msg := box.Seal(nonce[:], b, &nonce, pub, pri) msg := box.Seal(nonce[:], payload, &nonce, pub, pri)
return msg, nil return msg, nil
} }
@ -89,7 +112,7 @@ func (h *Headscale) getAvailableIP() (*netaddr.IP, error) {
for { for {
if !ipPrefix.Contains(ip) { if !ipPrefix.Contains(ip) {
return nil, fmt.Errorf("could not find any suitable IP in %s", ipPrefix) return nil, errCouldNotAllocateIP
} }
// Some OS (including Linux) does not like when IPs ends with 0 or 255, which // Some OS (including Linux) does not like when IPs ends with 0 or 255, which
@ -98,13 +121,14 @@ func (h *Headscale) getAvailableIP() (*netaddr.IP, error) {
ipRaw := ip.As4() ipRaw := ip.As4()
if ipRaw[3] == 0 || ipRaw[3] == 255 { if ipRaw[3] == 0 || ipRaw[3] == 255 {
ip = ip.Next() ip = ip.Next()
continue continue
} }
if ip.IsZero() && if ip.IsZero() &&
ip.IsLoopback() { ip.IsLoopback() {
ip = ip.Next() ip = ip.Next()
continue continue
} }
@ -125,7 +149,7 @@ func (h *Headscale) getUsedIPs() ([]netaddr.IP, error) {
if addr != "" { if addr != "" {
ip, err := netaddr.ParseIP(addr) ip, err := netaddr.ParseIP(addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse ip from database, %w", err) return nil, fmt.Errorf("failed to parse ip from database: %w", err)
} }
ips[index] = ip ips[index] = ip
@ -156,11 +180,16 @@ func tailNodesToString(nodes []*tailcfg.Node) string {
} }
func tailMapResponseToString(resp tailcfg.MapResponse) string { func tailMapResponseToString(resp tailcfg.MapResponse) string {
return fmt.Sprintf("{ Node: %s, Peers: %s }", resp.Node.Name, tailNodesToString(resp.Peers)) return fmt.Sprintf(
"{ Node: %s, Peers: %s }",
resp.Node.Name,
tailNodesToString(resp.Peers),
)
} }
func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) { func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) {
var d net.Dialer var d net.Dialer
return d.DialContext(ctx, "unix", addr) return d.DialContext(ctx, "unix", addr)
} }
@ -174,7 +203,7 @@ func ipPrefixToString(prefixes []netaddr.IPPrefix) []string {
return result return result
} }
func stringToIpPrefix(prefixes []string) ([]netaddr.IPPrefix, error) { func stringToIPPrefix(prefixes []string) ([]netaddr.IPPrefix, error) {
result := make([]netaddr.IPPrefix, len(prefixes)) result := make([]netaddr.IPPrefix, len(prefixes))
for index, prefixStr := range prefixes { for index, prefixStr := range prefixes {
@ -189,7 +218,7 @@ func stringToIpPrefix(prefixes []string) ([]netaddr.IPPrefix, error) {
return result, nil return result, nil
} }
func containsIpPrefix(prefixes []netaddr.IPPrefix, prefix netaddr.IPPrefix) bool { func containsIPPrefix(prefixes []netaddr.IPPrefix, prefix netaddr.IPPrefix) bool {
for _, p := range prefixes { for _, p := range prefixes {
if prefix == p { if prefix == p {
return true return true

View file

@ -6,7 +6,7 @@ import (
) )
func (s *Suite) TestGetAvailableIp(c *check.C) { func (s *Suite) TestGetAvailableIp(c *check.C) {
ip, err := h.getAvailableIP() ip, err := app.getAvailableIP()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -16,33 +16,33 @@ func (s *Suite) TestGetAvailableIp(c *check.C) {
} }
func (s *Suite) TestGetUsedIps(c *check.C) { func (s *Suite) TestGetUsedIps(c *check.C) {
ip, err := h.getAvailableIP() ip, err := app.getAvailableIP()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
n, err := h.CreateNamespace("test_ip") namespace, err := app.CreateNamespace("test_ip")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = h.GetMachine("test", "testmachine") _, err = app.GetMachine("test", "testmachine")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
m := Machine{ machine := Machine{
ID: 0, ID: 0,
MachineKey: "foo", MachineKey: "foo",
NodeKey: "bar", NodeKey: "bar",
DiscoKey: "faa", DiscoKey: "faa",
Name: "testmachine", Name: "testmachine",
NamespaceID: n.ID, NamespaceID: namespace.ID,
Registered: true, Registered: true,
RegisterMethod: "authKey", RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
IPAddress: ip.String(), IPAddress: ip.String(),
} }
h.db.Save(&m) app.db.Save(&machine)
ips, err := h.getUsedIPs() ips, err := app.getUsedIPs()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -50,42 +50,42 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
c.Assert(ips[0], check.Equals, expected) c.Assert(ips[0], check.Equals, expected)
m1, err := h.GetMachineByID(0) machine1, err := app.GetMachineByID(0)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(m1.IPAddress, check.Equals, expected.String()) c.Assert(machine1.IPAddress, check.Equals, expected.String())
} }
func (s *Suite) TestGetMultiIp(c *check.C) { func (s *Suite) TestGetMultiIp(c *check.C) {
n, err := h.CreateNamespace("test-ip-multi") namespace, err := app.CreateNamespace("test-ip-multi")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
for i := 1; i <= 350; i++ { for index := 1; index <= 350; index++ {
ip, err := h.getAvailableIP() ip, err := app.getAvailableIP()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = h.GetMachine("test", "testmachine") _, err = app.GetMachine("test", "testmachine")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
m := Machine{ machine := Machine{
ID: uint64(i), ID: uint64(index),
MachineKey: "foo", MachineKey: "foo",
NodeKey: "bar", NodeKey: "bar",
DiscoKey: "faa", DiscoKey: "faa",
Name: "testmachine", Name: "testmachine",
NamespaceID: n.ID, NamespaceID: namespace.ID,
Registered: true, Registered: true,
RegisterMethod: "authKey", RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
IPAddress: ip.String(), IPAddress: ip.String(),
} }
h.db.Save(&m) app.db.Save(&machine)
} }
ips, err := h.getUsedIPs() ips, err := app.getUsedIPs()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -96,59 +96,67 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
c.Assert(ips[300], check.Equals, netaddr.MustParseIP("10.27.1.47")) c.Assert(ips[300], check.Equals, netaddr.MustParseIP("10.27.1.47"))
// Check that we can read back the IPs // Check that we can read back the IPs
m1, err := h.GetMachineByID(1) machine1, err := app.GetMachineByID(1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(m1.IPAddress, check.Equals, netaddr.MustParseIP("10.27.0.1").String()) c.Assert(
machine1.IPAddress,
check.Equals,
netaddr.MustParseIP("10.27.0.1").String(),
)
m50, err := h.GetMachineByID(50) machine50, err := app.GetMachineByID(50)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(m50.IPAddress, check.Equals, netaddr.MustParseIP("10.27.0.50").String()) c.Assert(
machine50.IPAddress,
check.Equals,
netaddr.MustParseIP("10.27.0.50").String(),
)
expectedNextIP := netaddr.MustParseIP("10.27.1.97") expectedNextIP := netaddr.MustParseIP("10.27.1.97")
nextIP, err := h.getAvailableIP() nextIP, err := app.getAvailableIP()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(nextIP.String(), check.Equals, expectedNextIP.String()) c.Assert(nextIP.String(), check.Equals, expectedNextIP.String())
// If we call get Available again, we should receive // If we call get Available again, we should receive
// the same IP, as it has not been reserved. // the same IP, as it has not been reserved.
nextIP2, err := h.getAvailableIP() nextIP2, err := app.getAvailableIP()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(nextIP2.String(), check.Equals, expectedNextIP.String()) c.Assert(nextIP2.String(), check.Equals, expectedNextIP.String())
} }
func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) { func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
ip, err := h.getAvailableIP() ip, err := app.getAvailableIP()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
expected := netaddr.MustParseIP("10.27.0.1") expected := netaddr.MustParseIP("10.27.0.1")
c.Assert(ip.String(), check.Equals, expected.String()) c.Assert(ip.String(), check.Equals, expected.String())
n, err := h.CreateNamespace("test_ip") namespace, err := app.CreateNamespace("test_ip")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = h.GetMachine("test", "testmachine") _, err = app.GetMachine("test", "testmachine")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
m := Machine{ machine := Machine{
ID: 0, ID: 0,
MachineKey: "foo", MachineKey: "foo",
NodeKey: "bar", NodeKey: "bar",
DiscoKey: "faa", DiscoKey: "faa",
Name: "testmachine", Name: "testmachine",
NamespaceID: n.ID, NamespaceID: namespace.ID,
Registered: true, Registered: true,
RegisterMethod: "authKey", RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
} }
h.db.Save(&m) app.db.Save(&machine)
ip2, err := h.getAvailableIP() ip2, err := app.getAvailableIP()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(ip2.String(), check.Equals, expected.String()) c.Assert(ip2.String(), check.Equals, expected.String())