mirror of
https://github.com/juanfont/headscale.git
synced 2024-11-26 17:03:06 +00:00
Merge branch 'main' into patch-1
This commit is contained in:
commit
bd7b5e97cb
66 changed files with 2981 additions and 1869 deletions
3
.github/workflows/build.yml
vendored
3
.github/workflows/build.yml
vendored
|
@ -18,12 +18,11 @@ jobs:
|
||||||
- name: Setup Go
|
- name: Setup Go
|
||||||
uses: actions/setup-go@v2
|
uses: actions/setup-go@v2
|
||||||
with:
|
with:
|
||||||
go-version: "1.16.3"
|
go-version: "1.17.3"
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
go version
|
go version
|
||||||
go install golang.org/x/lint/golint@latest
|
|
||||||
sudo apt update
|
sudo apt update
|
||||||
sudo apt install -y make
|
sudo apt install -y make
|
||||||
|
|
||||||
|
|
39
.github/workflows/lint.yml
vendored
39
.github/workflows/lint.yml
vendored
|
@ -1,20 +1,37 @@
|
||||||
|
---
|
||||||
name: CI
|
name: CI
|
||||||
|
|
||||||
on: [push, pull_request]
|
on: [push, pull_request]
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
# The "build" workflow
|
golangci-lint:
|
||||||
lint:
|
|
||||||
# The type of runner that the job will run on
|
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
# Steps represent a sequence of tasks that will be executed as part of the job
|
|
||||||
steps:
|
steps:
|
||||||
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
|
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
|
|
||||||
# Install and run golangci-lint as a separate step, it's much faster this
|
- name: golangci-lint
|
||||||
# way because this action has caching. It'll get run again in `make lint`
|
uses: golangci/golangci-lint-action@v2
|
||||||
# below, but it's still much faster in the end than installing
|
with:
|
||||||
# golangci-lint manually in the `Run lint` step.
|
version: latest
|
||||||
- uses: golangci/golangci-lint-action@v2
|
|
||||||
|
prettier-lint:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- name: Prettify code
|
||||||
|
uses: creyD/prettier_action@v4.0
|
||||||
|
with:
|
||||||
|
prettier_options: >-
|
||||||
|
--check **/*.{ts,js,md,yaml,yml,sass,css,scss,html}
|
||||||
|
only_changed: false
|
||||||
|
dry: true
|
||||||
|
|
||||||
|
proto-lint:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- uses: bufbuild/buf-setup-action@v0.7.0
|
||||||
|
- uses: bufbuild/buf-lint-action@v1
|
||||||
|
with:
|
||||||
|
input: "proto"
|
||||||
|
|
|
@ -1,7 +1,53 @@
|
||||||
---
|
---
|
||||||
run:
|
run:
|
||||||
timeout: 5m
|
timeout: 10m
|
||||||
|
|
||||||
issues:
|
issues:
|
||||||
skip-dirs:
|
skip-dirs:
|
||||||
- gen
|
- gen
|
||||||
|
linters:
|
||||||
|
enable-all: true
|
||||||
|
disable:
|
||||||
|
- exhaustivestruct
|
||||||
|
- revive
|
||||||
|
- lll
|
||||||
|
- interfacer
|
||||||
|
- scopelint
|
||||||
|
- maligned
|
||||||
|
- golint
|
||||||
|
- gofmt
|
||||||
|
- gochecknoglobals
|
||||||
|
- gochecknoinits
|
||||||
|
- gocognit
|
||||||
|
- funlen
|
||||||
|
- exhaustivestruct
|
||||||
|
- tagliatelle
|
||||||
|
- godox
|
||||||
|
- ireturn
|
||||||
|
|
||||||
|
# In progress
|
||||||
|
- gocritic
|
||||||
|
|
||||||
|
# We should strive to enable these:
|
||||||
|
- wrapcheck
|
||||||
|
- dupl
|
||||||
|
- makezero
|
||||||
|
|
||||||
|
# We might want to enable this, but it might be a lot of work
|
||||||
|
- cyclop
|
||||||
|
- nestif
|
||||||
|
- wsl # might be incompatible with gofumpt
|
||||||
|
- testpackage
|
||||||
|
- paralleltest
|
||||||
|
|
||||||
|
linters-settings:
|
||||||
|
varnamelen:
|
||||||
|
ignore-type-assert-ok: true
|
||||||
|
ignore-map-index-ok: true
|
||||||
|
ignore-names:
|
||||||
|
- err
|
||||||
|
- db
|
||||||
|
- id
|
||||||
|
- ip
|
||||||
|
- ok
|
||||||
|
- c
|
||||||
|
|
15
Makefile
15
Makefile
|
@ -1,6 +1,14 @@
|
||||||
# Calculate version
|
# Calculate version
|
||||||
version = $(shell ./scripts/version-at-commit.sh)
|
version = $(shell ./scripts/version-at-commit.sh)
|
||||||
|
|
||||||
|
rwildcard=$(foreach d,$(wildcard $1*),$(call rwildcard,$d/,$2) $(filter $(subst *,%,$2),$d))
|
||||||
|
|
||||||
|
# GO_SOURCES = $(wildcard *.go)
|
||||||
|
# PROTO_SOURCES = $(wildcard **/*.proto)
|
||||||
|
GO_SOURCES = $(call rwildcard,,*.go)
|
||||||
|
PROTO_SOURCES = $(call rwildcard,,*.proto)
|
||||||
|
|
||||||
|
|
||||||
build:
|
build:
|
||||||
go build -ldflags "-s -w -X github.com/juanfont/headscale/cmd/headscale/cli.Version=$(version)" cmd/headscale/headscale.go
|
go build -ldflags "-s -w -X github.com/juanfont/headscale/cmd/headscale/cli.Version=$(version)" cmd/headscale/headscale.go
|
||||||
|
|
||||||
|
@ -19,7 +27,12 @@ coverprofile_html:
|
||||||
go tool cover -html=coverage.out
|
go tool cover -html=coverage.out
|
||||||
|
|
||||||
lint:
|
lint:
|
||||||
golangci-lint run --fix
|
golangci-lint run --fix --timeout 10m
|
||||||
|
|
||||||
|
fmt:
|
||||||
|
prettier --write '**/**.{ts,js,md,yaml,yml,sass,css,scss,html}'
|
||||||
|
golines --max-len=88 --base-formatter=gofumpt -w $(GO_SOURCES)
|
||||||
|
clang-format -style="{BasedOnStyle: Google, IndentWidth: 4, AlignConsecutiveDeclarations: true, AlignConsecutiveAssignments: true, ColumnLimit: 0}" -i $(PROTO_SOURCES)
|
||||||
|
|
||||||
proto-lint:
|
proto-lint:
|
||||||
cd proto/ && buf lint
|
cd proto/ && buf lint
|
||||||
|
|
21
README.md
21
README.md
|
@ -54,7 +54,6 @@ Suggestions/PRs welcomed!
|
||||||
|
|
||||||
Please have a look at the documentation under [`docs/`](docs/).
|
Please have a look at the documentation under [`docs/`](docs/).
|
||||||
|
|
||||||
|
|
||||||
## Disclaimer
|
## Disclaimer
|
||||||
|
|
||||||
1. We have nothing to do with Tailscale, or Tailscale Inc.
|
1. We have nothing to do with Tailscale, or Tailscale Inc.
|
||||||
|
@ -64,6 +63,23 @@ Please have a look at the documentation under [`docs/`](docs/).
|
||||||
|
|
||||||
To contribute to Headscale you would need the lastest version of [Go](https://golang.org) and [Buf](https://buf.build)(Protobuf generator).
|
To contribute to Headscale you would need the lastest version of [Go](https://golang.org) and [Buf](https://buf.build)(Protobuf generator).
|
||||||
|
|
||||||
|
### Code style
|
||||||
|
|
||||||
|
To ensure we have some consistency with a growing number of contributes, this project has adopted linting and style/formatting rules:
|
||||||
|
|
||||||
|
The **Go** code is linted with [`golangci-lint`](https://golangci-lint.run) and
|
||||||
|
formatted with [`golines`](https://github.com/segmentio/golines) (width 88) and
|
||||||
|
[`gofumpt`](https://github.com/mvdan/gofumpt).
|
||||||
|
Please configure your editor to run the tools while developing and make sure to
|
||||||
|
run `make lint` and `make fmt` before committing any code.
|
||||||
|
|
||||||
|
The **Proto** code is linted with [`buf`](https://docs.buf.build/lint/overview) and
|
||||||
|
formatted with [`clang-format`](https://clang.llvm.org/docs/ClangFormat.html).
|
||||||
|
|
||||||
|
The **rest** (markdown, yaml, etc) is formatted with [`prettier`](https://prettier.io).
|
||||||
|
|
||||||
|
Check out the `.golangci.yaml` and `Makefile` to see the specific configuration.
|
||||||
|
|
||||||
### Install development tools
|
### Install development tools
|
||||||
|
|
||||||
- Go
|
- Go
|
||||||
|
@ -81,6 +97,7 @@ Some parts of the project requires the generation of Go code from Protobuf (if c
|
||||||
```shell
|
```shell
|
||||||
make generate
|
make generate
|
||||||
```
|
```
|
||||||
|
|
||||||
**Note**: Please check in changes from `gen/` in a separate commit to make it easier to review.
|
**Note**: Please check in changes from `gen/` in a separate commit to make it easier to review.
|
||||||
|
|
||||||
To run the tests:
|
To run the tests:
|
||||||
|
@ -261,5 +278,3 @@ make build
|
||||||
</td>
|
</td>
|
||||||
</tr>
|
</tr>
|
||||||
</table>
|
</table>
|
||||||
|
|
||||||
|
|
||||||
|
|
149
acls.go
149
acls.go
|
@ -9,23 +9,30 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
|
||||||
"github.com/tailscale/hujson"
|
"github.com/tailscale/hujson"
|
||||||
"inet.af/netaddr"
|
"inet.af/netaddr"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
errorEmptyPolicy = Error("empty policy")
|
errEmptyPolicy = Error("empty policy")
|
||||||
errorInvalidAction = Error("invalid action")
|
errInvalidAction = Error("invalid action")
|
||||||
errorInvalidUserSection = Error("invalid user section")
|
errInvalidUserSection = Error("invalid user section")
|
||||||
errorInvalidGroup = Error("invalid group")
|
errInvalidGroup = Error("invalid group")
|
||||||
errorInvalidTag = Error("invalid tag")
|
errInvalidTag = Error("invalid tag")
|
||||||
errorInvalidNamespace = Error("invalid namespace")
|
errInvalidNamespace = Error("invalid namespace")
|
||||||
errorInvalidPortFormat = Error("invalid port format")
|
errInvalidPortFormat = Error("invalid port format")
|
||||||
)
|
)
|
||||||
|
|
||||||
// LoadACLPolicy loads the ACL policy from the specify path, and generates the ACL rules
|
const (
|
||||||
|
Base10 = 10
|
||||||
|
BitSize16 = 16
|
||||||
|
portRangeBegin = 0
|
||||||
|
portRangeEnd = 65535
|
||||||
|
expectedTokenItems = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
// LoadACLPolicy loads the ACL policy from the specify path, and generates the ACL rules.
|
||||||
func (h *Headscale) LoadACLPolicy(path string) error {
|
func (h *Headscale) LoadACLPolicy(path string) error {
|
||||||
policyFile, err := os.Open(path)
|
policyFile, err := os.Open(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -34,23 +41,23 @@ func (h *Headscale) LoadACLPolicy(path string) error {
|
||||||
defer policyFile.Close()
|
defer policyFile.Close()
|
||||||
|
|
||||||
var policy ACLPolicy
|
var policy ACLPolicy
|
||||||
b, err := io.ReadAll(policyFile)
|
policyBytes, err := io.ReadAll(policyFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
ast, err := hujson.Parse(b)
|
ast, err := hujson.Parse(policyBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
ast.Standardize()
|
ast.Standardize()
|
||||||
b = ast.Pack()
|
policyBytes = ast.Pack()
|
||||||
err = json.Unmarshal(b, &policy)
|
err = json.Unmarshal(policyBytes, &policy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if policy.IsZero() {
|
if policy.IsZero() {
|
||||||
return errorEmptyPolicy
|
return errEmptyPolicy
|
||||||
}
|
}
|
||||||
|
|
||||||
h.aclPolicy = &policy
|
h.aclPolicy = &policy
|
||||||
|
@ -59,37 +66,40 @@ func (h *Headscale) LoadACLPolicy(path string) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
h.aclRules = rules
|
h.aclRules = rules
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) {
|
func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) {
|
||||||
rules := []tailcfg.FilterRule{}
|
rules := []tailcfg.FilterRule{}
|
||||||
|
|
||||||
for i, a := range h.aclPolicy.ACLs {
|
for index, acl := range h.aclPolicy.ACLs {
|
||||||
if a.Action != "accept" {
|
if acl.Action != "accept" {
|
||||||
return nil, errorInvalidAction
|
return nil, errInvalidAction
|
||||||
}
|
}
|
||||||
|
|
||||||
r := tailcfg.FilterRule{}
|
filterRule := tailcfg.FilterRule{}
|
||||||
|
|
||||||
srcIPs := []string{}
|
srcIPs := []string{}
|
||||||
for j, u := range a.Users {
|
for innerIndex, user := range acl.Users {
|
||||||
srcs, err := h.generateACLPolicySrcIP(u)
|
srcs, err := h.generateACLPolicySrcIP(user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Msgf("Error parsing ACL %d, User %d", i, j)
|
Msgf("Error parsing ACL %d, User %d", index, innerIndex)
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
srcIPs = append(srcIPs, srcs...)
|
srcIPs = append(srcIPs, srcs...)
|
||||||
}
|
}
|
||||||
r.SrcIPs = srcIPs
|
filterRule.SrcIPs = srcIPs
|
||||||
|
|
||||||
destPorts := []tailcfg.NetPortRange{}
|
destPorts := []tailcfg.NetPortRange{}
|
||||||
for j, d := range a.Ports {
|
for innerIndex, ports := range acl.Ports {
|
||||||
dests, err := h.generateACLPolicyDestPorts(d)
|
dests, err := h.generateACLPolicyDestPorts(ports)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Msgf("Error parsing ACL %d, Port %d", i, j)
|
Msgf("Error parsing ACL %d, Port %d", index, innerIndex)
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
destPorts = append(destPorts, dests...)
|
destPorts = append(destPorts, dests...)
|
||||||
|
@ -108,10 +118,12 @@ func (h *Headscale) generateACLPolicySrcIP(u string) ([]string, error) {
|
||||||
return h.expandAlias(u)
|
return h.expandAlias(u)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) generateACLPolicyDestPorts(d string) ([]tailcfg.NetPortRange, error) {
|
func (h *Headscale) generateACLPolicyDestPorts(
|
||||||
|
d string,
|
||||||
|
) ([]tailcfg.NetPortRange, error) {
|
||||||
tokens := strings.Split(d, ":")
|
tokens := strings.Split(d, ":")
|
||||||
if len(tokens) < 2 || len(tokens) > 3 {
|
if len(tokens) < expectedTokenItems || len(tokens) > 3 {
|
||||||
return nil, errorInvalidPortFormat
|
return nil, errInvalidPortFormat
|
||||||
}
|
}
|
||||||
|
|
||||||
var alias string
|
var alias string
|
||||||
|
@ -121,7 +133,7 @@ func (h *Headscale) generateACLPolicyDestPorts(d string) ([]tailcfg.NetPortRange
|
||||||
// tag:montreal-webserver:80,443
|
// tag:montreal-webserver:80,443
|
||||||
// tag:api-server:443
|
// tag:api-server:443
|
||||||
// example-host-1:*
|
// example-host-1:*
|
||||||
if len(tokens) == 2 {
|
if len(tokens) == expectedTokenItems {
|
||||||
alias = tokens[0]
|
alias = tokens[0]
|
||||||
} else {
|
} else {
|
||||||
alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1])
|
alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1])
|
||||||
|
@ -146,34 +158,36 @@ func (h *Headscale) generateACLPolicyDestPorts(d string) ([]tailcfg.NetPortRange
|
||||||
dests = append(dests, pr)
|
dests = append(dests, pr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return dests, nil
|
return dests, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) expandAlias(s string) ([]string, error) {
|
func (h *Headscale) expandAlias(alias string) ([]string, error) {
|
||||||
if s == "*" {
|
if alias == "*" {
|
||||||
return []string{"*"}, nil
|
return []string{"*"}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.HasPrefix(s, "group:") {
|
if strings.HasPrefix(alias, "group:") {
|
||||||
if _, ok := h.aclPolicy.Groups[s]; !ok {
|
if _, ok := h.aclPolicy.Groups[alias]; !ok {
|
||||||
return nil, errorInvalidGroup
|
return nil, errInvalidGroup
|
||||||
}
|
}
|
||||||
ips := []string{}
|
ips := []string{}
|
||||||
for _, n := range h.aclPolicy.Groups[s] {
|
for _, n := range h.aclPolicy.Groups[alias] {
|
||||||
nodes, err := h.ListMachinesInNamespace(n)
|
nodes, err := h.ListMachinesInNamespace(n)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errorInvalidNamespace
|
return nil, errInvalidNamespace
|
||||||
}
|
}
|
||||||
for _, node := range nodes {
|
for _, node := range nodes {
|
||||||
ips = append(ips, node.IPAddress)
|
ips = append(ips, node.IPAddress)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return ips, nil
|
return ips, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.HasPrefix(s, "tag:") {
|
if strings.HasPrefix(alias, "tag:") {
|
||||||
if _, ok := h.aclPolicy.TagOwners[s]; !ok {
|
if _, ok := h.aclPolicy.TagOwners[alias]; !ok {
|
||||||
return nil, errorInvalidTag
|
return nil, errInvalidTag
|
||||||
}
|
}
|
||||||
|
|
||||||
// This will have HORRIBLE performance.
|
// This will have HORRIBLE performance.
|
||||||
|
@ -183,10 +197,10 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
ips := []string{}
|
ips := []string{}
|
||||||
for _, m := range machines {
|
for _, machine := range machines {
|
||||||
hostinfo := tailcfg.Hostinfo{}
|
hostinfo := tailcfg.Hostinfo{}
|
||||||
if len(m.HostInfo) != 0 {
|
if len(machine.HostInfo) != 0 {
|
||||||
hi, err := m.HostInfo.MarshalJSON()
|
hi, err := machine.HostInfo.MarshalJSON()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -197,17 +211,19 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
|
||||||
|
|
||||||
// FIXME: Check TagOwners allows this
|
// FIXME: Check TagOwners allows this
|
||||||
for _, t := range hostinfo.RequestTags {
|
for _, t := range hostinfo.RequestTags {
|
||||||
if s[4:] == t {
|
if alias[4:] == t {
|
||||||
ips = append(ips, m.IPAddress)
|
ips = append(ips, machine.IPAddress)
|
||||||
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return ips, nil
|
return ips, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
n, err := h.GetNamespace(s)
|
n, err := h.GetNamespace(alias)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
nodes, err := h.ListMachinesInNamespace(n.Name)
|
nodes, err := h.ListMachinesInNamespace(n.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -217,49 +233,54 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
|
||||||
for _, n := range nodes {
|
for _, n := range nodes {
|
||||||
ips = append(ips, n.IPAddress)
|
ips = append(ips, n.IPAddress)
|
||||||
}
|
}
|
||||||
|
|
||||||
return ips, nil
|
return ips, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if h, ok := h.aclPolicy.Hosts[s]; ok {
|
if h, ok := h.aclPolicy.Hosts[alias]; ok {
|
||||||
return []string{h.String()}, nil
|
return []string{h.String()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ip, err := netaddr.ParseIP(s)
|
ip, err := netaddr.ParseIP(alias)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return []string{ip.String()}, nil
|
return []string{ip.String()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
cidr, err := netaddr.ParseIPPrefix(s)
|
cidr, err := netaddr.ParseIPPrefix(alias)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return []string{cidr.String()}, nil
|
return []string{cidr.String()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, errorInvalidUserSection
|
return nil, errInvalidUserSection
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) expandPorts(s string) (*[]tailcfg.PortRange, error) {
|
func (h *Headscale) expandPorts(portsStr string) (*[]tailcfg.PortRange, error) {
|
||||||
if s == "*" {
|
if portsStr == "*" {
|
||||||
return &[]tailcfg.PortRange{{First: 0, Last: 65535}}, nil
|
return &[]tailcfg.PortRange{
|
||||||
|
{First: portRangeBegin, Last: portRangeEnd},
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ports := []tailcfg.PortRange{}
|
ports := []tailcfg.PortRange{}
|
||||||
for _, p := range strings.Split(s, ",") {
|
for _, portStr := range strings.Split(portsStr, ",") {
|
||||||
rang := strings.Split(p, "-")
|
rang := strings.Split(portStr, "-")
|
||||||
if len(rang) == 1 {
|
switch len(rang) {
|
||||||
pi, err := strconv.ParseUint(rang[0], 10, 16)
|
case 1:
|
||||||
|
port, err := strconv.ParseUint(rang[0], Base10, BitSize16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
ports = append(ports, tailcfg.PortRange{
|
ports = append(ports, tailcfg.PortRange{
|
||||||
First: uint16(pi),
|
First: uint16(port),
|
||||||
Last: uint16(pi),
|
Last: uint16(port),
|
||||||
})
|
})
|
||||||
} else if len(rang) == 2 {
|
|
||||||
start, err := strconv.ParseUint(rang[0], 10, 16)
|
case expectedTokenItems:
|
||||||
|
start, err := strconv.ParseUint(rang[0], Base10, BitSize16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
last, err := strconv.ParseUint(rang[1], 10, 16)
|
last, err := strconv.ParseUint(rang[1], Base10, BitSize16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -267,9 +288,11 @@ func (h *Headscale) expandPorts(s string) (*[]tailcfg.PortRange, error) {
|
||||||
First: uint16(start),
|
First: uint16(start),
|
||||||
Last: uint16(last),
|
Last: uint16(last),
|
||||||
})
|
})
|
||||||
} else {
|
|
||||||
return nil, errorInvalidPortFormat
|
default:
|
||||||
|
return nil, errInvalidPortFormat
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &ports, nil
|
return &ports, nil
|
||||||
}
|
}
|
||||||
|
|
76
acls_test.go
76
acls_test.go
|
@ -5,54 +5,58 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Suite) TestWrongPath(c *check.C) {
|
func (s *Suite) TestWrongPath(c *check.C) {
|
||||||
err := h.LoadACLPolicy("asdfg")
|
err := app.LoadACLPolicy("asdfg")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestBrokenHuJson(c *check.C) {
|
func (s *Suite) TestBrokenHuJson(c *check.C) {
|
||||||
err := h.LoadACLPolicy("./tests/acls/broken.hujson")
|
err := app.LoadACLPolicy("./tests/acls/broken.hujson")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestInvalidPolicyHuson(c *check.C) {
|
func (s *Suite) TestInvalidPolicyHuson(c *check.C) {
|
||||||
err := h.LoadACLPolicy("./tests/acls/invalid.hujson")
|
err := app.LoadACLPolicy("./tests/acls/invalid.hujson")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
c.Assert(err, check.Equals, errorEmptyPolicy)
|
c.Assert(err, check.Equals, errEmptyPolicy)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestParseHosts(c *check.C) {
|
func (s *Suite) TestParseHosts(c *check.C) {
|
||||||
var hs Hosts
|
var hosts Hosts
|
||||||
err := hs.UnmarshalJSON([]byte(`{"example-host-1": "100.100.100.100","example-host-2": "100.100.101.100/24"}`))
|
err := hosts.UnmarshalJSON(
|
||||||
c.Assert(hs, check.NotNil)
|
[]byte(
|
||||||
|
`{"example-host-1": "100.100.100.100","example-host-2": "100.100.101.100/24"}`,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
c.Assert(hosts, check.NotNil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestParseInvalidCIDR(c *check.C) {
|
func (s *Suite) TestParseInvalidCIDR(c *check.C) {
|
||||||
var hs Hosts
|
var hosts Hosts
|
||||||
err := hs.UnmarshalJSON([]byte(`{"example-host-1": "100.100.100.100/42"}`))
|
err := hosts.UnmarshalJSON([]byte(`{"example-host-1": "100.100.100.100/42"}`))
|
||||||
c.Assert(hs, check.IsNil)
|
c.Assert(hosts, check.IsNil)
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestRuleInvalidGeneration(c *check.C) {
|
func (s *Suite) TestRuleInvalidGeneration(c *check.C) {
|
||||||
err := h.LoadACLPolicy("./tests/acls/acl_policy_invalid.hujson")
|
err := app.LoadACLPolicy("./tests/acls/acl_policy_invalid.hujson")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestBasicRule(c *check.C) {
|
func (s *Suite) TestBasicRule(c *check.C) {
|
||||||
err := h.LoadACLPolicy("./tests/acls/acl_policy_basic_1.hujson")
|
err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_1.hujson")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
rules, err := h.generateACLRules()
|
rules, err := app.generateACLRules()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(rules, check.NotNil)
|
c.Assert(rules, check.NotNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestPortRange(c *check.C) {
|
func (s *Suite) TestPortRange(c *check.C) {
|
||||||
err := h.LoadACLPolicy("./tests/acls/acl_policy_basic_range.hujson")
|
err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_range.hujson")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
rules, err := h.generateACLRules()
|
rules, err := app.generateACLRules()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(rules, check.NotNil)
|
c.Assert(rules, check.NotNil)
|
||||||
|
|
||||||
|
@ -63,10 +67,10 @@ func (s *Suite) TestPortRange(c *check.C) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestPortWildcard(c *check.C) {
|
func (s *Suite) TestPortWildcard(c *check.C) {
|
||||||
err := h.LoadACLPolicy("./tests/acls/acl_policy_basic_wildcards.hujson")
|
err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_wildcards.hujson")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
rules, err := h.generateACLRules()
|
rules, err := app.generateACLRules()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(rules, check.NotNil)
|
c.Assert(rules, check.NotNil)
|
||||||
|
|
||||||
|
@ -79,33 +83,35 @@ func (s *Suite) TestPortWildcard(c *check.C) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestPortNamespace(c *check.C) {
|
func (s *Suite) TestPortNamespace(c *check.C) {
|
||||||
n, err := h.CreateNamespace("testnamespace")
|
namespace, err := app.CreateNamespace("testnamespace")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = h.GetMachine("testnamespace", "testmachine")
|
_, err = app.GetMachine("testnamespace", "testmachine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
ip, _ := h.getAvailableIP()
|
ip, _ := app.getAvailableIP()
|
||||||
m := Machine{
|
machine := Machine{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: "foo",
|
MachineKey: "foo",
|
||||||
NodeKey: "bar",
|
NodeKey: "bar",
|
||||||
DiscoKey: "faa",
|
DiscoKey: "faa",
|
||||||
Name: "testmachine",
|
Name: "testmachine",
|
||||||
NamespaceID: n.ID,
|
NamespaceID: namespace.ID,
|
||||||
Registered: true,
|
Registered: true,
|
||||||
RegisterMethod: "authKey",
|
RegisterMethod: "authKey",
|
||||||
IPAddress: ip.String(),
|
IPAddress: ip.String(),
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
h.db.Save(&m)
|
app.db.Save(&machine)
|
||||||
|
|
||||||
err = h.LoadACLPolicy("./tests/acls/acl_policy_basic_namespace_as_user.hujson")
|
err = app.LoadACLPolicy(
|
||||||
|
"./tests/acls/acl_policy_basic_namespace_as_user.hujson",
|
||||||
|
)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
rules, err := h.generateACLRules()
|
rules, err := app.generateACLRules()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(rules, check.NotNil)
|
c.Assert(rules, check.NotNil)
|
||||||
|
|
||||||
|
@ -119,33 +125,33 @@ func (s *Suite) TestPortNamespace(c *check.C) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestPortGroup(c *check.C) {
|
func (s *Suite) TestPortGroup(c *check.C) {
|
||||||
n, err := h.CreateNamespace("testnamespace")
|
namespace, err := app.CreateNamespace("testnamespace")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = h.GetMachine("testnamespace", "testmachine")
|
_, err = app.GetMachine("testnamespace", "testmachine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
ip, _ := h.getAvailableIP()
|
ip, _ := app.getAvailableIP()
|
||||||
m := Machine{
|
machine := Machine{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: "foo",
|
MachineKey: "foo",
|
||||||
NodeKey: "bar",
|
NodeKey: "bar",
|
||||||
DiscoKey: "faa",
|
DiscoKey: "faa",
|
||||||
Name: "testmachine",
|
Name: "testmachine",
|
||||||
NamespaceID: n.ID,
|
NamespaceID: namespace.ID,
|
||||||
Registered: true,
|
Registered: true,
|
||||||
RegisterMethod: "authKey",
|
RegisterMethod: "authKey",
|
||||||
IPAddress: ip.String(),
|
IPAddress: ip.String(),
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
h.db.Save(&m)
|
app.db.Save(&machine)
|
||||||
|
|
||||||
err = h.LoadACLPolicy("./tests/acls/acl_policy_basic_groups.hujson")
|
err = app.LoadACLPolicy("./tests/acls/acl_policy_basic_groups.hujson")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
rules, err := h.generateACLRules()
|
rules, err := app.generateACLRules()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(rules, check.NotNil)
|
c.Assert(rules, check.NotNil)
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ import (
|
||||||
"inet.af/netaddr"
|
"inet.af/netaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ACLPolicy represents a Tailscale ACL Policy
|
// ACLPolicy represents a Tailscale ACL Policy.
|
||||||
type ACLPolicy struct {
|
type ACLPolicy struct {
|
||||||
Groups Groups `json:"Groups"`
|
Groups Groups `json:"Groups"`
|
||||||
Hosts Hosts `json:"Hosts"`
|
Hosts Hosts `json:"Hosts"`
|
||||||
|
@ -17,61 +17,63 @@ type ACLPolicy struct {
|
||||||
Tests []ACLTest `json:"Tests"`
|
Tests []ACLTest `json:"Tests"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ACL is a basic rule for the ACL Policy
|
// ACL is a basic rule for the ACL Policy.
|
||||||
type ACL struct {
|
type ACL struct {
|
||||||
Action string `json:"Action"`
|
Action string `json:"Action"`
|
||||||
Users []string `json:"Users"`
|
Users []string `json:"Users"`
|
||||||
Ports []string `json:"Ports"`
|
Ports []string `json:"Ports"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Groups references a series of alias in the ACL rules
|
// Groups references a series of alias in the ACL rules.
|
||||||
type Groups map[string][]string
|
type Groups map[string][]string
|
||||||
|
|
||||||
// Hosts are alias for IP addresses or subnets
|
// Hosts are alias for IP addresses or subnets.
|
||||||
type Hosts map[string]netaddr.IPPrefix
|
type Hosts map[string]netaddr.IPPrefix
|
||||||
|
|
||||||
// TagOwners specify what users (namespaces?) are allow to use certain tags
|
// TagOwners specify what users (namespaces?) are allow to use certain tags.
|
||||||
type TagOwners map[string][]string
|
type TagOwners map[string][]string
|
||||||
|
|
||||||
// ACLTest is not implemented, but should be use to check if a certain rule is allowed
|
// ACLTest is not implemented, but should be use to check if a certain rule is allowed.
|
||||||
type ACLTest struct {
|
type ACLTest struct {
|
||||||
User string `json:"User"`
|
User string `json:"User"`
|
||||||
Allow []string `json:"Allow"`
|
Allow []string `json:"Allow"`
|
||||||
Deny []string `json:"Deny,omitempty"`
|
Deny []string `json:"Deny,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnmarshalJSON allows to parse the Hosts directly into netaddr objects
|
// UnmarshalJSON allows to parse the Hosts directly into netaddr objects.
|
||||||
func (h *Hosts) UnmarshalJSON(data []byte) error {
|
func (hosts *Hosts) UnmarshalJSON(data []byte) error {
|
||||||
hosts := Hosts{}
|
newHosts := Hosts{}
|
||||||
hs := make(map[string]string)
|
hostIPPrefixMap := make(map[string]string)
|
||||||
ast, err := hujson.Parse(data)
|
ast, err := hujson.Parse(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
ast.Standardize()
|
ast.Standardize()
|
||||||
data = ast.Pack()
|
data = ast.Pack()
|
||||||
err = json.Unmarshal(data, &hs)
|
err = json.Unmarshal(data, &hostIPPrefixMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for k, v := range hs {
|
for host, prefixStr := range hostIPPrefixMap {
|
||||||
if !strings.Contains(v, "/") {
|
if !strings.Contains(prefixStr, "/") {
|
||||||
v = v + "/32"
|
prefixStr += "/32"
|
||||||
}
|
}
|
||||||
prefix, err := netaddr.ParseIPPrefix(v)
|
prefix, err := netaddr.ParseIPPrefix(prefixStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
hosts[k] = prefix
|
newHosts[host] = prefix
|
||||||
}
|
}
|
||||||
*h = hosts
|
*hosts = newHosts
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsZero is perhaps a bit naive here
|
// IsZero is perhaps a bit naive here.
|
||||||
func (p ACLPolicy) IsZero() bool {
|
func (policy ACLPolicy) IsZero() bool {
|
||||||
if len(p.Groups) == 0 && len(p.Hosts) == 0 && len(p.ACLs) == 0 {
|
if len(policy.Groups) == 0 && len(policy.Hosts) == 0 && len(policy.ACLs) == 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
297
api.go
297
api.go
|
@ -10,31 +10,37 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/klauspost/compress/zstd"
|
"github.com/klauspost/compress/zstd"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/wgkey"
|
"tailscale.com/types/wgkey"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const reservedResponseHeaderSize = 4
|
||||||
|
|
||||||
// KeyHandler provides the Headscale pub key
|
// KeyHandler provides the Headscale pub key
|
||||||
// Listens in /key
|
// Listens in /key.
|
||||||
func (h *Headscale) KeyHandler(c *gin.Context) {
|
func (h *Headscale) KeyHandler(ctx *gin.Context) {
|
||||||
c.Data(200, "text/plain; charset=utf-8", []byte(h.publicKey.HexString()))
|
ctx.Data(
|
||||||
|
http.StatusOK,
|
||||||
|
"text/plain; charset=utf-8",
|
||||||
|
[]byte(h.publicKey.HexString()),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterWebAPI shows a simple message in the browser to point to the CLI
|
// RegisterWebAPI shows a simple message in the browser to point to the CLI
|
||||||
// Listens in /register
|
// Listens in /register.
|
||||||
func (h *Headscale) RegisterWebAPI(c *gin.Context) {
|
func (h *Headscale) RegisterWebAPI(ctx *gin.Context) {
|
||||||
mKeyStr := c.Query("key")
|
machineKeyStr := ctx.Query("key")
|
||||||
if mKeyStr == "" {
|
if machineKeyStr == "" {
|
||||||
c.String(http.StatusBadRequest, "Wrong params")
|
ctx.String(http.StatusBadRequest, "Wrong params")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
|
ctx.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
|
||||||
<html>
|
<html>
|
||||||
<body>
|
<body>
|
||||||
<h1>headscale</h1>
|
<h1>headscale</h1>
|
||||||
|
@ -51,43 +57,45 @@ func (h *Headscale) RegisterWebAPI(c *gin.Context) {
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
|
|
||||||
`, mKeyStr)))
|
`, machineKeyStr)))
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegistrationHandler handles the actual registration process of a machine
|
// RegistrationHandler handles the actual registration process of a machine
|
||||||
// Endpoint /machine/:id
|
// Endpoint /machine/:id.
|
||||||
func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
|
||||||
body, _ := io.ReadAll(c.Request.Body)
|
body, _ := io.ReadAll(ctx.Request.Body)
|
||||||
mKeyStr := c.Param("id")
|
machineKeyStr := ctx.Param("id")
|
||||||
mKey, err := wgkey.ParseHex(mKeyStr)
|
machineKey, err := wgkey.ParseHex(machineKeyStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "Registration").
|
Str("handler", "Registration").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot parse machine key")
|
Msg("Cannot parse machine key")
|
||||||
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
|
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
|
||||||
c.String(http.StatusInternalServerError, "Sad!")
|
ctx.String(http.StatusInternalServerError, "Sad!")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
req := tailcfg.RegisterRequest{}
|
req := tailcfg.RegisterRequest{}
|
||||||
err = decode(body, &req, &mKey, h.privateKey)
|
err = decode(body, &req, &machineKey, h.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "Registration").
|
Str("handler", "Registration").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot decode message")
|
Msg("Cannot decode message")
|
||||||
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
|
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
|
||||||
c.String(http.StatusInternalServerError, "Very sad!")
|
ctx.String(http.StatusInternalServerError, "Very sad!")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
m, err := h.GetMachineByMachineKey(mKey.HexString())
|
machine, err := h.GetMachineByMachineKey(machineKey.HexString())
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine")
|
log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine")
|
||||||
newMachine := Machine{
|
newMachine := Machine{
|
||||||
Expiry: &time.Time{},
|
Expiry: &time.Time{},
|
||||||
MachineKey: mKey.HexString(),
|
MachineKey: machineKey.HexString(),
|
||||||
Name: req.Hostinfo.Hostname,
|
Name: req.Hostinfo.Hostname,
|
||||||
}
|
}
|
||||||
if err := h.db.Create(&newMachine).Error; err != nil {
|
if err := h.db.Create(&newMachine).Error; err != nil {
|
||||||
|
@ -95,88 +103,96 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
||||||
Str("handler", "Registration").
|
Str("handler", "Registration").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Could not create row")
|
Msg("Could not create row")
|
||||||
machineRegistrations.WithLabelValues("unknown", "web", "error", m.Namespace.Name).Inc()
|
machineRegistrations.WithLabelValues("unknown", "web", "error", machine.Namespace.Name).
|
||||||
|
Inc()
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
m = &newMachine
|
machine = &newMachine
|
||||||
}
|
}
|
||||||
|
|
||||||
if !m.Registered && req.Auth.AuthKey != "" {
|
if !machine.Registered && req.Auth.AuthKey != "" {
|
||||||
h.handleAuthKey(c, h.db, mKey, req, *m)
|
h.handleAuthKey(ctx, h.db, machineKey, req, *machine)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := tailcfg.RegisterResponse{}
|
resp := tailcfg.RegisterResponse{}
|
||||||
|
|
||||||
// We have the updated key!
|
// We have the updated key!
|
||||||
if m.NodeKey == wgkey.Key(req.NodeKey).HexString() {
|
if machine.NodeKey == wgkey.Key(req.NodeKey).HexString() {
|
||||||
|
|
||||||
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
|
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
|
||||||
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
|
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
|
||||||
if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) {
|
if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) {
|
||||||
log.Info().
|
log.Info().
|
||||||
Str("handler", "Registration").
|
Str("handler", "Registration").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Client requested logout")
|
Msg("Client requested logout")
|
||||||
|
|
||||||
m.Expiry = &req.Expiry // save the expiry so that the machine is marked as expired
|
machine.Expiry = &req.Expiry // save the expiry so that the machine is marked as expired
|
||||||
h.db.Save(&m)
|
h.db.Save(&machine)
|
||||||
|
|
||||||
resp.AuthURL = ""
|
resp.AuthURL = ""
|
||||||
resp.MachineAuthorized = false
|
resp.MachineAuthorized = false
|
||||||
resp.User = *m.Namespace.toUser()
|
resp.User = *machine.Namespace.toUser()
|
||||||
respBody, err := encode(resp, &mKey, h.privateKey)
|
respBody, err := encode(resp, &machineKey, h.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "Registration").
|
Str("handler", "Registration").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot encode message")
|
Msg("Cannot encode message")
|
||||||
c.String(http.StatusInternalServerError, "")
|
ctx.String(http.StatusInternalServerError, "")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Data(200, "application/json; charset=utf-8", respBody)
|
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.Registered && m.Expiry.UTC().After(now) {
|
if machine.Registered && machine.Expiry.UTC().After(now) {
|
||||||
// The machine registration is valid, respond with redirect to /map
|
// The machine registration is valid, respond with redirect to /map
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Str("handler", "Registration").
|
Str("handler", "Registration").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Client is registered and we have the current NodeKey. All clear to /map")
|
Msg("Client is registered and we have the current NodeKey. All clear to /map")
|
||||||
|
|
||||||
resp.AuthURL = ""
|
resp.AuthURL = ""
|
||||||
resp.MachineAuthorized = true
|
resp.MachineAuthorized = true
|
||||||
resp.User = *m.Namespace.toUser()
|
resp.User = *machine.Namespace.toUser()
|
||||||
resp.Login = *m.Namespace.toLogin()
|
resp.Login = *machine.Namespace.toLogin()
|
||||||
|
|
||||||
respBody, err := encode(resp, &mKey, h.privateKey)
|
respBody, err := encode(resp, &machineKey, h.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "Registration").
|
Str("handler", "Registration").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot encode message")
|
Msg("Cannot encode message")
|
||||||
machineRegistrations.WithLabelValues("update", "web", "error", m.Namespace.Name).Inc()
|
machineRegistrations.WithLabelValues("update", "web", "error", machine.Namespace.Name).
|
||||||
c.String(http.StatusInternalServerError, "")
|
Inc()
|
||||||
|
ctx.String(http.StatusInternalServerError, "")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
machineRegistrations.WithLabelValues("update", "web", "success", m.Namespace.Name).Inc()
|
machineRegistrations.WithLabelValues("update", "web", "success", machine.Namespace.Name).
|
||||||
c.Data(200, "application/json; charset=utf-8", respBody)
|
Inc()
|
||||||
|
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// The client has registered before, but has expired
|
// The client has registered before, but has expired
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Str("handler", "Registration").
|
Str("handler", "Registration").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Machine registration has expired. Sending a authurl to register")
|
Msg("Machine registration has expired. Sending a authurl to register")
|
||||||
|
|
||||||
if h.cfg.OIDC.Issuer != "" {
|
if h.cfg.OIDC.Issuer != "" {
|
||||||
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
|
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
|
||||||
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
|
strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString())
|
||||||
} else {
|
} else {
|
||||||
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
|
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
|
||||||
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
|
strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString())
|
||||||
}
|
}
|
||||||
|
|
||||||
// When a client connects, it may request a specific expiry time in its
|
// When a client connects, it may request a specific expiry time in its
|
||||||
|
@ -185,102 +201,120 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
||||||
// into two steps (which cant pass arbitrary data between them easily) and needs to be
|
// into two steps (which cant pass arbitrary data between them easily) and needs to be
|
||||||
// retrieved again after the user has authenticated. After the authentication flow
|
// retrieved again after the user has authenticated. After the authentication flow
|
||||||
// completes, RequestedExpiry is copied into Expiry.
|
// completes, RequestedExpiry is copied into Expiry.
|
||||||
m.RequestedExpiry = &req.Expiry
|
machine.RequestedExpiry = &req.Expiry
|
||||||
|
|
||||||
h.db.Save(&m)
|
h.db.Save(&machine)
|
||||||
|
|
||||||
respBody, err := encode(resp, &mKey, h.privateKey)
|
respBody, err := encode(resp, &machineKey, h.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "Registration").
|
Str("handler", "Registration").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot encode message")
|
Msg("Cannot encode message")
|
||||||
machineRegistrations.WithLabelValues("new", "web", "error", m.Namespace.Name).Inc()
|
machineRegistrations.WithLabelValues("new", "web", "error", machine.Namespace.Name).
|
||||||
c.String(http.StatusInternalServerError, "")
|
Inc()
|
||||||
|
ctx.String(http.StatusInternalServerError, "")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
machineRegistrations.WithLabelValues("new", "web", "success", m.Namespace.Name).Inc()
|
machineRegistrations.WithLabelValues("new", "web", "success", machine.Namespace.Name).
|
||||||
c.Data(200, "application/json; charset=utf-8", respBody)
|
Inc()
|
||||||
|
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
|
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
|
||||||
if m.NodeKey == wgkey.Key(req.OldNodeKey).HexString() && m.Expiry.UTC().After(now) {
|
if machine.NodeKey == wgkey.Key(req.OldNodeKey).HexString() &&
|
||||||
|
machine.Expiry.UTC().After(now) {
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Str("handler", "Registration").
|
Str("handler", "Registration").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("We have the OldNodeKey in the database. This is a key refresh")
|
Msg("We have the OldNodeKey in the database. This is a key refresh")
|
||||||
m.NodeKey = wgkey.Key(req.NodeKey).HexString()
|
machine.NodeKey = wgkey.Key(req.NodeKey).HexString()
|
||||||
h.db.Save(&m)
|
h.db.Save(&machine)
|
||||||
|
|
||||||
resp.AuthURL = ""
|
resp.AuthURL = ""
|
||||||
resp.User = *m.Namespace.toUser()
|
resp.User = *machine.Namespace.toUser()
|
||||||
respBody, err := encode(resp, &mKey, h.privateKey)
|
respBody, err := encode(resp, &machineKey, h.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "Registration").
|
Str("handler", "Registration").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot encode message")
|
Msg("Cannot encode message")
|
||||||
c.String(http.StatusInternalServerError, "Extremely sad!")
|
ctx.String(http.StatusInternalServerError, "Extremely sad!")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Data(200, "application/json; charset=utf-8", respBody)
|
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// The machine registration is new, redirect the client to the registration URL
|
// The machine registration is new, redirect the client to the registration URL
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Str("handler", "Registration").
|
Str("handler", "Registration").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("The node is sending us a new NodeKey, sending auth url")
|
Msg("The node is sending us a new NodeKey, sending auth url")
|
||||||
if h.cfg.OIDC.Issuer != "" {
|
if h.cfg.OIDC.Issuer != "" {
|
||||||
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
|
resp.AuthURL = fmt.Sprintf(
|
||||||
|
"%s/oidc/register/%s",
|
||||||
|
strings.TrimSuffix(h.cfg.ServerURL, "/"),
|
||||||
|
machineKey.HexString(),
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
|
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
|
||||||
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
|
strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString())
|
||||||
}
|
}
|
||||||
|
|
||||||
// save the requested expiry time for retrieval later in the authentication flow
|
// save the requested expiry time for retrieval later in the authentication flow
|
||||||
m.RequestedExpiry = &req.Expiry
|
machine.RequestedExpiry = &req.Expiry
|
||||||
m.NodeKey = wgkey.Key(req.NodeKey).HexString() // save the NodeKey
|
machine.NodeKey = wgkey.Key(req.NodeKey).HexString() // save the NodeKey
|
||||||
h.db.Save(&m)
|
h.db.Save(&machine)
|
||||||
|
|
||||||
respBody, err := encode(resp, &mKey, h.privateKey)
|
respBody, err := encode(resp, &machineKey, h.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "Registration").
|
Str("handler", "Registration").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot encode message")
|
Msg("Cannot encode message")
|
||||||
c.String(http.StatusInternalServerError, "")
|
ctx.String(http.StatusInternalServerError, "")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Data(200, "application/json; charset=utf-8", respBody)
|
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m *Machine) ([]byte, error) {
|
func (h *Headscale) getMapResponse(
|
||||||
|
machineKey wgkey.Key,
|
||||||
|
req tailcfg.MapRequest,
|
||||||
|
machine *Machine,
|
||||||
|
) ([]byte, error) {
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("func", "getMapResponse").
|
Str("func", "getMapResponse").
|
||||||
Str("machine", req.Hostinfo.Hostname).
|
Str("machine", req.Hostinfo.Hostname).
|
||||||
Msg("Creating Map response")
|
Msg("Creating Map response")
|
||||||
node, err := m.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
|
node, err := machine.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("func", "getMapResponse").
|
Str("func", "getMapResponse").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot convert to node")
|
Msg("Cannot convert to node")
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
peers, err := h.getPeers(m)
|
peers, err := h.getPeers(machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("func", "getMapResponse").
|
Str("func", "getMapResponse").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot fetch peers")
|
Msg("Cannot fetch peers")
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
profiles := getMapResponseUserProfiles(*m, peers)
|
profiles := getMapResponseUserProfiles(*machine, peers)
|
||||||
|
|
||||||
nodePeers, err := peers.toNodes(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
|
nodePeers, err := peers.toNodes(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -288,17 +322,16 @@ func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m *Ma
|
||||||
Str("func", "getMapResponse").
|
Str("func", "getMapResponse").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Failed to convert peers to Tailscale nodes")
|
Msg("Failed to convert peers to Tailscale nodes")
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsConfig, err := getMapResponseDNSConfig(h.cfg.DNSConfig, h.cfg.BaseDomain, *m, peers)
|
dnsConfig := getMapResponseDNSConfig(
|
||||||
if err != nil {
|
h.cfg.DNSConfig,
|
||||||
log.Error().
|
h.cfg.BaseDomain,
|
||||||
Str("func", "getMapResponse").
|
*machine,
|
||||||
Err(err).
|
peers,
|
||||||
Msg("Failed generate the DNSConfig")
|
)
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := tailcfg.MapResponse{
|
resp := tailcfg.MapResponse{
|
||||||
KeepAlive: false,
|
KeepAlive: false,
|
||||||
|
@ -323,66 +356,71 @@ func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m *Ma
|
||||||
|
|
||||||
encoder, _ := zstd.NewWriter(nil)
|
encoder, _ := zstd.NewWriter(nil)
|
||||||
srcCompressed := encoder.EncodeAll(src, nil)
|
srcCompressed := encoder.EncodeAll(src, nil)
|
||||||
respBody, err = encodeMsg(srcCompressed, &mKey, h.privateKey)
|
respBody, err = encodeMsg(srcCompressed, &machineKey, h.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
respBody, err = encode(resp, &mKey, h.privateKey)
|
respBody, err = encode(resp, &machineKey, h.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// declare the incoming size on the first 4 bytes
|
// declare the incoming size on the first 4 bytes
|
||||||
data := make([]byte, 4)
|
data := make([]byte, reservedResponseHeaderSize)
|
||||||
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
|
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
|
||||||
data = append(data, respBody...)
|
data = append(data, respBody...)
|
||||||
|
|
||||||
return data, nil
|
return data, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) getMapKeepAliveResponse(mKey wgkey.Key, req tailcfg.MapRequest, m *Machine) ([]byte, error) {
|
func (h *Headscale) getMapKeepAliveResponse(
|
||||||
resp := tailcfg.MapResponse{
|
machineKey wgkey.Key,
|
||||||
|
mapRequest tailcfg.MapRequest,
|
||||||
|
) ([]byte, error) {
|
||||||
|
mapResponse := tailcfg.MapResponse{
|
||||||
KeepAlive: true,
|
KeepAlive: true,
|
||||||
}
|
}
|
||||||
var respBody []byte
|
var respBody []byte
|
||||||
var err error
|
var err error
|
||||||
if req.Compress == "zstd" {
|
if mapRequest.Compress == "zstd" {
|
||||||
src, _ := json.Marshal(resp)
|
src, _ := json.Marshal(mapResponse)
|
||||||
encoder, _ := zstd.NewWriter(nil)
|
encoder, _ := zstd.NewWriter(nil)
|
||||||
srcCompressed := encoder.EncodeAll(src, nil)
|
srcCompressed := encoder.EncodeAll(src, nil)
|
||||||
respBody, err = encodeMsg(srcCompressed, &mKey, h.privateKey)
|
respBody, err = encodeMsg(srcCompressed, &machineKey, h.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
respBody, err = encode(resp, &mKey, h.privateKey)
|
respBody, err = encode(mapResponse, &machineKey, h.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
data := make([]byte, 4)
|
data := make([]byte, reservedResponseHeaderSize)
|
||||||
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
|
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
|
||||||
data = append(data, respBody...)
|
data = append(data, respBody...)
|
||||||
|
|
||||||
return data, nil
|
return data, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) handleAuthKey(
|
func (h *Headscale) handleAuthKey(
|
||||||
c *gin.Context,
|
ctx *gin.Context,
|
||||||
db *gorm.DB,
|
db *gorm.DB,
|
||||||
idKey wgkey.Key,
|
idKey wgkey.Key,
|
||||||
req tailcfg.RegisterRequest,
|
reqisterRequest tailcfg.RegisterRequest,
|
||||||
m Machine,
|
machine Machine,
|
||||||
) {
|
) {
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Str("func", "handleAuthKey").
|
Str("func", "handleAuthKey").
|
||||||
Str("machine", req.Hostinfo.Hostname).
|
Str("machine", reqisterRequest.Hostinfo.Hostname).
|
||||||
Msgf("Processing auth key for %s", req.Hostinfo.Hostname)
|
Msgf("Processing auth key for %s", reqisterRequest.Hostinfo.Hostname)
|
||||||
resp := tailcfg.RegisterResponse{}
|
resp := tailcfg.RegisterResponse{}
|
||||||
pak, err := h.checkKeyValidity(req.Auth.AuthKey)
|
pak, err := h.checkKeyValidity(reqisterRequest.Auth.AuthKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("func", "handleAuthKey").
|
Str("func", "handleAuthKey").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Failed authentication via AuthKey")
|
Msg("Failed authentication via AuthKey")
|
||||||
resp.MachineAuthorized = false
|
resp.MachineAuthorized = false
|
||||||
|
@ -390,48 +428,56 @@ func (h *Headscale) handleAuthKey(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("func", "handleAuthKey").
|
Str("func", "handleAuthKey").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot encode message")
|
Msg("Cannot encode message")
|
||||||
c.String(http.StatusInternalServerError, "")
|
ctx.String(http.StatusInternalServerError, "")
|
||||||
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).Inc()
|
machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
|
||||||
|
Inc()
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Data(401, "application/json; charset=utf-8", respBody)
|
ctx.Data(http.StatusUnauthorized, "application/json; charset=utf-8", respBody)
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("func", "handleAuthKey").
|
Str("func", "handleAuthKey").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Failed authentication via AuthKey")
|
Msg("Failed authentication via AuthKey")
|
||||||
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).Inc()
|
machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
|
||||||
|
Inc()
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Str("func", "handleAuthKey").
|
Str("func", "handleAuthKey").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Authentication key was valid, proceeding to acquire an IP address")
|
Msg("Authentication key was valid, proceeding to acquire an IP address")
|
||||||
ip, err := h.getAvailableIP()
|
ip, err := h.getAvailableIP()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("func", "handleAuthKey").
|
Str("func", "handleAuthKey").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Failed to find an available IP")
|
Msg("Failed to find an available IP")
|
||||||
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).Inc()
|
machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
|
||||||
|
Inc()
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Info().
|
log.Info().
|
||||||
Str("func", "handleAuthKey").
|
Str("func", "handleAuthKey").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Str("ip", ip.String()).
|
Str("ip", ip.String()).
|
||||||
Msgf("Assigning %s to %s", ip, m.Name)
|
Msgf("Assigning %s to %s", ip, machine.Name)
|
||||||
|
|
||||||
m.AuthKeyID = uint(pak.ID)
|
machine.AuthKeyID = uint(pak.ID)
|
||||||
m.IPAddress = ip.String()
|
machine.IPAddress = ip.String()
|
||||||
m.NamespaceID = pak.NamespaceID
|
machine.NamespaceID = pak.NamespaceID
|
||||||
m.NodeKey = wgkey.Key(req.NodeKey).HexString() // we update it just in case
|
machine.NodeKey = wgkey.Key(reqisterRequest.NodeKey).
|
||||||
m.Registered = true
|
HexString()
|
||||||
m.RegisterMethod = "authKey"
|
// we update it just in case
|
||||||
db.Save(&m)
|
machine.Registered = true
|
||||||
|
machine.RegisterMethod = "authKey"
|
||||||
|
db.Save(&machine)
|
||||||
|
|
||||||
pak.Used = true
|
pak.Used = true
|
||||||
db.Save(&pak)
|
db.Save(&pak)
|
||||||
|
@ -442,18 +488,21 @@ func (h *Headscale) handleAuthKey(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("func", "handleAuthKey").
|
Str("func", "handleAuthKey").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot encode message")
|
Msg("Cannot encode message")
|
||||||
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).Inc()
|
machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
|
||||||
c.String(http.StatusInternalServerError, "Extremely sad!")
|
Inc()
|
||||||
|
ctx.String(http.StatusInternalServerError, "Extremely sad!")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
machineRegistrations.WithLabelValues("new", "authkey", "success", m.Namespace.Name).Inc()
|
machineRegistrations.WithLabelValues("new", "authkey", "success", machine.Namespace.Name).
|
||||||
c.Data(200, "application/json; charset=utf-8", respBody)
|
Inc()
|
||||||
|
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
|
||||||
log.Info().
|
log.Info().
|
||||||
Str("func", "handleAuthKey").
|
Str("func", "handleAuthKey").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Str("ip", ip.String()).
|
Str("ip", ip.String()).
|
||||||
Msg("Successfully authenticated via AuthKey")
|
Msg("Successfully authenticated via AuthKey")
|
||||||
}
|
}
|
||||||
|
|
271
app.go
271
app.go
|
@ -18,20 +18,19 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/coreos/go-oidc/v3/oidc"
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
"github.com/patrickmn/go-cache"
|
|
||||||
"golang.org/x/oauth2"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/grpc-ecosystem/go-grpc-middleware"
|
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
|
||||||
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
|
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
"github.com/philip-bui/grpc-zerolog"
|
"github.com/patrickmn/go-cache"
|
||||||
|
zerolog "github.com/philip-bui/grpc-zerolog"
|
||||||
zl "github.com/rs/zerolog"
|
zl "github.com/rs/zerolog"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/soheilhy/cmux"
|
"github.com/soheilhy/cmux"
|
||||||
ginprometheus "github.com/zsais/go-gin-prometheus"
|
ginprometheus "github.com/zsais/go-gin-prometheus"
|
||||||
"golang.org/x/crypto/acme"
|
"golang.org/x/crypto/acme"
|
||||||
"golang.org/x/crypto/acme/autocert"
|
"golang.org/x/crypto/acme/autocert"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
|
@ -48,7 +47,16 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
AUTH_PREFIX = "Bearer "
|
AuthPrefix = "Bearer "
|
||||||
|
Postgres = "postgresql"
|
||||||
|
Sqlite = "sqlite3"
|
||||||
|
updateInterval = 5000
|
||||||
|
HTTPReadTimeout = 30 * time.Second
|
||||||
|
|
||||||
|
errUnsupportedDatabase = Error("unsupported DB")
|
||||||
|
errUnsupportedLetsEncryptChallengeType = Error(
|
||||||
|
"unknown value for Lets Encrypt challenge type",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
// Config contains the initial Headscale configuration.
|
// Config contains the initial Headscale configuration.
|
||||||
|
@ -151,16 +159,22 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
|
||||||
|
|
||||||
var dbString string
|
var dbString string
|
||||||
switch cfg.DBtype {
|
switch cfg.DBtype {
|
||||||
case "postgres":
|
case Postgres:
|
||||||
dbString = fmt.Sprintf("host=%s port=%d dbname=%s user=%s password=%s sslmode=disable", cfg.DBhost,
|
dbString = fmt.Sprintf(
|
||||||
cfg.DBport, cfg.DBname, cfg.DBuser, cfg.DBpass)
|
"host=%s port=%d dbname=%s user=%s password=%s sslmode=disable",
|
||||||
case "sqlite3":
|
cfg.DBhost,
|
||||||
|
cfg.DBport,
|
||||||
|
cfg.DBname,
|
||||||
|
cfg.DBuser,
|
||||||
|
cfg.DBpass,
|
||||||
|
)
|
||||||
|
case Sqlite:
|
||||||
dbString = cfg.DBpath
|
dbString = cfg.DBpath
|
||||||
default:
|
default:
|
||||||
return nil, errors.New("unsupported DB")
|
return nil, errUnsupportedDatabase
|
||||||
}
|
}
|
||||||
|
|
||||||
h := Headscale{
|
app := Headscale{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
dbType: cfg.DBtype,
|
dbType: cfg.DBtype,
|
||||||
dbString: dbString,
|
dbString: dbString,
|
||||||
|
@ -169,33 +183,32 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
|
||||||
aclRules: tailcfg.FilterAllowAll, // default allowall
|
aclRules: tailcfg.FilterAllowAll, // default allowall
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.initDB()
|
err = app.initDB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.OIDC.Issuer != "" {
|
if cfg.OIDC.Issuer != "" {
|
||||||
err = h.initOIDC()
|
err = app.initOIDC()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if h.cfg.DNSConfig != nil && h.cfg.DNSConfig.Proxied { // if MagicDNS
|
if app.cfg.DNSConfig != nil && app.cfg.DNSConfig.Proxied { // if MagicDNS
|
||||||
magicDNSDomains, err := generateMagicDNSRootDomains(h.cfg.IPPrefix, h.cfg.BaseDomain)
|
magicDNSDomains := generateMagicDNSRootDomains(
|
||||||
if err != nil {
|
app.cfg.IPPrefix,
|
||||||
return nil, err
|
)
|
||||||
}
|
|
||||||
// we might have routes already from Split DNS
|
// we might have routes already from Split DNS
|
||||||
if h.cfg.DNSConfig.Routes == nil {
|
if app.cfg.DNSConfig.Routes == nil {
|
||||||
h.cfg.DNSConfig.Routes = make(map[string][]dnstype.Resolver)
|
app.cfg.DNSConfig.Routes = make(map[string][]dnstype.Resolver)
|
||||||
}
|
}
|
||||||
for _, d := range magicDNSDomains {
|
for _, d := range magicDNSDomains {
|
||||||
h.cfg.DNSConfig.Routes[d.WithoutTrailingDot()] = nil
|
app.cfg.DNSConfig.Routes[d.WithoutTrailingDot()] = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &h, nil
|
return &app, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Redirect to our TLS url.
|
// Redirect to our TLS url.
|
||||||
|
@ -221,30 +234,37 @@ func (h *Headscale) expireEphemeralNodesWorker() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, ns := range namespaces {
|
for _, namespace := range namespaces {
|
||||||
machines, err := h.ListMachinesInNamespace(ns.Name)
|
machines, err := h.ListMachinesInNamespace(namespace.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Str("namespace", ns.Name).Msg("Error listing machines in namespace")
|
log.Error().
|
||||||
|
Err(err).
|
||||||
|
Str("namespace", namespace.Name).
|
||||||
|
Msg("Error listing machines in namespace")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, m := range machines {
|
for _, machine := range machines {
|
||||||
if m.AuthKey != nil && m.LastSeen != nil && m.AuthKey.Ephemeral &&
|
if machine.AuthKey != nil && machine.LastSeen != nil &&
|
||||||
time.Now().After(m.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) {
|
machine.AuthKey.Ephemeral &&
|
||||||
log.Info().Str("machine", m.Name).Msg("Ephemeral client removed from database")
|
time.Now().
|
||||||
|
After(machine.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) {
|
||||||
|
log.Info().
|
||||||
|
Str("machine", machine.Name).
|
||||||
|
Msg("Ephemeral client removed from database")
|
||||||
|
|
||||||
err = h.db.Unscoped().Delete(m).Error
|
err = h.db.Unscoped().Delete(machine).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Err(err).
|
Err(err).
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("🤮 Cannot delete ephemeral machine from the database")
|
Msg("🤮 Cannot delete ephemeral machine from the database")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
h.setLastStateChangeToNow(ns.Name)
|
h.setLastStateChangeToNow(namespace.Name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -266,36 +286,56 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
|
||||||
req interface{},
|
req interface{},
|
||||||
info *grpc.UnaryServerInfo,
|
info *grpc.UnaryServerInfo,
|
||||||
handler grpc.UnaryHandler) (interface{}, error) {
|
handler grpc.UnaryHandler) (interface{}, error) {
|
||||||
|
|
||||||
// Check if the request is coming from the on-server client.
|
// Check if the request is coming from the on-server client.
|
||||||
// This is not secure, but it is to maintain maintainability
|
// This is not secure, but it is to maintain maintainability
|
||||||
// with the "legacy" database-based client
|
// with the "legacy" database-based client
|
||||||
// It is also neede for grpc-gateway to be able to connect to
|
// It is also neede for grpc-gateway to be able to connect to
|
||||||
// the server
|
// the server
|
||||||
p, _ := peer.FromContext(ctx)
|
client, _ := peer.FromContext(ctx)
|
||||||
|
|
||||||
log.Trace().Caller().Str("client_address", p.Addr.String()).Msg("Client is trying to authenticate")
|
log.Trace().
|
||||||
|
Caller().
|
||||||
|
Str("client_address", client.Addr.String()).
|
||||||
|
Msg("Client is trying to authenticate")
|
||||||
|
|
||||||
md, ok := metadata.FromIncomingContext(ctx)
|
meta, ok := metadata.FromIncomingContext(ctx)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Error().Caller().Str("client_address", p.Addr.String()).Msg("Retrieving metadata is failed")
|
log.Error().
|
||||||
return ctx, status.Errorf(codes.InvalidArgument, "Retrieving metadata is failed")
|
Caller().
|
||||||
|
Str("client_address", client.Addr.String()).
|
||||||
|
Msg("Retrieving metadata is failed")
|
||||||
|
|
||||||
|
return ctx, status.Errorf(
|
||||||
|
codes.InvalidArgument,
|
||||||
|
"Retrieving metadata is failed",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
authHeader, ok := md["authorization"]
|
authHeader, ok := meta["authorization"]
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Error().Caller().Str("client_address", p.Addr.String()).Msg("Authorization token is not supplied")
|
log.Error().
|
||||||
return ctx, status.Errorf(codes.Unauthenticated, "Authorization token is not supplied")
|
Caller().
|
||||||
|
Str("client_address", client.Addr.String()).
|
||||||
|
Msg("Authorization token is not supplied")
|
||||||
|
|
||||||
|
return ctx, status.Errorf(
|
||||||
|
codes.Unauthenticated,
|
||||||
|
"Authorization token is not supplied",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
token := authHeader[0]
|
token := authHeader[0]
|
||||||
|
|
||||||
if !strings.HasPrefix(token, AUTH_PREFIX) {
|
if !strings.HasPrefix(token, AuthPrefix) {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
Str("client_address", p.Addr.String()).
|
Str("client_address", client.Addr.String()).
|
||||||
Msg(`missing "Bearer " prefix in "Authorization" header`)
|
Msg(`missing "Bearer " prefix in "Authorization" header`)
|
||||||
return ctx, status.Error(codes.Unauthenticated, `missing "Bearer " prefix in "Authorization" header`)
|
|
||||||
|
return ctx, status.Error(
|
||||||
|
codes.Unauthenticated,
|
||||||
|
`missing "Bearer " prefix in "Authorization" header`,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(kradalby): Implement API key backend:
|
// TODO(kradalby): Implement API key backend:
|
||||||
|
@ -307,7 +347,10 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
|
||||||
// Currently all other than localhost traffic is unauthorized, this is intentional to allow
|
// Currently all other than localhost traffic is unauthorized, this is intentional to allow
|
||||||
// us to make use of gRPC for our CLI, but not having to implement any of the remote capabilities
|
// us to make use of gRPC for our CLI, but not having to implement any of the remote capabilities
|
||||||
// and API key auth
|
// and API key auth
|
||||||
return ctx, status.Error(codes.Unauthenticated, "Authentication is not implemented yet")
|
return ctx, status.Error(
|
||||||
|
codes.Unauthenticated,
|
||||||
|
"Authentication is not implemented yet",
|
||||||
|
)
|
||||||
|
|
||||||
// if strings.TrimPrefix(token, AUTH_PREFIX) != a.Token {
|
// if strings.TrimPrefix(token, AUTH_PREFIX) != a.Token {
|
||||||
// log.Error().Caller().Str("client_address", p.Addr.String()).Msg("invalid token")
|
// log.Error().Caller().Str("client_address", p.Addr.String()).Msg("invalid token")
|
||||||
|
@ -317,25 +360,25 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
|
||||||
// return handler(ctx, req)
|
// return handler(ctx, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) httpAuthenticationMiddleware(c *gin.Context) {
|
func (h *Headscale) httpAuthenticationMiddleware(ctx *gin.Context) {
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("client_address", c.ClientIP()).
|
Str("client_address", ctx.ClientIP()).
|
||||||
Msg("HTTP authentication invoked")
|
Msg("HTTP authentication invoked")
|
||||||
|
|
||||||
authHeader := c.GetHeader("authorization")
|
authHeader := ctx.GetHeader("authorization")
|
||||||
|
|
||||||
if !strings.HasPrefix(authHeader, AUTH_PREFIX) {
|
if !strings.HasPrefix(authHeader, AuthPrefix) {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
Str("client_address", c.ClientIP()).
|
Str("client_address", ctx.ClientIP()).
|
||||||
Msg(`missing "Bearer " prefix in "Authorization" header`)
|
Msg(`missing "Bearer " prefix in "Authorization" header`)
|
||||||
c.AbortWithStatus(http.StatusUnauthorized)
|
ctx.AbortWithStatus(http.StatusUnauthorized)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.AbortWithStatus(http.StatusUnauthorized)
|
ctx.AbortWithStatus(http.StatusUnauthorized)
|
||||||
|
|
||||||
// TODO(kradalby): Implement API key backend
|
// TODO(kradalby): Implement API key backend
|
||||||
// Currently all traffic is unauthorized, this is intentional to allow
|
// Currently all traffic is unauthorized, this is intentional to allow
|
||||||
|
@ -359,6 +402,7 @@ func (h *Headscale) ensureUnixSocketIsAbsent() error {
|
||||||
if _, err := os.Stat(h.cfg.UnixSocket); errors.Is(err, os.ErrNotExist) {
|
if _, err := os.Stat(h.cfg.UnixSocket); errors.Is(err, os.ErrNotExist) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return os.Remove(h.cfg.UnixSocket)
|
return os.Remove(h.cfg.UnixSocket)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -401,14 +445,17 @@ func (h *Headscale) Serve() error {
|
||||||
|
|
||||||
// Create the cmux object that will multiplex 2 protocols on the same port.
|
// Create the cmux object that will multiplex 2 protocols on the same port.
|
||||||
// The two following listeners will be served on the same port below gracefully.
|
// The two following listeners will be served on the same port below gracefully.
|
||||||
m := cmux.New(networkListener)
|
networkMutex := cmux.New(networkListener)
|
||||||
// Match gRPC requests here
|
// Match gRPC requests here
|
||||||
grpcListener := m.MatchWithWriters(
|
grpcListener := networkMutex.MatchWithWriters(
|
||||||
cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc"),
|
cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc"),
|
||||||
cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc+proto"),
|
cmux.HTTP2MatchHeaderFieldSendSettings(
|
||||||
|
"content-type",
|
||||||
|
"application/grpc+proto",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
// Otherwise match regular http requests.
|
// Otherwise match regular http requests.
|
||||||
httpListener := m.Match(cmux.Any())
|
httpListener := networkMutex.Match(cmux.Any())
|
||||||
|
|
||||||
grpcGatewayMux := runtime.NewServeMux()
|
grpcGatewayMux := runtime.NewServeMux()
|
||||||
|
|
||||||
|
@ -431,30 +478,33 @@ func (h *Headscale) Serve() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
r := gin.Default()
|
router := gin.Default()
|
||||||
|
|
||||||
p := ginprometheus.NewPrometheus("gin")
|
prometheus := ginprometheus.NewPrometheus("gin")
|
||||||
p.Use(r)
|
prometheus.Use(router)
|
||||||
|
|
||||||
r.GET("/health", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"healthy": "ok"}) })
|
router.GET(
|
||||||
r.GET("/key", h.KeyHandler)
|
"/health",
|
||||||
r.GET("/register", h.RegisterWebAPI)
|
func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"healthy": "ok"}) },
|
||||||
r.POST("/machine/:id/map", h.PollNetMapHandler)
|
)
|
||||||
r.POST("/machine/:id", h.RegistrationHandler)
|
router.GET("/key", h.KeyHandler)
|
||||||
r.GET("/oidc/register/:mkey", h.RegisterOIDC)
|
router.GET("/register", h.RegisterWebAPI)
|
||||||
r.GET("/oidc/callback", h.OIDCCallback)
|
router.POST("/machine/:id/map", h.PollNetMapHandler)
|
||||||
r.GET("/apple", h.AppleMobileConfig)
|
router.POST("/machine/:id", h.RegistrationHandler)
|
||||||
r.GET("/apple/:platform", h.ApplePlatformConfig)
|
router.GET("/oidc/register/:mkey", h.RegisterOIDC)
|
||||||
r.GET("/swagger", SwaggerUI)
|
router.GET("/oidc/callback", h.OIDCCallback)
|
||||||
r.GET("/swagger/v1/openapiv2.json", SwaggerAPIv1)
|
router.GET("/apple", h.AppleMobileConfig)
|
||||||
|
router.GET("/apple/:platform", h.ApplePlatformConfig)
|
||||||
|
router.GET("/swagger", SwaggerUI)
|
||||||
|
router.GET("/swagger/v1/openapiv2.json", SwaggerAPIv1)
|
||||||
|
|
||||||
api := r.Group("/api")
|
api := router.Group("/api")
|
||||||
api.Use(h.httpAuthenticationMiddleware)
|
api.Use(h.httpAuthenticationMiddleware)
|
||||||
{
|
{
|
||||||
api.Any("/v1/*any", gin.WrapF(grpcGatewayMux.ServeHTTP))
|
api.Any("/v1/*any", gin.WrapF(grpcGatewayMux.ServeHTTP))
|
||||||
}
|
}
|
||||||
|
|
||||||
r.NoRoute(stdoutHandler)
|
router.NoRoute(stdoutHandler)
|
||||||
|
|
||||||
// Fetch an initial DERP Map before we start serving
|
// Fetch an initial DERP Map before we start serving
|
||||||
h.DERPMap = GetDERPMap(h.cfg.DERP)
|
h.DERPMap = GetDERPMap(h.cfg.DERP)
|
||||||
|
@ -466,14 +516,13 @@ func (h *Headscale) Serve() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// I HATE THIS
|
// I HATE THIS
|
||||||
updateMillisecondsWait := int64(5000)
|
go h.watchForKVUpdates(updateInterval)
|
||||||
go h.watchForKVUpdates(updateMillisecondsWait)
|
go h.expireEphemeralNodes(updateInterval)
|
||||||
go h.expireEphemeralNodes(updateMillisecondsWait)
|
|
||||||
|
|
||||||
httpServer := &http.Server{
|
httpServer := &http.Server{
|
||||||
Addr: h.cfg.Addr,
|
Addr: h.cfg.Addr,
|
||||||
Handler: r,
|
Handler: router,
|
||||||
ReadTimeout: 30 * time.Second,
|
ReadTimeout: HTTPReadTimeout,
|
||||||
// Go does not handle timeouts in HTTP very well, and there is
|
// Go does not handle timeouts in HTTP very well, and there is
|
||||||
// no good way to handle streaming timeouts, therefore we need to
|
// no good way to handle streaming timeouts, therefore we need to
|
||||||
// keep this at unlimited and be careful to clean up connections
|
// keep this at unlimited and be careful to clean up connections
|
||||||
|
@ -519,36 +568,40 @@ func (h *Headscale) Serve() error {
|
||||||
reflection.Register(grpcServer)
|
reflection.Register(grpcServer)
|
||||||
reflection.Register(grpcSocket)
|
reflection.Register(grpcSocket)
|
||||||
|
|
||||||
g := new(errgroup.Group)
|
errorGroup := new(errgroup.Group)
|
||||||
|
|
||||||
g.Go(func() error { return grpcSocket.Serve(socketListener) })
|
errorGroup.Go(func() error { return grpcSocket.Serve(socketListener) })
|
||||||
|
|
||||||
// TODO(kradalby): Verify if we need the same TLS setup for gRPC as HTTP
|
// TODO(kradalby): Verify if we need the same TLS setup for gRPC as HTTP
|
||||||
g.Go(func() error { return grpcServer.Serve(grpcListener) })
|
errorGroup.Go(func() error { return grpcServer.Serve(grpcListener) })
|
||||||
|
|
||||||
if tlsConfig != nil {
|
if tlsConfig != nil {
|
||||||
g.Go(func() error {
|
errorGroup.Go(func() error {
|
||||||
tlsl := tls.NewListener(httpListener, tlsConfig)
|
tlsl := tls.NewListener(httpListener, tlsConfig)
|
||||||
|
|
||||||
return httpServer.Serve(tlsl)
|
return httpServer.Serve(tlsl)
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
g.Go(func() error { return httpServer.Serve(httpListener) })
|
errorGroup.Go(func() error { return httpServer.Serve(httpListener) })
|
||||||
}
|
}
|
||||||
|
|
||||||
g.Go(func() error { return m.Serve() })
|
errorGroup.Go(func() error { return networkMutex.Serve() })
|
||||||
|
|
||||||
log.Info().Msgf("listening and serving (multiplexed HTTP and gRPC) on: %s", h.cfg.Addr)
|
log.Info().
|
||||||
|
Msgf("listening and serving (multiplexed HTTP and gRPC) on: %s", h.cfg.Addr)
|
||||||
|
|
||||||
return g.Wait()
|
return errorGroup.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) getTLSSettings() (*tls.Config, error) {
|
func (h *Headscale) getTLSSettings() (*tls.Config, error) {
|
||||||
|
var err error
|
||||||
if h.cfg.TLSLetsEncryptHostname != "" {
|
if h.cfg.TLSLetsEncryptHostname != "" {
|
||||||
if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
|
if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
|
||||||
log.Warn().Msg("Listening with TLS but ServerURL does not start with https://")
|
log.Warn().
|
||||||
|
Msg("Listening with TLS but ServerURL does not start with https://")
|
||||||
}
|
}
|
||||||
|
|
||||||
m := autocert.Manager{
|
certManager := autocert.Manager{
|
||||||
Prompt: autocert.AcceptTOS,
|
Prompt: autocert.AcceptTOS,
|
||||||
HostPolicy: autocert.HostWhitelist(h.cfg.TLSLetsEncryptHostname),
|
HostPolicy: autocert.HostWhitelist(h.cfg.TLSLetsEncryptHostname),
|
||||||
Cache: autocert.DirCache(h.cfg.TLSLetsEncryptCacheDir),
|
Cache: autocert.DirCache(h.cfg.TLSLetsEncryptCacheDir),
|
||||||
|
@ -558,40 +611,44 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
|
||||||
Email: h.cfg.ACMEEmail,
|
Email: h.cfg.ACMEEmail,
|
||||||
}
|
}
|
||||||
|
|
||||||
if h.cfg.TLSLetsEncryptChallengeType == "TLS-ALPN-01" {
|
switch h.cfg.TLSLetsEncryptChallengeType {
|
||||||
|
case "TLS-ALPN-01":
|
||||||
// Configuration via autocert with TLS-ALPN-01 (https://tools.ietf.org/html/rfc8737)
|
// Configuration via autocert with TLS-ALPN-01 (https://tools.ietf.org/html/rfc8737)
|
||||||
// The RFC requires that the validation is done on port 443; in other words, headscale
|
// The RFC requires that the validation is done on port 443; in other words, headscale
|
||||||
// must be reachable on port 443.
|
// must be reachable on port 443.
|
||||||
return m.TLSConfig(), nil
|
return certManager.TLSConfig(), nil
|
||||||
} else if h.cfg.TLSLetsEncryptChallengeType == "HTTP-01" {
|
|
||||||
|
case "HTTP-01":
|
||||||
// Configuration via autocert with HTTP-01. This requires listening on
|
// Configuration via autocert with HTTP-01. This requires listening on
|
||||||
// port 80 for the certificate validation in addition to the headscale
|
// port 80 for the certificate validation in addition to the headscale
|
||||||
// service, which can be configured to run on any other port.
|
// service, which can be configured to run on any other port.
|
||||||
go func() {
|
go func() {
|
||||||
log.Fatal().
|
log.Fatal().
|
||||||
Err(http.ListenAndServe(h.cfg.TLSLetsEncryptListen, m.HTTPHandler(http.HandlerFunc(h.redirect)))).
|
Err(http.ListenAndServe(h.cfg.TLSLetsEncryptListen, certManager.HTTPHandler(http.HandlerFunc(h.redirect)))).
|
||||||
Msg("failed to set up a HTTP server")
|
Msg("failed to set up a HTTP server")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return m.TLSConfig(), nil
|
return certManager.TLSConfig(), nil
|
||||||
} else {
|
|
||||||
return nil, errors.New("unknown value for TLSLetsEncryptChallengeType")
|
default:
|
||||||
|
return nil, errUnsupportedLetsEncryptChallengeType
|
||||||
}
|
}
|
||||||
} else if h.cfg.TLSCertPath == "" {
|
} else if h.cfg.TLSCertPath == "" {
|
||||||
if !strings.HasPrefix(h.cfg.ServerURL, "http://") {
|
if !strings.HasPrefix(h.cfg.ServerURL, "http://") {
|
||||||
log.Warn().Msg("Listening without TLS but ServerURL does not start with http://")
|
log.Warn().Msg("Listening without TLS but ServerURL does not start with http://")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, nil
|
return nil, err
|
||||||
} else {
|
} else {
|
||||||
if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
|
if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
|
||||||
log.Warn().Msg("Listening with TLS but ServerURL does not start with https://")
|
log.Warn().Msg("Listening with TLS but ServerURL does not start with https://")
|
||||||
}
|
}
|
||||||
var err error
|
tlsConfig := &tls.Config{
|
||||||
tlsConfig := &tls.Config{}
|
ClientAuth: tls.RequireAnyClientCert,
|
||||||
tlsConfig.ClientAuth = tls.RequireAnyClientCert
|
NextProtos: []string{"http/1.1"},
|
||||||
tlsConfig.NextProtos = []string{"http/1.1"}
|
Certificates: make([]tls.Certificate, 1),
|
||||||
tlsConfig.Certificates = make([]tls.Certificate, 1)
|
MinVersion: tls.VersionTLS12,
|
||||||
|
}
|
||||||
tlsConfig.Certificates[0], err = tls.LoadX509KeyPair(h.cfg.TLSCertPath, h.cfg.TLSKeyPath)
|
tlsConfig.Certificates[0], err = tls.LoadX509KeyPair(h.cfg.TLSCertPath, h.cfg.TLSKeyPath)
|
||||||
|
|
||||||
return tlsConfig, err
|
return tlsConfig, err
|
||||||
|
@ -628,13 +685,13 @@ func (h *Headscale) getLastStateChange(namespaces ...string) time.Time {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func stdoutHandler(c *gin.Context) {
|
func stdoutHandler(ctx *gin.Context) {
|
||||||
b, _ := io.ReadAll(c.Request.Body)
|
body, _ := io.ReadAll(ctx.Request.Body)
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Interface("header", c.Request.Header).
|
Interface("header", ctx.Request.Header).
|
||||||
Interface("proto", c.Request.Proto).
|
Interface("proto", ctx.Request.Proto).
|
||||||
Interface("url", c.Request.URL).
|
Interface("url", ctx.Request.URL).
|
||||||
Bytes("body", b).
|
Bytes("body", body).
|
||||||
Msg("Request did not match")
|
Msg("Request did not match")
|
||||||
}
|
}
|
||||||
|
|
14
app_test.go
14
app_test.go
|
@ -17,8 +17,10 @@ var _ = check.Suite(&Suite{})
|
||||||
|
|
||||||
type Suite struct{}
|
type Suite struct{}
|
||||||
|
|
||||||
var tmpDir string
|
var (
|
||||||
var h Headscale
|
tmpDir string
|
||||||
|
app Headscale
|
||||||
|
)
|
||||||
|
|
||||||
func (s *Suite) SetUpTest(c *check.C) {
|
func (s *Suite) SetUpTest(c *check.C) {
|
||||||
s.ResetDB(c)
|
s.ResetDB(c)
|
||||||
|
@ -41,18 +43,18 @@ func (s *Suite) ResetDB(c *check.C) {
|
||||||
IPPrefix: netaddr.MustParseIPPrefix("10.27.0.0/23"),
|
IPPrefix: netaddr.MustParseIPPrefix("10.27.0.0/23"),
|
||||||
}
|
}
|
||||||
|
|
||||||
h = Headscale{
|
app = Headscale{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
dbType: "sqlite3",
|
dbType: "sqlite3",
|
||||||
dbString: tmpDir + "/headscale_test.db",
|
dbString: tmpDir + "/headscale_test.db",
|
||||||
}
|
}
|
||||||
err = h.initDB()
|
err = app.initDB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Fatal(err)
|
c.Fatal(err)
|
||||||
}
|
}
|
||||||
db, err := h.openDB()
|
db, err := app.openDB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Fatal(err)
|
c.Fatal(err)
|
||||||
}
|
}
|
||||||
h.db = db
|
app.db = db
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,16 +5,15 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"text/template"
|
"text/template"
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/gofrs/uuid"
|
"github.com/gofrs/uuid"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AppleMobileConfig shows a simple message in the browser to point to the CLI
|
// AppleMobileConfig shows a simple message in the browser to point to the CLI
|
||||||
// Listens in /register
|
// Listens in /register.
|
||||||
func (h *Headscale) AppleMobileConfig(c *gin.Context) {
|
func (h *Headscale) AppleMobileConfig(ctx *gin.Context) {
|
||||||
t := template.Must(template.New("apple").Parse(`
|
appleTemplate := template.Must(template.New("apple").Parse(`
|
||||||
<html>
|
<html>
|
||||||
<body>
|
<body>
|
||||||
<h1>Apple configuration profiles</h1>
|
<h1>Apple configuration profiles</h1>
|
||||||
|
@ -56,7 +55,7 @@ func (h *Headscale) AppleMobileConfig(c *gin.Context) {
|
||||||
|
|
||||||
<p>Or</p>
|
<p>Or</p>
|
||||||
<p>Use your terminal to configure the default setting for Tailscale by issuing:</p>
|
<p>Use your terminal to configure the default setting for Tailscale by issuing:</p>
|
||||||
<code>defaults write io.tailscale.ipn.macos ControlURL {{.Url}}</code>
|
<code>defaults write io.tailscale.ipn.macos ControlURL {{.URL}}</code>
|
||||||
|
|
||||||
<p>Restart Tailscale.app and log in.</p>
|
<p>Restart Tailscale.app and log in.</p>
|
||||||
|
|
||||||
|
@ -64,24 +63,29 @@ func (h *Headscale) AppleMobileConfig(c *gin.Context) {
|
||||||
</html>`))
|
</html>`))
|
||||||
|
|
||||||
config := map[string]interface{}{
|
config := map[string]interface{}{
|
||||||
"Url": h.cfg.ServerURL,
|
"URL": h.cfg.ServerURL,
|
||||||
}
|
}
|
||||||
|
|
||||||
var payload bytes.Buffer
|
var payload bytes.Buffer
|
||||||
if err := t.Execute(&payload, config); err != nil {
|
if err := appleTemplate.Execute(&payload, config); err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "AppleMobileConfig").
|
Str("handler", "AppleMobileConfig").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Could not render Apple index template")
|
Msg("Could not render Apple index template")
|
||||||
c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Apple index template"))
|
ctx.Data(
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
"text/html; charset=utf-8",
|
||||||
|
[]byte("Could not render Apple index template"),
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes())
|
ctx.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
|
func (h *Headscale) ApplePlatformConfig(ctx *gin.Context) {
|
||||||
platform := c.Param("platform")
|
platform := ctx.Param("platform")
|
||||||
|
|
||||||
id, err := uuid.NewV4()
|
id, err := uuid.NewV4()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -89,23 +93,33 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
|
||||||
Str("handler", "ApplePlatformConfig").
|
Str("handler", "ApplePlatformConfig").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Failed not create UUID")
|
Msg("Failed not create UUID")
|
||||||
c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Failed to create UUID"))
|
ctx.Data(
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
"text/html; charset=utf-8",
|
||||||
|
[]byte("Failed to create UUID"),
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
contentId, err := uuid.NewV4()
|
contentID, err := uuid.NewV4()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "ApplePlatformConfig").
|
Str("handler", "ApplePlatformConfig").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Failed not create UUID")
|
Msg("Failed not create UUID")
|
||||||
c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Failed to create UUID"))
|
ctx.Data(
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
"text/html; charset=utf-8",
|
||||||
|
[]byte("Failed to create UUID"),
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
platformConfig := AppleMobilePlatformConfig{
|
platformConfig := AppleMobilePlatformConfig{
|
||||||
UUID: contentId,
|
UUID: contentID,
|
||||||
Url: h.cfg.ServerURL,
|
URL: h.cfg.ServerURL,
|
||||||
}
|
}
|
||||||
|
|
||||||
var payload bytes.Buffer
|
var payload bytes.Buffer
|
||||||
|
@ -117,7 +131,12 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
|
||||||
Str("handler", "ApplePlatformConfig").
|
Str("handler", "ApplePlatformConfig").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Could not render Apple macOS template")
|
Msg("Could not render Apple macOS template")
|
||||||
c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Apple macOS template"))
|
ctx.Data(
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
"text/html; charset=utf-8",
|
||||||
|
[]byte("Could not render Apple macOS template"),
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
case "ios":
|
case "ios":
|
||||||
|
@ -126,17 +145,27 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
|
||||||
Str("handler", "ApplePlatformConfig").
|
Str("handler", "ApplePlatformConfig").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Could not render Apple iOS template")
|
Msg("Could not render Apple iOS template")
|
||||||
c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Apple iOS template"))
|
ctx.Data(
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
"text/html; charset=utf-8",
|
||||||
|
[]byte("Could not render Apple iOS template"),
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte("Invalid platform, only ios and macos is supported"))
|
ctx.Data(
|
||||||
|
http.StatusOK,
|
||||||
|
"text/html; charset=utf-8",
|
||||||
|
[]byte("Invalid platform, only ios and macos is supported"),
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
config := AppleMobileConfig{
|
config := AppleMobileConfig{
|
||||||
UUID: id,
|
UUID: id,
|
||||||
Url: h.cfg.ServerURL,
|
URL: h.cfg.ServerURL,
|
||||||
Payload: payload.String(),
|
Payload: payload.String(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -146,25 +175,35 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
|
||||||
Str("handler", "ApplePlatformConfig").
|
Str("handler", "ApplePlatformConfig").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Could not render Apple platform template")
|
Msg("Could not render Apple platform template")
|
||||||
c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Apple platform template"))
|
ctx.Data(
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
"text/html; charset=utf-8",
|
||||||
|
[]byte("Could not render Apple platform template"),
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Data(http.StatusOK, "application/x-apple-aspen-config; charset=utf-8", content.Bytes())
|
ctx.Data(
|
||||||
|
http.StatusOK,
|
||||||
|
"application/x-apple-aspen-config; charset=utf-8",
|
||||||
|
content.Bytes(),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
type AppleMobileConfig struct {
|
type AppleMobileConfig struct {
|
||||||
UUID uuid.UUID
|
UUID uuid.UUID
|
||||||
Url string
|
URL string
|
||||||
Payload string
|
Payload string
|
||||||
}
|
}
|
||||||
|
|
||||||
type AppleMobilePlatformConfig struct {
|
type AppleMobilePlatformConfig struct {
|
||||||
UUID uuid.UUID
|
UUID uuid.UUID
|
||||||
Url string
|
URL string
|
||||||
}
|
}
|
||||||
|
|
||||||
var commonTemplate = template.Must(template.New("mobileconfig").Parse(`<?xml version="1.0" encoding="UTF-8"?>
|
var commonTemplate = template.Must(
|
||||||
|
template.New("mobileconfig").Parse(`<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||||
<plist version="1.0">
|
<plist version="1.0">
|
||||||
<dict>
|
<dict>
|
||||||
|
@ -173,7 +212,7 @@ var commonTemplate = template.Must(template.New("mobileconfig").Parse(`<?xml ver
|
||||||
<key>PayloadDisplayName</key>
|
<key>PayloadDisplayName</key>
|
||||||
<string>Headscale</string>
|
<string>Headscale</string>
|
||||||
<key>PayloadDescription</key>
|
<key>PayloadDescription</key>
|
||||||
<string>Configure Tailscale login server to: {{.Url}}</string>
|
<string>Configure Tailscale login server to: {{.URL}}</string>
|
||||||
<key>PayloadIdentifier</key>
|
<key>PayloadIdentifier</key>
|
||||||
<string>com.github.juanfont.headscale</string>
|
<string>com.github.juanfont.headscale</string>
|
||||||
<key>PayloadRemovalDisallowed</key>
|
<key>PayloadRemovalDisallowed</key>
|
||||||
|
@ -187,7 +226,8 @@ var commonTemplate = template.Must(template.New("mobileconfig").Parse(`<?xml ver
|
||||||
{{.Payload}}
|
{{.Payload}}
|
||||||
</array>
|
</array>
|
||||||
</dict>
|
</dict>
|
||||||
</plist>`))
|
</plist>`),
|
||||||
|
)
|
||||||
|
|
||||||
var iosTemplate = template.Must(template.New("iosTemplate").Parse(`
|
var iosTemplate = template.Must(template.New("iosTemplate").Parse(`
|
||||||
<dict>
|
<dict>
|
||||||
|
@ -203,7 +243,7 @@ var iosTemplate = template.Must(template.New("iosTemplate").Parse(`
|
||||||
<true/>
|
<true/>
|
||||||
|
|
||||||
<key>ControlURL</key>
|
<key>ControlURL</key>
|
||||||
<string>{{.Url}}</string>
|
<string>{{.URL}}</string>
|
||||||
</dict>
|
</dict>
|
||||||
`))
|
`))
|
||||||
|
|
||||||
|
@ -221,6 +261,6 @@ var macosTemplate = template.Must(template.New("macosTemplate").Parse(`
|
||||||
<true/>
|
<true/>
|
||||||
|
|
||||||
<key>ControlURL</key>
|
<key>ControlURL</key>
|
||||||
<string>{{.Url}}</string>
|
<string>{{.URL}}</string>
|
||||||
</dict>
|
</dict>
|
||||||
`))
|
`))
|
||||||
|
|
19
cli_test.go
19
cli_test.go
|
@ -7,31 +7,34 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Suite) TestRegisterMachine(c *check.C) {
|
func (s *Suite) TestRegisterMachine(c *check.C) {
|
||||||
n, err := h.CreateNamespace("test")
|
namespace, err := app.CreateNamespace("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
|
|
||||||
m := Machine{
|
machine := Machine{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: "8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e",
|
MachineKey: "8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e",
|
||||||
NodeKey: "bar",
|
NodeKey: "bar",
|
||||||
DiscoKey: "faa",
|
DiscoKey: "faa",
|
||||||
Name: "testmachine",
|
Name: "testmachine",
|
||||||
NamespaceID: n.ID,
|
NamespaceID: namespace.ID,
|
||||||
IPAddress: "10.0.0.1",
|
IPAddress: "10.0.0.1",
|
||||||
Expiry: &now,
|
Expiry: &now,
|
||||||
RequestedExpiry: &now,
|
RequestedExpiry: &now,
|
||||||
}
|
}
|
||||||
h.db.Save(&m)
|
app.db.Save(&machine)
|
||||||
|
|
||||||
_, err = h.GetMachine("test", "testmachine")
|
_, err = app.GetMachine("test", "testmachine")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
m2, err := h.RegisterMachine("8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e", n.Name)
|
machineAfterRegistering, err := app.RegisterMachine(
|
||||||
|
"8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e",
|
||||||
|
namespace.Name,
|
||||||
|
)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(m2.Registered, check.Equals, true)
|
c.Assert(machineAfterRegistering.Registered, check.Equals, true)
|
||||||
|
|
||||||
_, err = m2.GetHostInfo()
|
_, err = machineAfterRegistering.GetHostInfo()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,7 +27,8 @@ func init() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal().Err(err).Msg("")
|
log.Fatal().Err(err).Msg("")
|
||||||
}
|
}
|
||||||
createNodeCmd.Flags().StringSliceP("route", "r", []string{}, "List (or repeated flags) of routes to advertise")
|
createNodeCmd.Flags().
|
||||||
|
StringSliceP("route", "r", []string{}, "List (or repeated flags) of routes to advertise")
|
||||||
|
|
||||||
debugCmd.AddCommand(createNodeCmd)
|
debugCmd.AddCommand(createNodeCmd)
|
||||||
}
|
}
|
||||||
|
@ -47,6 +48,7 @@ var createNodeCmd = &cobra.Command{
|
||||||
namespace, err := cmd.Flags().GetString("namespace")
|
namespace, err := cmd.Flags().GetString("namespace")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
|
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -56,19 +58,34 @@ var createNodeCmd = &cobra.Command{
|
||||||
|
|
||||||
name, err := cmd.Flags().GetString("name")
|
name, err := cmd.Flags().GetString("name")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Error getting node from flag: %s", err), output)
|
ErrorOutput(
|
||||||
|
err,
|
||||||
|
fmt.Sprintf("Error getting node from flag: %s", err),
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
machineKey, err := cmd.Flags().GetString("key")
|
machineKey, err := cmd.Flags().GetString("key")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Error getting key from flag: %s", err), output)
|
ErrorOutput(
|
||||||
|
err,
|
||||||
|
fmt.Sprintf("Error getting key from flag: %s", err),
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
routes, err := cmd.Flags().GetStringSlice("route")
|
routes, err := cmd.Flags().GetStringSlice("route")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Error getting routes from flag: %s", err), output)
|
ErrorOutput(
|
||||||
|
err,
|
||||||
|
fmt.Sprintf("Error getting routes from flag: %s", err),
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -81,7 +98,12 @@ var createNodeCmd = &cobra.Command{
|
||||||
|
|
||||||
response, err := client.DebugCreateMachine(ctx, request)
|
response, err := client.DebugCreateMachine(ctx, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Cannot create machine: %s", status.Convert(err).Message()), output)
|
ErrorOutput(
|
||||||
|
err,
|
||||||
|
fmt.Sprintf("Cannot create machine: %s", status.Convert(err).Message()),
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
survey "github.com/AlecAivazis/survey/v2"
|
survey "github.com/AlecAivazis/survey/v2"
|
||||||
|
"github.com/juanfont/headscale"
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
"github.com/pterm/pterm"
|
"github.com/pterm/pterm"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
@ -19,6 +20,10 @@ func init() {
|
||||||
namespaceCmd.AddCommand(renameNamespaceCmd)
|
namespaceCmd.AddCommand(renameNamespaceCmd)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
errMissingParameter = headscale.Error("missing parameters")
|
||||||
|
)
|
||||||
|
|
||||||
var namespaceCmd = &cobra.Command{
|
var namespaceCmd = &cobra.Command{
|
||||||
Use: "namespaces",
|
Use: "namespaces",
|
||||||
Short: "Manage the namespaces of Headscale",
|
Short: "Manage the namespaces of Headscale",
|
||||||
|
@ -29,8 +34,9 @@ var createNamespaceCmd = &cobra.Command{
|
||||||
Short: "Creates a new namespace",
|
Short: "Creates a new namespace",
|
||||||
Args: func(cmd *cobra.Command, args []string) error {
|
Args: func(cmd *cobra.Command, args []string) error {
|
||||||
if len(args) < 1 {
|
if len(args) < 1 {
|
||||||
return fmt.Errorf("Missing parameters")
|
return errMissingParameter
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
|
@ -49,7 +55,15 @@ var createNamespaceCmd = &cobra.Command{
|
||||||
log.Trace().Interface("request", request).Msg("Sending CreateNamespace request")
|
log.Trace().Interface("request", request).Msg("Sending CreateNamespace request")
|
||||||
response, err := client.CreateNamespace(ctx, request)
|
response, err := client.CreateNamespace(ctx, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Cannot create namespace: %s", status.Convert(err).Message()), output)
|
ErrorOutput(
|
||||||
|
err,
|
||||||
|
fmt.Sprintf(
|
||||||
|
"Cannot create namespace: %s",
|
||||||
|
status.Convert(err).Message(),
|
||||||
|
),
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -62,8 +76,9 @@ var destroyNamespaceCmd = &cobra.Command{
|
||||||
Short: "Destroys a namespace",
|
Short: "Destroys a namespace",
|
||||||
Args: func(cmd *cobra.Command, args []string) error {
|
Args: func(cmd *cobra.Command, args []string) error {
|
||||||
if len(args) < 1 {
|
if len(args) < 1 {
|
||||||
return fmt.Errorf("Missing parameters")
|
return errMissingParameter
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
|
@ -81,7 +96,12 @@ var destroyNamespaceCmd = &cobra.Command{
|
||||||
|
|
||||||
_, err := client.GetNamespace(ctx, request)
|
_, err := client.GetNamespace(ctx, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Error: %s", status.Convert(err).Message()), output)
|
ErrorOutput(
|
||||||
|
err,
|
||||||
|
fmt.Sprintf("Error: %s", status.Convert(err).Message()),
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -89,7 +109,10 @@ var destroyNamespaceCmd = &cobra.Command{
|
||||||
force, _ := cmd.Flags().GetBool("force")
|
force, _ := cmd.Flags().GetBool("force")
|
||||||
if !force {
|
if !force {
|
||||||
prompt := &survey.Confirm{
|
prompt := &survey.Confirm{
|
||||||
Message: fmt.Sprintf("Do you want to remove the namespace '%s' and any associated preauthkeys?", namespaceName),
|
Message: fmt.Sprintf(
|
||||||
|
"Do you want to remove the namespace '%s' and any associated preauthkeys?",
|
||||||
|
namespaceName,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
err := survey.AskOne(prompt, &confirm)
|
err := survey.AskOne(prompt, &confirm)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -102,7 +125,15 @@ var destroyNamespaceCmd = &cobra.Command{
|
||||||
|
|
||||||
response, err := client.DeleteNamespace(ctx, request)
|
response, err := client.DeleteNamespace(ctx, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Cannot destroy namespace: %s", status.Convert(err).Message()), output)
|
ErrorOutput(
|
||||||
|
err,
|
||||||
|
fmt.Sprintf(
|
||||||
|
"Cannot destroy namespace: %s",
|
||||||
|
status.Convert(err).Message(),
|
||||||
|
),
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
SuccessOutput(response, "Namespace destroyed", output)
|
SuccessOutput(response, "Namespace destroyed", output)
|
||||||
|
@ -126,19 +157,25 @@ var listNamespacesCmd = &cobra.Command{
|
||||||
|
|
||||||
response, err := client.ListNamespaces(ctx, request)
|
response, err := client.ListNamespaces(ctx, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Cannot get namespaces: %s", status.Convert(err).Message()), output)
|
ErrorOutput(
|
||||||
|
err,
|
||||||
|
fmt.Sprintf("Cannot get namespaces: %s", status.Convert(err).Message()),
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if output != "" {
|
if output != "" {
|
||||||
SuccessOutput(response.Namespaces, "", output)
|
SuccessOutput(response.Namespaces, "", output)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
d := pterm.TableData{{"ID", "Name", "Created"}}
|
tableData := pterm.TableData{{"ID", "Name", "Created"}}
|
||||||
for _, namespace := range response.GetNamespaces() {
|
for _, namespace := range response.GetNamespaces() {
|
||||||
d = append(
|
tableData = append(
|
||||||
d,
|
tableData,
|
||||||
[]string{
|
[]string{
|
||||||
namespace.GetId(),
|
namespace.GetId(),
|
||||||
namespace.GetName(),
|
namespace.GetName(),
|
||||||
|
@ -146,9 +183,14 @@ var listNamespacesCmd = &cobra.Command{
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
|
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Failed to render pterm table: %s", err), output)
|
ErrorOutput(
|
||||||
|
err,
|
||||||
|
fmt.Sprintf("Failed to render pterm table: %s", err),
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -158,9 +200,11 @@ var renameNamespaceCmd = &cobra.Command{
|
||||||
Use: "rename OLD_NAME NEW_NAME",
|
Use: "rename OLD_NAME NEW_NAME",
|
||||||
Short: "Renames a namespace",
|
Short: "Renames a namespace",
|
||||||
Args: func(cmd *cobra.Command, args []string) error {
|
Args: func(cmd *cobra.Command, args []string) error {
|
||||||
if len(args) < 2 {
|
expectedArguments := 2
|
||||||
return fmt.Errorf("Missing parameters")
|
if len(args) < expectedArguments {
|
||||||
|
return errMissingParameter
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
|
@ -177,7 +221,15 @@ var renameNamespaceCmd = &cobra.Command{
|
||||||
|
|
||||||
response, err := client.RenameNamespace(ctx, request)
|
response, err := client.RenameNamespace(ctx, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Cannot rename namespace: %s", status.Convert(err).Message()), output)
|
ErrorOutput(
|
||||||
|
err,
|
||||||
|
fmt.Sprintf(
|
||||||
|
"Cannot rename namespace: %s",
|
||||||
|
status.Convert(err).Message(),
|
||||||
|
),
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
survey "github.com/AlecAivazis/survey/v2"
|
survey "github.com/AlecAivazis/survey/v2"
|
||||||
|
"github.com/juanfont/headscale"
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
"github.com/pterm/pterm"
|
"github.com/pterm/pterm"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
@ -77,6 +78,7 @@ var registerNodeCmd = &cobra.Command{
|
||||||
namespace, err := cmd.Flags().GetString("namespace")
|
namespace, err := cmd.Flags().GetString("namespace")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
|
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -86,7 +88,12 @@ var registerNodeCmd = &cobra.Command{
|
||||||
|
|
||||||
machineKey, err := cmd.Flags().GetString("key")
|
machineKey, err := cmd.Flags().GetString("key")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Error getting machine key from flag: %s", err), output)
|
ErrorOutput(
|
||||||
|
err,
|
||||||
|
fmt.Sprintf("Error getting machine key from flag: %s", err),
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -97,7 +104,15 @@ var registerNodeCmd = &cobra.Command{
|
||||||
|
|
||||||
response, err := client.RegisterMachine(ctx, request)
|
response, err := client.RegisterMachine(ctx, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Cannot register machine: %s\n", status.Convert(err).Message()), output)
|
ErrorOutput(
|
||||||
|
err,
|
||||||
|
fmt.Sprintf(
|
||||||
|
"Cannot register machine: %s\n",
|
||||||
|
status.Convert(err).Message(),
|
||||||
|
),
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -113,6 +128,7 @@ var listNodesCmd = &cobra.Command{
|
||||||
namespace, err := cmd.Flags().GetString("namespace")
|
namespace, err := cmd.Flags().GetString("namespace")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
|
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -126,24 +142,36 @@ var listNodesCmd = &cobra.Command{
|
||||||
|
|
||||||
response, err := client.ListMachines(ctx, request)
|
response, err := client.ListMachines(ctx, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()), output)
|
ErrorOutput(
|
||||||
|
err,
|
||||||
|
fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()),
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if output != "" {
|
if output != "" {
|
||||||
SuccessOutput(response.Machines, "", output)
|
SuccessOutput(response.Machines, "", output)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
d, err := nodesToPtables(namespace, response.Machines)
|
tableData, err := nodesToPtables(namespace, response.Machines)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
|
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
|
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Failed to render pterm table: %s", err), output)
|
ErrorOutput(
|
||||||
|
err,
|
||||||
|
fmt.Sprintf("Failed to render pterm table: %s", err),
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -155,9 +183,14 @@ var deleteNodeCmd = &cobra.Command{
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
output, _ := cmd.Flags().GetString("output")
|
output, _ := cmd.Flags().GetString("output")
|
||||||
|
|
||||||
id, err := cmd.Flags().GetInt("identifier")
|
identifier, err := cmd.Flags().GetInt("identifier")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Error converting ID to integer: %s", err), output)
|
ErrorOutput(
|
||||||
|
err,
|
||||||
|
fmt.Sprintf("Error converting ID to integer: %s", err),
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -166,24 +199,35 @@ var deleteNodeCmd = &cobra.Command{
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
getRequest := &v1.GetMachineRequest{
|
getRequest := &v1.GetMachineRequest{
|
||||||
MachineId: uint64(id),
|
MachineId: uint64(identifier),
|
||||||
}
|
}
|
||||||
|
|
||||||
getResponse, err := client.GetMachine(ctx, getRequest)
|
getResponse, err := client.GetMachine(ctx, getRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Error getting node node: %s", status.Convert(err).Message()), output)
|
ErrorOutput(
|
||||||
|
err,
|
||||||
|
fmt.Sprintf(
|
||||||
|
"Error getting node node: %s",
|
||||||
|
status.Convert(err).Message(),
|
||||||
|
),
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
deleteRequest := &v1.DeleteMachineRequest{
|
deleteRequest := &v1.DeleteMachineRequest{
|
||||||
MachineId: uint64(id),
|
MachineId: uint64(identifier),
|
||||||
}
|
}
|
||||||
|
|
||||||
confirm := false
|
confirm := false
|
||||||
force, _ := cmd.Flags().GetBool("force")
|
force, _ := cmd.Flags().GetBool("force")
|
||||||
if !force {
|
if !force {
|
||||||
prompt := &survey.Confirm{
|
prompt := &survey.Confirm{
|
||||||
Message: fmt.Sprintf("Do you want to remove the node %s?", getResponse.GetMachine().Name),
|
Message: fmt.Sprintf(
|
||||||
|
"Do you want to remove the node %s?",
|
||||||
|
getResponse.GetMachine().Name,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
err = survey.AskOne(prompt, &confirm)
|
err = survey.AskOne(prompt, &confirm)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -195,13 +239,26 @@ var deleteNodeCmd = &cobra.Command{
|
||||||
response, err := client.DeleteMachine(ctx, deleteRequest)
|
response, err := client.DeleteMachine(ctx, deleteRequest)
|
||||||
if output != "" {
|
if output != "" {
|
||||||
SuccessOutput(response, "", output)
|
SuccessOutput(response, "", output)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Error deleting node: %s", status.Convert(err).Message()), output)
|
ErrorOutput(
|
||||||
|
err,
|
||||||
|
fmt.Sprintf(
|
||||||
|
"Error deleting node: %s",
|
||||||
|
status.Convert(err).Message(),
|
||||||
|
),
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
SuccessOutput(map[string]string{"Result": "Node deleted"}, "Node deleted", output)
|
SuccessOutput(
|
||||||
|
map[string]string{"Result": "Node deleted"},
|
||||||
|
"Node deleted",
|
||||||
|
output,
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
SuccessOutput(map[string]string{"Result": "Node not deleted"}, "Node not deleted", output)
|
SuccessOutput(map[string]string{"Result": "Node not deleted"}, "Node not deleted", output)
|
||||||
}
|
}
|
||||||
|
@ -210,12 +267,12 @@ var deleteNodeCmd = &cobra.Command{
|
||||||
|
|
||||||
func sharingWorker(
|
func sharingWorker(
|
||||||
cmd *cobra.Command,
|
cmd *cobra.Command,
|
||||||
args []string,
|
|
||||||
) (string, *v1.Machine, *v1.Namespace, error) {
|
) (string, *v1.Machine, *v1.Namespace, error) {
|
||||||
output, _ := cmd.Flags().GetString("output")
|
output, _ := cmd.Flags().GetString("output")
|
||||||
namespaceStr, err := cmd.Flags().GetString("namespace")
|
namespaceStr, err := cmd.Flags().GetString("namespace")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
|
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
|
||||||
|
|
||||||
return "", nil, nil, err
|
return "", nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -223,19 +280,25 @@ func sharingWorker(
|
||||||
defer cancel()
|
defer cancel()
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
id, err := cmd.Flags().GetInt("identifier")
|
identifier, err := cmd.Flags().GetInt("identifier")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Error converting ID to integer: %s", err), output)
|
ErrorOutput(err, fmt.Sprintf("Error converting ID to integer: %s", err), output)
|
||||||
|
|
||||||
return "", nil, nil, err
|
return "", nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
machineRequest := &v1.GetMachineRequest{
|
machineRequest := &v1.GetMachineRequest{
|
||||||
MachineId: uint64(id),
|
MachineId: uint64(identifier),
|
||||||
}
|
}
|
||||||
|
|
||||||
machineResponse, err := client.GetMachine(ctx, machineRequest)
|
machineResponse, err := client.GetMachine(ctx, machineRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Error getting node node: %s", status.Convert(err).Message()), output)
|
ErrorOutput(
|
||||||
|
err,
|
||||||
|
fmt.Sprintf("Error getting node node: %s", status.Convert(err).Message()),
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
|
||||||
return "", nil, nil, err
|
return "", nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -245,7 +308,12 @@ func sharingWorker(
|
||||||
|
|
||||||
namespaceResponse, err := client.GetNamespace(ctx, namespaceRequest)
|
namespaceResponse, err := client.GetNamespace(ctx, namespaceRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Error getting node node: %s", status.Convert(err).Message()), output)
|
ErrorOutput(
|
||||||
|
err,
|
||||||
|
fmt.Sprintf("Error getting node node: %s", status.Convert(err).Message()),
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
|
||||||
return "", nil, nil, err
|
return "", nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -256,9 +324,14 @@ var shareMachineCmd = &cobra.Command{
|
||||||
Use: "share",
|
Use: "share",
|
||||||
Short: "Shares a node from the current namespace to the specified one",
|
Short: "Shares a node from the current namespace to the specified one",
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
output, machine, namespace, err := sharingWorker(cmd, args)
|
output, machine, namespace, err := sharingWorker(cmd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Failed to fetch namespace or machine: %s", err), output)
|
ErrorOutput(
|
||||||
|
err,
|
||||||
|
fmt.Sprintf("Failed to fetch namespace or machine: %s", err),
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -273,7 +346,12 @@ var shareMachineCmd = &cobra.Command{
|
||||||
|
|
||||||
response, err := client.ShareMachine(ctx, request)
|
response, err := client.ShareMachine(ctx, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Error sharing node: %s", status.Convert(err).Message()), output)
|
ErrorOutput(
|
||||||
|
err,
|
||||||
|
fmt.Sprintf("Error sharing node: %s", status.Convert(err).Message()),
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -285,9 +363,14 @@ var unshareMachineCmd = &cobra.Command{
|
||||||
Use: "unshare",
|
Use: "unshare",
|
||||||
Short: "Unshares a node from the specified namespace",
|
Short: "Unshares a node from the specified namespace",
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
output, machine, namespace, err := sharingWorker(cmd, args)
|
output, machine, namespace, err := sharingWorker(cmd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Failed to fetch namespace or machine: %s", err), output)
|
ErrorOutput(
|
||||||
|
err,
|
||||||
|
fmt.Sprintf("Failed to fetch namespace or machine: %s", err),
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -302,7 +385,12 @@ var unshareMachineCmd = &cobra.Command{
|
||||||
|
|
||||||
response, err := client.UnshareMachine(ctx, request)
|
response, err := client.UnshareMachine(ctx, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Error unsharing node: %s", status.Convert(err).Message()), output)
|
ErrorOutput(
|
||||||
|
err,
|
||||||
|
fmt.Sprintf("Error unsharing node: %s", status.Convert(err).Message()),
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -310,8 +398,22 @@ var unshareMachineCmd = &cobra.Command{
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func nodesToPtables(currentNamespace string, machines []*v1.Machine) (pterm.TableData, error) {
|
func nodesToPtables(
|
||||||
d := pterm.TableData{{"ID", "Name", "NodeKey", "Namespace", "IP address", "Ephemeral", "Last seen", "Online"}}
|
currentNamespace string,
|
||||||
|
machines []*v1.Machine,
|
||||||
|
) (pterm.TableData, error) {
|
||||||
|
tableData := pterm.TableData{
|
||||||
|
{
|
||||||
|
"ID",
|
||||||
|
"Name",
|
||||||
|
"NodeKey",
|
||||||
|
"Namespace",
|
||||||
|
"IP address",
|
||||||
|
"Ephemeral",
|
||||||
|
"Last seen",
|
||||||
|
"Online",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
for _, machine := range machines {
|
for _, machine := range machines {
|
||||||
var ephemeral bool
|
var ephemeral bool
|
||||||
|
@ -331,7 +433,9 @@ func nodesToPtables(currentNamespace string, machines []*v1.Machine) (pterm.Tabl
|
||||||
nodeKey := tailcfg.NodeKey(nKey)
|
nodeKey := tailcfg.NodeKey(nKey)
|
||||||
|
|
||||||
var online string
|
var online string
|
||||||
if lastSeen.After(time.Now().Add(-5 * time.Minute)) { // TODO: Find a better way to reliably show if online
|
if lastSeen.After(
|
||||||
|
time.Now().Add(-5 * time.Minute),
|
||||||
|
) { // TODO: Find a better way to reliably show if online
|
||||||
online = pterm.LightGreen("true")
|
online = pterm.LightGreen("true")
|
||||||
} else {
|
} else {
|
||||||
online = pterm.LightRed("false")
|
online = pterm.LightRed("false")
|
||||||
|
@ -344,10 +448,10 @@ func nodesToPtables(currentNamespace string, machines []*v1.Machine) (pterm.Tabl
|
||||||
// Shared into this namespace
|
// Shared into this namespace
|
||||||
namespace = pterm.LightYellow(machine.Namespace.Name)
|
namespace = pterm.LightYellow(machine.Namespace.Name)
|
||||||
}
|
}
|
||||||
d = append(
|
tableData = append(
|
||||||
d,
|
tableData,
|
||||||
[]string{
|
[]string{
|
||||||
strconv.FormatUint(machine.Id, 10),
|
strconv.FormatUint(machine.Id, headscale.Base10),
|
||||||
machine.Name,
|
machine.Name,
|
||||||
nodeKey.ShortString(),
|
nodeKey.ShortString(),
|
||||||
namespace,
|
namespace,
|
||||||
|
@ -358,5 +462,6 @@ func nodesToPtables(currentNamespace string, machines []*v1.Machine) (pterm.Tabl
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
return d, nil
|
|
||||||
|
return tableData, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,10 @@ import (
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
DefaultPreAuthKeyExpiry = 24 * time.Hour
|
||||||
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
rootCmd.AddCommand(preauthkeysCmd)
|
rootCmd.AddCommand(preauthkeysCmd)
|
||||||
preauthkeysCmd.PersistentFlags().StringP("namespace", "n", "", "Namespace")
|
preauthkeysCmd.PersistentFlags().StringP("namespace", "n", "", "Namespace")
|
||||||
|
@ -22,10 +26,12 @@ func init() {
|
||||||
preauthkeysCmd.AddCommand(listPreAuthKeys)
|
preauthkeysCmd.AddCommand(listPreAuthKeys)
|
||||||
preauthkeysCmd.AddCommand(createPreAuthKeyCmd)
|
preauthkeysCmd.AddCommand(createPreAuthKeyCmd)
|
||||||
preauthkeysCmd.AddCommand(expirePreAuthKeyCmd)
|
preauthkeysCmd.AddCommand(expirePreAuthKeyCmd)
|
||||||
createPreAuthKeyCmd.PersistentFlags().Bool("reusable", false, "Make the preauthkey reusable")
|
createPreAuthKeyCmd.PersistentFlags().
|
||||||
createPreAuthKeyCmd.PersistentFlags().Bool("ephemeral", false, "Preauthkey for ephemeral nodes")
|
Bool("reusable", false, "Make the preauthkey reusable")
|
||||||
|
createPreAuthKeyCmd.PersistentFlags().
|
||||||
|
Bool("ephemeral", false, "Preauthkey for ephemeral nodes")
|
||||||
createPreAuthKeyCmd.Flags().
|
createPreAuthKeyCmd.Flags().
|
||||||
DurationP("expiration", "e", 24*time.Hour, "Human-readable expiration of the key (30m, 24h, 365d...)")
|
DurationP("expiration", "e", DefaultPreAuthKeyExpiry, "Human-readable expiration of the key (30m, 24h, 365d...)")
|
||||||
}
|
}
|
||||||
|
|
||||||
var preauthkeysCmd = &cobra.Command{
|
var preauthkeysCmd = &cobra.Command{
|
||||||
|
@ -39,9 +45,10 @@ var listPreAuthKeys = &cobra.Command{
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
output, _ := cmd.Flags().GetString("output")
|
output, _ := cmd.Flags().GetString("output")
|
||||||
|
|
||||||
n, err := cmd.Flags().GetString("namespace")
|
namespace, err := cmd.Flags().GetString("namespace")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
|
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -50,48 +57,61 @@ var listPreAuthKeys = &cobra.Command{
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
request := &v1.ListPreAuthKeysRequest{
|
request := &v1.ListPreAuthKeysRequest{
|
||||||
Namespace: n,
|
Namespace: namespace,
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := client.ListPreAuthKeys(ctx, request)
|
response, err := client.ListPreAuthKeys(ctx, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Error getting the list of keys: %s", err), output)
|
ErrorOutput(
|
||||||
|
err,
|
||||||
|
fmt.Sprintf("Error getting the list of keys: %s", err),
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if output != "" {
|
if output != "" {
|
||||||
SuccessOutput(response.PreAuthKeys, "", output)
|
SuccessOutput(response.PreAuthKeys, "", output)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
d := pterm.TableData{{"ID", "Key", "Reusable", "Ephemeral", "Used", "Expiration", "Created"}}
|
tableData := pterm.TableData{
|
||||||
for _, k := range response.PreAuthKeys {
|
{"ID", "Key", "Reusable", "Ephemeral", "Used", "Expiration", "Created"},
|
||||||
|
}
|
||||||
|
for _, key := range response.PreAuthKeys {
|
||||||
expiration := "-"
|
expiration := "-"
|
||||||
if k.GetExpiration() != nil {
|
if key.GetExpiration() != nil {
|
||||||
expiration = k.Expiration.AsTime().Format("2006-01-02 15:04:05")
|
expiration = key.Expiration.AsTime().Format("2006-01-02 15:04:05")
|
||||||
}
|
}
|
||||||
|
|
||||||
var reusable string
|
var reusable string
|
||||||
if k.GetEphemeral() {
|
if key.GetEphemeral() {
|
||||||
reusable = "N/A"
|
reusable = "N/A"
|
||||||
} else {
|
} else {
|
||||||
reusable = fmt.Sprintf("%v", k.GetReusable())
|
reusable = fmt.Sprintf("%v", key.GetReusable())
|
||||||
}
|
}
|
||||||
|
|
||||||
d = append(d, []string{
|
tableData = append(tableData, []string{
|
||||||
k.GetId(),
|
key.GetId(),
|
||||||
k.GetKey(),
|
key.GetKey(),
|
||||||
reusable,
|
reusable,
|
||||||
strconv.FormatBool(k.GetEphemeral()),
|
strconv.FormatBool(key.GetEphemeral()),
|
||||||
strconv.FormatBool(k.GetUsed()),
|
strconv.FormatBool(key.GetUsed()),
|
||||||
expiration,
|
expiration,
|
||||||
k.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"),
|
key.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"),
|
||||||
})
|
})
|
||||||
|
|
||||||
}
|
}
|
||||||
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
|
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Failed to render pterm table: %s", err), output)
|
ErrorOutput(
|
||||||
|
err,
|
||||||
|
fmt.Sprintf("Failed to render pterm table: %s", err),
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -106,6 +126,7 @@ var createPreAuthKeyCmd = &cobra.Command{
|
||||||
namespace, err := cmd.Flags().GetString("namespace")
|
namespace, err := cmd.Flags().GetString("namespace")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
|
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -139,7 +160,12 @@ var createPreAuthKeyCmd = &cobra.Command{
|
||||||
|
|
||||||
response, err := client.CreatePreAuthKey(ctx, request)
|
response, err := client.CreatePreAuthKey(ctx, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Cannot create Pre Auth Key: %s\n", err), output)
|
ErrorOutput(
|
||||||
|
err,
|
||||||
|
fmt.Sprintf("Cannot create Pre Auth Key: %s\n", err),
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -152,8 +178,9 @@ var expirePreAuthKeyCmd = &cobra.Command{
|
||||||
Short: "Expire a preauthkey",
|
Short: "Expire a preauthkey",
|
||||||
Args: func(cmd *cobra.Command, args []string) error {
|
Args: func(cmd *cobra.Command, args []string) error {
|
||||||
if len(args) < 1 {
|
if len(args) < 1 {
|
||||||
return fmt.Errorf("missing parameters")
|
return errMissingParameter
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
|
@ -161,6 +188,7 @@ var expirePreAuthKeyCmd = &cobra.Command{
|
||||||
namespace, err := cmd.Flags().GetString("namespace")
|
namespace, err := cmd.Flags().GetString("namespace")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
|
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -175,7 +203,12 @@ var expirePreAuthKeyCmd = &cobra.Command{
|
||||||
|
|
||||||
response, err := client.ExpirePreAuthKey(ctx, request)
|
response, err := client.ExpirePreAuthKey(ctx, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Cannot expire Pre Auth Key: %s\n", err), output)
|
ErrorOutput(
|
||||||
|
err,
|
||||||
|
fmt.Sprintf("Cannot expire Pre Auth Key: %s\n", err),
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,8 @@ import (
|
||||||
func init() {
|
func init() {
|
||||||
rootCmd.PersistentFlags().
|
rootCmd.PersistentFlags().
|
||||||
StringP("output", "o", "", "Output format. Empty for human-readable, 'json', 'json-line' or 'yaml'")
|
StringP("output", "o", "", "Output format. Empty for human-readable, 'json', 'json-line' or 'yaml'")
|
||||||
rootCmd.PersistentFlags().Bool("force", false, "Disable prompts and forces the execution")
|
rootCmd.PersistentFlags().
|
||||||
|
Bool("force", false, "Disable prompts and forces the execution")
|
||||||
}
|
}
|
||||||
|
|
||||||
var rootCmd = &cobra.Command{
|
var rootCmd = &cobra.Command{
|
||||||
|
|
|
@ -21,7 +21,8 @@ func init() {
|
||||||
}
|
}
|
||||||
routesCmd.AddCommand(listRoutesCmd)
|
routesCmd.AddCommand(listRoutesCmd)
|
||||||
|
|
||||||
enableRouteCmd.Flags().StringSliceP("route", "r", []string{}, "List (or repeated flags) of routes to enable")
|
enableRouteCmd.Flags().
|
||||||
|
StringSliceP("route", "r", []string{}, "List (or repeated flags) of routes to enable")
|
||||||
enableRouteCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
|
enableRouteCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
|
||||||
err = enableRouteCmd.MarkFlagRequired("identifier")
|
err = enableRouteCmd.MarkFlagRequired("identifier")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -44,9 +45,14 @@ var listRoutesCmd = &cobra.Command{
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
output, _ := cmd.Flags().GetString("output")
|
output, _ := cmd.Flags().GetString("output")
|
||||||
|
|
||||||
machineId, err := cmd.Flags().GetUint64("identifier")
|
machineID, err := cmd.Flags().GetUint64("identifier")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Error getting machine id from flag: %s", err), output)
|
ErrorOutput(
|
||||||
|
err,
|
||||||
|
fmt.Sprintf("Error getting machine id from flag: %s", err),
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -55,29 +61,41 @@ var listRoutesCmd = &cobra.Command{
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
request := &v1.GetMachineRouteRequest{
|
request := &v1.GetMachineRouteRequest{
|
||||||
MachineId: machineId,
|
MachineId: machineID,
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := client.GetMachineRoute(ctx, request)
|
response, err := client.GetMachineRoute(ctx, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()), output)
|
ErrorOutput(
|
||||||
|
err,
|
||||||
|
fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()),
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if output != "" {
|
if output != "" {
|
||||||
SuccessOutput(response.Routes, "", output)
|
SuccessOutput(response.Routes, "", output)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
d := routesToPtables(response.Routes)
|
tableData := routesToPtables(response.Routes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
|
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
|
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Failed to render pterm table: %s", err), output)
|
ErrorOutput(
|
||||||
|
err,
|
||||||
|
fmt.Sprintf("Failed to render pterm table: %s", err),
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -93,15 +111,26 @@ omit the route you do not want to enable.
|
||||||
`,
|
`,
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
output, _ := cmd.Flags().GetString("output")
|
output, _ := cmd.Flags().GetString("output")
|
||||||
machineId, err := cmd.Flags().GetUint64("identifier")
|
|
||||||
|
machineID, err := cmd.Flags().GetUint64("identifier")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Error getting machine id from flag: %s", err), output)
|
ErrorOutput(
|
||||||
|
err,
|
||||||
|
fmt.Sprintf("Error getting machine id from flag: %s", err),
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
routes, err := cmd.Flags().GetStringSlice("route")
|
routes, err := cmd.Flags().GetStringSlice("route")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Error getting routes from flag: %s", err), output)
|
ErrorOutput(
|
||||||
|
err,
|
||||||
|
fmt.Sprintf("Error getting routes from flag: %s", err),
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -110,45 +139,61 @@ omit the route you do not want to enable.
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
request := &v1.EnableMachineRoutesRequest{
|
request := &v1.EnableMachineRoutesRequest{
|
||||||
MachineId: machineId,
|
MachineId: machineID,
|
||||||
Routes: routes,
|
Routes: routes,
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := client.EnableMachineRoutes(ctx, request)
|
response, err := client.EnableMachineRoutes(ctx, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Cannot register machine: %s\n", status.Convert(err).Message()), output)
|
ErrorOutput(
|
||||||
|
err,
|
||||||
|
fmt.Sprintf(
|
||||||
|
"Cannot register machine: %s\n",
|
||||||
|
status.Convert(err).Message(),
|
||||||
|
),
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if output != "" {
|
if output != "" {
|
||||||
SuccessOutput(response.Routes, "", output)
|
SuccessOutput(response.Routes, "", output)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
d := routesToPtables(response.Routes)
|
tableData := routesToPtables(response.Routes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
|
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
|
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Failed to render pterm table: %s", err), output)
|
ErrorOutput(
|
||||||
|
err,
|
||||||
|
fmt.Sprintf("Failed to render pterm table: %s", err),
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// routesToPtables converts the list of routes to a nice table
|
// routesToPtables converts the list of routes to a nice table.
|
||||||
func routesToPtables(routes *v1.Routes) pterm.TableData {
|
func routesToPtables(routes *v1.Routes) pterm.TableData {
|
||||||
d := pterm.TableData{{"Route", "Enabled"}}
|
tableData := pterm.TableData{{"Route", "Enabled"}}
|
||||||
|
|
||||||
for _, route := range routes.GetAdvertisedRoutes() {
|
for _, route := range routes.GetAdvertisedRoutes() {
|
||||||
enabled := isStringInSlice(routes.EnabledRoutes, route)
|
enabled := isStringInSlice(routes.EnabledRoutes, route)
|
||||||
|
|
||||||
d = append(d, []string{route, strconv.FormatBool(enabled)})
|
tableData = append(tableData, []string{route, strconv.FormatBool(enabled)})
|
||||||
}
|
}
|
||||||
return d
|
|
||||||
|
return tableData
|
||||||
}
|
}
|
||||||
|
|
||||||
func isStringInSlice(strs []string, s string) bool {
|
func isStringInSlice(strs []string, s string) bool {
|
||||||
|
|
|
@ -52,9 +52,8 @@ func LoadConfig(path string) error {
|
||||||
viper.SetDefault("cli.insecure", false)
|
viper.SetDefault("cli.insecure", false)
|
||||||
viper.SetDefault("cli.timeout", "5s")
|
viper.SetDefault("cli.timeout", "5s")
|
||||||
|
|
||||||
err := viper.ReadInConfig()
|
if err := viper.ReadInConfig(); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("fatal error reading config file: %w", err)
|
||||||
return fmt.Errorf("Fatal error reading config file: %s \n", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Collect any validation errors and return them all at once
|
// Collect any validation errors and return them all at once
|
||||||
|
@ -82,6 +81,7 @@ func LoadConfig(path string) error {
|
||||||
errorText += "Fatal config error: server_url must start with https:// or http://\n"
|
errorText += "Fatal config error: server_url must start with https:// or http://\n"
|
||||||
}
|
}
|
||||||
if errorText != "" {
|
if errorText != "" {
|
||||||
|
//nolint
|
||||||
return errors.New(strings.TrimSuffix(errorText, "\n"))
|
return errors.New(strings.TrimSuffix(errorText, "\n"))
|
||||||
} else {
|
} else {
|
||||||
return nil
|
return nil
|
||||||
|
@ -149,9 +149,14 @@ func GetDNSConfig() (*tailcfg.DNSConfig, string) {
|
||||||
if viper.IsSet("dns_config.restricted_nameservers") {
|
if viper.IsSet("dns_config.restricted_nameservers") {
|
||||||
if len(dnsConfig.Nameservers) > 0 {
|
if len(dnsConfig.Nameservers) > 0 {
|
||||||
dnsConfig.Routes = make(map[string][]dnstype.Resolver)
|
dnsConfig.Routes = make(map[string][]dnstype.Resolver)
|
||||||
restrictedDNS := viper.GetStringMapStringSlice("dns_config.restricted_nameservers")
|
restrictedDNS := viper.GetStringMapStringSlice(
|
||||||
|
"dns_config.restricted_nameservers",
|
||||||
|
)
|
||||||
for domain, restrictedNameservers := range restrictedDNS {
|
for domain, restrictedNameservers := range restrictedDNS {
|
||||||
restrictedResolvers := make([]dnstype.Resolver, len(restrictedNameservers))
|
restrictedResolvers := make(
|
||||||
|
[]dnstype.Resolver,
|
||||||
|
len(restrictedNameservers),
|
||||||
|
)
|
||||||
for index, nameserverStr := range restrictedNameservers {
|
for index, nameserverStr := range restrictedNameservers {
|
||||||
nameserver, err := netaddr.ParseIP(nameserverStr)
|
nameserver, err := netaddr.ParseIP(nameserverStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -208,6 +213,7 @@ func absPath(path string) string {
|
||||||
path = filepath.Join(dir, path)
|
path = filepath.Join(dir, path)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return path
|
return path
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -219,7 +225,9 @@ func getHeadscaleConfig() headscale.Config {
|
||||||
"10h",
|
"10h",
|
||||||
) // use 10h here because it is the length of a standard business day plus a small amount of leeway
|
) // use 10h here because it is the length of a standard business day plus a small amount of leeway
|
||||||
if viper.GetDuration("max_machine_registration_duration") >= time.Second {
|
if viper.GetDuration("max_machine_registration_duration") >= time.Second {
|
||||||
maxMachineRegistrationDuration = viper.GetDuration("max_machine_registration_duration")
|
maxMachineRegistrationDuration = viper.GetDuration(
|
||||||
|
"max_machine_registration_duration",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// defaultMachineRegistrationDuration is the default time assigned to a machine registration if one is not
|
// defaultMachineRegistrationDuration is the default time assigned to a machine registration if one is not
|
||||||
|
@ -229,7 +237,9 @@ func getHeadscaleConfig() headscale.Config {
|
||||||
"8h",
|
"8h",
|
||||||
) // use 8h here because it's the length of a standard business day
|
) // use 8h here because it's the length of a standard business day
|
||||||
if viper.GetDuration("default_machine_registration_duration") >= time.Second {
|
if viper.GetDuration("default_machine_registration_duration") >= time.Second {
|
||||||
defaultMachineRegistrationDuration = viper.GetDuration("default_machine_registration_duration")
|
defaultMachineRegistrationDuration = viper.GetDuration(
|
||||||
|
"default_machine_registration_duration",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsConfig, baseDomain := GetDNSConfig()
|
dnsConfig, baseDomain := GetDNSConfig()
|
||||||
|
@ -244,7 +254,9 @@ func getHeadscaleConfig() headscale.Config {
|
||||||
|
|
||||||
DERP: derpConfig,
|
DERP: derpConfig,
|
||||||
|
|
||||||
EphemeralNodeInactivityTimeout: viper.GetDuration("ephemeral_node_inactivity_timeout"),
|
EphemeralNodeInactivityTimeout: viper.GetDuration(
|
||||||
|
"ephemeral_node_inactivity_timeout",
|
||||||
|
),
|
||||||
|
|
||||||
DBtype: viper.GetString("db_type"),
|
DBtype: viper.GetString("db_type"),
|
||||||
DBpath: absPath(viper.GetString("db_path")),
|
DBpath: absPath(viper.GetString("db_path")),
|
||||||
|
@ -256,7 +268,9 @@ func getHeadscaleConfig() headscale.Config {
|
||||||
|
|
||||||
TLSLetsEncryptHostname: viper.GetString("tls_letsencrypt_hostname"),
|
TLSLetsEncryptHostname: viper.GetString("tls_letsencrypt_hostname"),
|
||||||
TLSLetsEncryptListen: viper.GetString("tls_letsencrypt_listen"),
|
TLSLetsEncryptListen: viper.GetString("tls_letsencrypt_listen"),
|
||||||
TLSLetsEncryptCacheDir: absPath(viper.GetString("tls_letsencrypt_cache_dir")),
|
TLSLetsEncryptCacheDir: absPath(
|
||||||
|
viper.GetString("tls_letsencrypt_cache_dir"),
|
||||||
|
),
|
||||||
TLSLetsEncryptChallengeType: viper.GetString("tls_letsencrypt_challenge_type"),
|
TLSLetsEncryptChallengeType: viper.GetString("tls_letsencrypt_challenge_type"),
|
||||||
|
|
||||||
TLSCertPath: absPath(viper.GetString("tls_cert_path")),
|
TLSCertPath: absPath(viper.GetString("tls_cert_path")),
|
||||||
|
@ -292,11 +306,14 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
|
||||||
// to avoid races
|
// to avoid races
|
||||||
minInactivityTimeout, _ := time.ParseDuration("65s")
|
minInactivityTimeout, _ := time.ParseDuration("65s")
|
||||||
if viper.GetDuration("ephemeral_node_inactivity_timeout") <= minInactivityTimeout {
|
if viper.GetDuration("ephemeral_node_inactivity_timeout") <= minInactivityTimeout {
|
||||||
|
// TODO: Find a better way to return this text
|
||||||
|
//nolint
|
||||||
err := fmt.Errorf(
|
err := fmt.Errorf(
|
||||||
"ephemeral_node_inactivity_timeout (%s) is set too low, must be more than %s\n",
|
"ephemeral_node_inactivity_timeout (%s) is set too low, must be more than %s",
|
||||||
viper.GetString("ephemeral_node_inactivity_timeout"),
|
viper.GetString("ephemeral_node_inactivity_timeout"),
|
||||||
minInactivityTimeout,
|
minInactivityTimeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -304,7 +321,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
|
||||||
|
|
||||||
cfg.OIDC.MatchMap = loadOIDCMatchMap()
|
cfg.OIDC.MatchMap = loadOIDCMatchMap()
|
||||||
|
|
||||||
h, err := headscale.NewHeadscale(cfg)
|
app, err := headscale.NewHeadscale(cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -313,7 +330,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
|
||||||
|
|
||||||
if viper.GetString("acl_policy_path") != "" {
|
if viper.GetString("acl_policy_path") != "" {
|
||||||
aclPath := absPath(viper.GetString("acl_policy_path"))
|
aclPath := absPath(viper.GetString("acl_policy_path"))
|
||||||
err = h.LoadACLPolicy(aclPath)
|
err = app.LoadACLPolicy(aclPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("path", aclPath).
|
Str("path", aclPath).
|
||||||
|
@ -322,7 +339,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return h, nil
|
return app, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) {
|
func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) {
|
||||||
|
@ -342,7 +359,6 @@ func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.
|
||||||
|
|
||||||
// If the address is not set, we assume that we are on the server hosting headscale.
|
// If the address is not set, we assume that we are on the server hosting headscale.
|
||||||
if address == "" {
|
if address == "" {
|
||||||
|
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Str("socket", cfg.UnixSocket).
|
Str("socket", cfg.UnixSocket).
|
||||||
Msgf("HEADSCALE_CLI_ADDRESS environment is not set, connecting to unix socket.")
|
Msgf("HEADSCALE_CLI_ADDRESS environment is not set, connecting to unix socket.")
|
||||||
|
@ -402,10 +418,13 @@ func SuccessOutput(result interface{}, override string, outputFormat string) {
|
||||||
log.Fatal().Err(err)
|
log.Fatal().Err(err)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
//nolint
|
||||||
fmt.Println(override)
|
fmt.Println(override)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//nolint
|
||||||
fmt.Println(string(j))
|
fmt.Println(string(j))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -423,6 +442,7 @@ func HasMachineOutputFlag() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -431,7 +451,10 @@ type tokenAuth struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return value is mapped to request headers.
|
// Return value is mapped to request headers.
|
||||||
func (t tokenAuth) GetRequestMetadata(ctx context.Context, in ...string) (map[string]string, error) {
|
func (t tokenAuth) GetRequestMetadata(
|
||||||
|
ctx context.Context,
|
||||||
|
in ...string,
|
||||||
|
) (map[string]string, error) {
|
||||||
return map[string]string{
|
return map[string]string{
|
||||||
"authorization": "Bearer " + t.token,
|
"authorization": "Bearer " + t.token,
|
||||||
}, nil
|
}, nil
|
||||||
|
|
|
@ -23,6 +23,8 @@ func main() {
|
||||||
colors = true
|
colors = true
|
||||||
case termcolor.LevelBasic:
|
case termcolor.LevelBasic:
|
||||||
colors = true
|
colors = true
|
||||||
|
case termcolor.LevelNone:
|
||||||
|
colors = false
|
||||||
default:
|
default:
|
||||||
// no color, return text as is.
|
// no color, return text as is.
|
||||||
colors = false
|
colors = false
|
||||||
|
@ -41,8 +43,7 @@ func main() {
|
||||||
NoColor: !colors,
|
NoColor: !colors,
|
||||||
})
|
})
|
||||||
|
|
||||||
err := cli.LoadConfig("")
|
if err := cli.LoadConfig(""); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Fatal().Err(err)
|
log.Fatal().Err(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -63,13 +64,15 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
if !viper.GetBool("disable_check_updates") && !machineOutput {
|
if !viper.GetBool("disable_check_updates") && !machineOutput {
|
||||||
if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") && cli.Version != "dev" {
|
if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") &&
|
||||||
|
cli.Version != "dev" {
|
||||||
githubTag := &latest.GithubTag{
|
githubTag := &latest.GithubTag{
|
||||||
Owner: "juanfont",
|
Owner: "juanfont",
|
||||||
Repository: "headscale",
|
Repository: "headscale",
|
||||||
}
|
}
|
||||||
res, err := latest.Check(githubTag, cli.Version)
|
res, err := latest.Check(githubTag, cli.Version)
|
||||||
if err == nil && res.Outdated {
|
if err == nil && res.Outdated {
|
||||||
|
//nolint
|
||||||
fmt.Printf(
|
fmt.Printf(
|
||||||
"An updated version of Headscale has been found (%s vs. your current %s). Check it out https://github.com/juanfont/headscale/releases\n",
|
"An updated version of Headscale has been found (%s vs. your current %s). Check it out https://github.com/juanfont/headscale/releases\n",
|
||||||
res.Current,
|
res.Current,
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
@ -40,7 +39,10 @@ func (*Suite) TestConfigLoading(c *check.C) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Symlink the example config file
|
// Symlink the example config file
|
||||||
err = os.Symlink(filepath.Clean(path+"/../../config-example.yaml"), filepath.Join(tmpDir, "config.yaml"))
|
err = os.Symlink(
|
||||||
|
filepath.Clean(path+"/../../config-example.yaml"),
|
||||||
|
filepath.Join(tmpDir, "config.yaml"),
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Fatal(err)
|
c.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -74,7 +76,10 @@ func (*Suite) TestDNSConfigLoading(c *check.C) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Symlink the example config file
|
// Symlink the example config file
|
||||||
err = os.Symlink(filepath.Clean(path+"/../../config-example.yaml"), filepath.Join(tmpDir, "config.yaml"))
|
err = os.Symlink(
|
||||||
|
filepath.Clean(path+"/../../config-example.yaml"),
|
||||||
|
filepath.Join(tmpDir, "config.yaml"),
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Fatal(err)
|
c.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -94,7 +99,7 @@ func (*Suite) TestDNSConfigLoading(c *check.C) {
|
||||||
func writeConfig(c *check.C, tmpDir string, configYaml []byte) {
|
func writeConfig(c *check.C, tmpDir string, configYaml []byte) {
|
||||||
// Populate a custom config file
|
// Populate a custom config file
|
||||||
configFile := filepath.Join(tmpDir, "config.yaml")
|
configFile := filepath.Join(tmpDir, "config.yaml")
|
||||||
err := ioutil.WriteFile(configFile, configYaml, 0o644)
|
err := ioutil.WriteFile(configFile, configYaml, 0o600)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Fatalf("Couldn't write file %s", configFile)
|
c.Fatalf("Couldn't write file %s", configFile)
|
||||||
}
|
}
|
||||||
|
@ -106,7 +111,6 @@ func (*Suite) TestTLSConfigValidation(c *check.C) {
|
||||||
c.Fatal(err)
|
c.Fatal(err)
|
||||||
}
|
}
|
||||||
// defer os.RemoveAll(tmpDir)
|
// defer os.RemoveAll(tmpDir)
|
||||||
fmt.Println(tmpDir)
|
|
||||||
|
|
||||||
configYaml := []byte(
|
configYaml := []byte(
|
||||||
"---\ntls_letsencrypt_hostname: \"example.com\"\ntls_letsencrypt_challenge_type: \"\"\ntls_cert_path: \"abc.pem\"",
|
"---\ntls_letsencrypt_hostname: \"example.com\"\ntls_letsencrypt_challenge_type: \"\"\ntls_cert_path: \"abc.pem\"",
|
||||||
|
@ -128,8 +132,11 @@ func (*Suite) TestTLSConfigValidation(c *check.C) {
|
||||||
check.Matches,
|
check.Matches,
|
||||||
".*Fatal config error: the only supported values for tls_letsencrypt_challenge_type are.*",
|
".*Fatal config error: the only supported values for tls_letsencrypt_challenge_type are.*",
|
||||||
)
|
)
|
||||||
c.Assert(tmp, check.Matches, ".*Fatal config error: server_url must start with https:// or http://.*")
|
c.Assert(
|
||||||
fmt.Println(tmp)
|
tmp,
|
||||||
|
check.Matches,
|
||||||
|
".*Fatal config error: server_url must start with https:// or http://.*",
|
||||||
|
)
|
||||||
|
|
||||||
// Check configuration validation errors (2)
|
// Check configuration validation errors (2)
|
||||||
configYaml = []byte(
|
configYaml = []byte(
|
||||||
|
|
35
db.go
35
db.go
|
@ -9,7 +9,10 @@ import (
|
||||||
"gorm.io/gorm/logger"
|
"gorm.io/gorm/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
const dbVersion = "1"
|
const (
|
||||||
|
dbVersion = "1"
|
||||||
|
errValueNotFound = Error("not found")
|
||||||
|
)
|
||||||
|
|
||||||
// KV is a key-value store in a psql table. For future use...
|
// KV is a key-value store in a psql table. For future use...
|
||||||
type KV struct {
|
type KV struct {
|
||||||
|
@ -24,7 +27,7 @@ func (h *Headscale) initDB() error {
|
||||||
}
|
}
|
||||||
h.db = db
|
h.db = db
|
||||||
|
|
||||||
if h.dbType == "postgres" {
|
if h.dbType == Postgres {
|
||||||
db.Exec("create extension if not exists \"uuid-ossp\";")
|
db.Exec("create extension if not exists \"uuid-ossp\";")
|
||||||
}
|
}
|
||||||
err = db.AutoMigrate(&Machine{})
|
err = db.AutoMigrate(&Machine{})
|
||||||
|
@ -50,6 +53,7 @@ func (h *Headscale) initDB() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.setValue("db_version", dbVersion)
|
err = h.setValue("db_version", dbVersion)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -65,12 +69,12 @@ func (h *Headscale) openDB() (*gorm.DB, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
switch h.dbType {
|
switch h.dbType {
|
||||||
case "sqlite3":
|
case Sqlite:
|
||||||
db, err = gorm.Open(sqlite.Open(h.dbString), &gorm.Config{
|
db, err = gorm.Open(sqlite.Open(h.dbString), &gorm.Config{
|
||||||
DisableForeignKeyConstraintWhenMigrating: true,
|
DisableForeignKeyConstraintWhenMigrating: true,
|
||||||
Logger: log,
|
Logger: log,
|
||||||
})
|
})
|
||||||
case "postgres":
|
case Postgres:
|
||||||
db, err = gorm.Open(postgres.Open(h.dbString), &gorm.Config{
|
db, err = gorm.Open(postgres.Open(h.dbString), &gorm.Config{
|
||||||
DisableForeignKeyConstraintWhenMigrating: true,
|
DisableForeignKeyConstraintWhenMigrating: true,
|
||||||
Logger: log,
|
Logger: log,
|
||||||
|
@ -84,28 +88,33 @@ func (h *Headscale) openDB() (*gorm.DB, error) {
|
||||||
return db, nil
|
return db, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getValue returns the value for the given key in KV
|
// getValue returns the value for the given key in KV.
|
||||||
func (h *Headscale) getValue(key string) (string, error) {
|
func (h *Headscale) getValue(key string) (string, error) {
|
||||||
var row KV
|
var row KV
|
||||||
if result := h.db.First(&row, "key = ?", key); errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if result := h.db.First(&row, "key = ?", key); errors.Is(
|
||||||
return "", errors.New("not found")
|
result.Error,
|
||||||
|
gorm.ErrRecordNotFound,
|
||||||
|
) {
|
||||||
|
return "", errValueNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
return row.Value, nil
|
return row.Value, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// setValue sets value for the given key in KV
|
// setValue sets value for the given key in KV.
|
||||||
func (h *Headscale) setValue(key string, value string) error {
|
func (h *Headscale) setValue(key string, value string) error {
|
||||||
kv := KV{
|
keyValue := KV{
|
||||||
Key: key,
|
Key: key,
|
||||||
Value: value,
|
Value: value,
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := h.getValue(key)
|
if _, err := h.getValue(key); err == nil {
|
||||||
if err == nil {
|
h.db.Model(&keyValue).Where("key = ?", key).Update("value", value)
|
||||||
h.db.Model(&kv).Where("key = ?", key).Update("value", value)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
h.db.Create(kv)
|
h.db.Create(keyValue)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
24
derp.go
24
derp.go
|
@ -1,6 +1,7 @@
|
||||||
package headscale
|
package headscale
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
@ -10,9 +11,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
|
||||||
"gopkg.in/yaml.v2"
|
"gopkg.in/yaml.v2"
|
||||||
|
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -28,14 +27,24 @@ func loadDERPMapFromPath(path string) (*tailcfg.DERPMap, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
err = yaml.Unmarshal(b, &derpMap)
|
err = yaml.Unmarshal(b, &derpMap)
|
||||||
|
|
||||||
return &derpMap, err
|
return &derpMap, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) {
|
func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) {
|
||||||
client := http.Client{
|
ctx, cancel := context.WithTimeout(context.Background(), HTTPReadTimeout)
|
||||||
Timeout: 10 * time.Second,
|
defer cancel()
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "GET", addr.String(), nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
resp, err := client.Get(addr.String())
|
|
||||||
|
client := http.Client{
|
||||||
|
Timeout: HTTPReadTimeout,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -48,6 +57,7 @@ func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) {
|
||||||
|
|
||||||
var derpMap tailcfg.DERPMap
|
var derpMap tailcfg.DERPMap
|
||||||
err = json.Unmarshal(body, &derpMap)
|
err = json.Unmarshal(body, &derpMap)
|
||||||
|
|
||||||
return &derpMap, err
|
return &derpMap, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -55,7 +65,7 @@ func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) {
|
||||||
// DERPMap, it will _only_ look at the Regions, an integer.
|
// DERPMap, it will _only_ look at the Regions, an integer.
|
||||||
// If a region exists in two of the given DERPMaps, the region
|
// If a region exists in two of the given DERPMaps, the region
|
||||||
// form the _last_ DERPMap will be preserved.
|
// form the _last_ DERPMap will be preserved.
|
||||||
// An empty DERPMap list will result in a DERPMap with no regions
|
// An empty DERPMap list will result in a DERPMap with no regions.
|
||||||
func mergeDERPMaps(derpMaps []*tailcfg.DERPMap) *tailcfg.DERPMap {
|
func mergeDERPMaps(derpMaps []*tailcfg.DERPMap) *tailcfg.DERPMap {
|
||||||
result := tailcfg.DERPMap{
|
result := tailcfg.DERPMap{
|
||||||
OmitDefaultRegions: false,
|
OmitDefaultRegions: false,
|
||||||
|
@ -86,6 +96,7 @@ func GetDERPMap(cfg DERPConfig) *tailcfg.DERPMap {
|
||||||
Str("path", path).
|
Str("path", path).
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Could not load DERP map from path")
|
Msg("Could not load DERP map from path")
|
||||||
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -104,6 +115,7 @@ func GetDERPMap(cfg DERPConfig) *tailcfg.DERPMap {
|
||||||
Str("url", addr.String()).
|
Str("url", addr.String()).
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Could not load DERP map from path")
|
Msg("Could not load DERP map from path")
|
||||||
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
34
dns.go
34
dns.go
|
@ -10,6 +10,10 @@ import (
|
||||||
"tailscale.com/util/dnsname"
|
"tailscale.com/util/dnsname"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ByteSize = 8
|
||||||
|
)
|
||||||
|
|
||||||
// generateMagicDNSRootDomains generates a list of DNS entries to be included in `Routes` in `MapResponse`.
|
// generateMagicDNSRootDomains generates a list of DNS entries to be included in `Routes` in `MapResponse`.
|
||||||
// This list of reverse DNS entries instructs the OS on what subnets and domains the Tailscale embedded DNS
|
// This list of reverse DNS entries instructs the OS on what subnets and domains the Tailscale embedded DNS
|
||||||
// server (listening in 100.100.100.100 udp/53) should be used for.
|
// server (listening in 100.100.100.100 udp/53) should be used for.
|
||||||
|
@ -30,7 +34,9 @@ import (
|
||||||
|
|
||||||
// From the netmask we can find out the wildcard bits (the bits that are not set in the netmask).
|
// From the netmask we can find out the wildcard bits (the bits that are not set in the netmask).
|
||||||
// This allows us to then calculate the subnets included in the subsequent class block and generate the entries.
|
// This allows us to then calculate the subnets included in the subsequent class block and generate the entries.
|
||||||
func generateMagicDNSRootDomains(ipPrefix netaddr.IPPrefix, baseDomain string) ([]dnsname.FQDN, error) {
|
func generateMagicDNSRootDomains(
|
||||||
|
ipPrefix netaddr.IPPrefix,
|
||||||
|
) []dnsname.FQDN {
|
||||||
// TODO(juanfont): we are not handing out IPv6 addresses yet
|
// TODO(juanfont): we are not handing out IPv6 addresses yet
|
||||||
// and in fact this is Tailscale.com's range (note the fd7a:115c:a1e0: range in the fc00::/7 network)
|
// and in fact this is Tailscale.com's range (note the fd7a:115c:a1e0: range in the fc00::/7 network)
|
||||||
ipv6base := dnsname.FQDN("0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa.")
|
ipv6base := dnsname.FQDN("0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa.")
|
||||||
|
@ -41,15 +47,15 @@ func generateMagicDNSRootDomains(ipPrefix netaddr.IPPrefix, baseDomain string) (
|
||||||
maskBits, _ := netRange.Mask.Size()
|
maskBits, _ := netRange.Mask.Size()
|
||||||
|
|
||||||
// lastOctet is the last IP byte covered by the mask
|
// lastOctet is the last IP byte covered by the mask
|
||||||
lastOctet := maskBits / 8
|
lastOctet := maskBits / ByteSize
|
||||||
|
|
||||||
// wildcardBits is the number of bits not under the mask in the lastOctet
|
// wildcardBits is the number of bits not under the mask in the lastOctet
|
||||||
wildcardBits := 8 - maskBits%8
|
wildcardBits := ByteSize - maskBits%ByteSize
|
||||||
|
|
||||||
// min is the value in the lastOctet byte of the IP
|
// min is the value in the lastOctet byte of the IP
|
||||||
// max is basically 2^wildcardBits - i.e., the value when all the wildcardBits are set to 1
|
// max is basically 2^wildcardBits - i.e., the value when all the wildcardBits are set to 1
|
||||||
min := uint(netRange.IP[lastOctet])
|
min := uint(netRange.IP[lastOctet])
|
||||||
max := uint((min + 1<<uint(wildcardBits)) - 1)
|
max := (min + 1<<uint(wildcardBits)) - 1
|
||||||
|
|
||||||
// here we generate the base domain (e.g., 100.in-addr.arpa., 16.172.in-addr.arpa., etc.)
|
// here we generate the base domain (e.g., 100.in-addr.arpa., 16.172.in-addr.arpa., etc.)
|
||||||
rdnsSlice := []string{}
|
rdnsSlice := []string{}
|
||||||
|
@ -66,18 +72,27 @@ func generateMagicDNSRootDomains(ipPrefix netaddr.IPPrefix, baseDomain string) (
|
||||||
}
|
}
|
||||||
fqdns = append(fqdns, fqdn)
|
fqdns = append(fqdns, fqdn)
|
||||||
}
|
}
|
||||||
return fqdns, nil
|
|
||||||
|
return fqdns
|
||||||
}
|
}
|
||||||
|
|
||||||
func getMapResponseDNSConfig(dnsConfigOrig *tailcfg.DNSConfig, baseDomain string, m Machine, peers Machines) (*tailcfg.DNSConfig, error) {
|
func getMapResponseDNSConfig(
|
||||||
|
dnsConfigOrig *tailcfg.DNSConfig,
|
||||||
|
baseDomain string,
|
||||||
|
machine Machine,
|
||||||
|
peers Machines,
|
||||||
|
) *tailcfg.DNSConfig {
|
||||||
var dnsConfig *tailcfg.DNSConfig
|
var dnsConfig *tailcfg.DNSConfig
|
||||||
if dnsConfigOrig != nil && dnsConfigOrig.Proxied { // if MagicDNS is enabled
|
if dnsConfigOrig != nil && dnsConfigOrig.Proxied { // if MagicDNS is enabled
|
||||||
// Only inject the Search Domain of the current namespace - shared nodes should use their full FQDN
|
// Only inject the Search Domain of the current namespace - shared nodes should use their full FQDN
|
||||||
dnsConfig = dnsConfigOrig.Clone()
|
dnsConfig = dnsConfigOrig.Clone()
|
||||||
dnsConfig.Domains = append(dnsConfig.Domains, fmt.Sprintf("%s.%s", m.Namespace.Name, baseDomain))
|
dnsConfig.Domains = append(
|
||||||
|
dnsConfig.Domains,
|
||||||
|
fmt.Sprintf("%s.%s", machine.Namespace.Name, baseDomain),
|
||||||
|
)
|
||||||
|
|
||||||
namespaceSet := set.New(set.ThreadSafe)
|
namespaceSet := set.New(set.ThreadSafe)
|
||||||
namespaceSet.Add(m.Namespace)
|
namespaceSet.Add(machine.Namespace)
|
||||||
for _, p := range peers {
|
for _, p := range peers {
|
||||||
namespaceSet.Add(p.Namespace)
|
namespaceSet.Add(p.Namespace)
|
||||||
}
|
}
|
||||||
|
@ -88,5 +103,6 @@ func getMapResponseDNSConfig(dnsConfigOrig *tailcfg.DNSConfig, baseDomain string
|
||||||
} else {
|
} else {
|
||||||
dnsConfig = dnsConfigOrig
|
dnsConfig = dnsConfigOrig
|
||||||
}
|
}
|
||||||
return dnsConfig, nil
|
|
||||||
|
return dnsConfig
|
||||||
}
|
}
|
||||||
|
|
211
dns_test.go
211
dns_test.go
|
@ -11,13 +11,13 @@ import (
|
||||||
|
|
||||||
func (s *Suite) TestMagicDNSRootDomains100(c *check.C) {
|
func (s *Suite) TestMagicDNSRootDomains100(c *check.C) {
|
||||||
prefix := netaddr.MustParseIPPrefix("100.64.0.0/10")
|
prefix := netaddr.MustParseIPPrefix("100.64.0.0/10")
|
||||||
domains, err := generateMagicDNSRootDomains(prefix, "foobar.headscale.net")
|
domains := generateMagicDNSRootDomains(prefix)
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
found := false
|
found := false
|
||||||
for _, domain := range domains {
|
for _, domain := range domains {
|
||||||
if domain == "64.100.in-addr.arpa." {
|
if domain == "64.100.in-addr.arpa." {
|
||||||
found = true
|
found = true
|
||||||
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -27,6 +27,7 @@ func (s *Suite) TestMagicDNSRootDomains100(c *check.C) {
|
||||||
for _, domain := range domains {
|
for _, domain := range domains {
|
||||||
if domain == "100.100.in-addr.arpa." {
|
if domain == "100.100.in-addr.arpa." {
|
||||||
found = true
|
found = true
|
||||||
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -36,6 +37,7 @@ func (s *Suite) TestMagicDNSRootDomains100(c *check.C) {
|
||||||
for _, domain := range domains {
|
for _, domain := range domains {
|
||||||
if domain == "127.100.in-addr.arpa." {
|
if domain == "127.100.in-addr.arpa." {
|
||||||
found = true
|
found = true
|
||||||
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -44,13 +46,13 @@ func (s *Suite) TestMagicDNSRootDomains100(c *check.C) {
|
||||||
|
|
||||||
func (s *Suite) TestMagicDNSRootDomains172(c *check.C) {
|
func (s *Suite) TestMagicDNSRootDomains172(c *check.C) {
|
||||||
prefix := netaddr.MustParseIPPrefix("172.16.0.0/16")
|
prefix := netaddr.MustParseIPPrefix("172.16.0.0/16")
|
||||||
domains, err := generateMagicDNSRootDomains(prefix, "headscale.net")
|
domains := generateMagicDNSRootDomains(prefix)
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
found := false
|
found := false
|
||||||
for _, domain := range domains {
|
for _, domain := range domains {
|
||||||
if domain == "0.16.172.in-addr.arpa." {
|
if domain == "0.16.172.in-addr.arpa." {
|
||||||
found = true
|
found = true
|
||||||
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -60,6 +62,7 @@ func (s *Suite) TestMagicDNSRootDomains172(c *check.C) {
|
||||||
for _, domain := range domains {
|
for _, domain := range domains {
|
||||||
if domain == "255.16.172.in-addr.arpa." {
|
if domain == "255.16.172.in-addr.arpa." {
|
||||||
found = true
|
found = true
|
||||||
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -67,100 +70,120 @@ func (s *Suite) TestMagicDNSRootDomains172(c *check.C) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
|
func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
|
||||||
n1, err := h.CreateNamespace("shared1")
|
namespaceShared1, err := app.CreateNamespace("shared1")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
n2, err := h.CreateNamespace("shared2")
|
namespaceShared2, err := app.CreateNamespace("shared2")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
n3, err := h.CreateNamespace("shared3")
|
namespaceShared3, err := app.CreateNamespace("shared3")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak1n1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
|
preAuthKeyInShared1, err := app.CreatePreAuthKey(
|
||||||
|
namespaceShared1.Name,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak2n2, err := h.CreatePreAuthKey(n2.Name, false, false, nil)
|
preAuthKeyInShared2, err := app.CreatePreAuthKey(
|
||||||
|
namespaceShared2.Name,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak3n3, err := h.CreatePreAuthKey(n3.Name, false, false, nil)
|
preAuthKeyInShared3, err := app.CreatePreAuthKey(
|
||||||
|
namespaceShared3.Name,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak4n1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
|
PreAuthKey2InShared1, err := app.CreatePreAuthKey(
|
||||||
|
namespaceShared1.Name,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = h.GetMachine(n1.Name, "test_get_shared_nodes_1")
|
_, err = app.GetMachine(namespaceShared1.Name, "test_get_shared_nodes_1")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
m1 := &Machine{
|
machineInShared1 := &Machine{
|
||||||
ID: 1,
|
ID: 1,
|
||||||
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||||
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||||
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||||
Name: "test_get_shared_nodes_1",
|
Name: "test_get_shared_nodes_1",
|
||||||
NamespaceID: n1.ID,
|
NamespaceID: namespaceShared1.ID,
|
||||||
Namespace: *n1,
|
Namespace: *namespaceShared1,
|
||||||
Registered: true,
|
Registered: true,
|
||||||
RegisterMethod: "authKey",
|
RegisterMethod: "authKey",
|
||||||
IPAddress: "100.64.0.1",
|
IPAddress: "100.64.0.1",
|
||||||
AuthKeyID: uint(pak1n1.ID),
|
AuthKeyID: uint(preAuthKeyInShared1.ID),
|
||||||
}
|
}
|
||||||
h.db.Save(m1)
|
app.db.Save(machineInShared1)
|
||||||
|
|
||||||
_, err = h.GetMachine(n1.Name, m1.Name)
|
_, err = app.GetMachine(namespaceShared1.Name, machineInShared1.Name)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
m2 := &Machine{
|
machineInShared2 := &Machine{
|
||||||
ID: 2,
|
ID: 2,
|
||||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
Name: "test_get_shared_nodes_2",
|
Name: "test_get_shared_nodes_2",
|
||||||
NamespaceID: n2.ID,
|
NamespaceID: namespaceShared2.ID,
|
||||||
Namespace: *n2,
|
Namespace: *namespaceShared2,
|
||||||
Registered: true,
|
Registered: true,
|
||||||
RegisterMethod: "authKey",
|
RegisterMethod: "authKey",
|
||||||
IPAddress: "100.64.0.2",
|
IPAddress: "100.64.0.2",
|
||||||
AuthKeyID: uint(pak2n2.ID),
|
AuthKeyID: uint(preAuthKeyInShared2.ID),
|
||||||
}
|
}
|
||||||
h.db.Save(m2)
|
app.db.Save(machineInShared2)
|
||||||
|
|
||||||
_, err = h.GetMachine(n2.Name, m2.Name)
|
_, err = app.GetMachine(namespaceShared2.Name, machineInShared2.Name)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
m3 := &Machine{
|
machineInShared3 := &Machine{
|
||||||
ID: 3,
|
ID: 3,
|
||||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
Name: "test_get_shared_nodes_3",
|
Name: "test_get_shared_nodes_3",
|
||||||
NamespaceID: n3.ID,
|
NamespaceID: namespaceShared3.ID,
|
||||||
Namespace: *n3,
|
Namespace: *namespaceShared3,
|
||||||
Registered: true,
|
Registered: true,
|
||||||
RegisterMethod: "authKey",
|
RegisterMethod: "authKey",
|
||||||
IPAddress: "100.64.0.3",
|
IPAddress: "100.64.0.3",
|
||||||
AuthKeyID: uint(pak3n3.ID),
|
AuthKeyID: uint(preAuthKeyInShared3.ID),
|
||||||
}
|
}
|
||||||
h.db.Save(m3)
|
app.db.Save(machineInShared3)
|
||||||
|
|
||||||
_, err = h.GetMachine(n3.Name, m3.Name)
|
_, err = app.GetMachine(namespaceShared3.Name, machineInShared3.Name)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
m4 := &Machine{
|
machine2InShared1 := &Machine{
|
||||||
ID: 4,
|
ID: 4,
|
||||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
Name: "test_get_shared_nodes_4",
|
Name: "test_get_shared_nodes_4",
|
||||||
NamespaceID: n1.ID,
|
NamespaceID: namespaceShared1.ID,
|
||||||
Namespace: *n1,
|
Namespace: *namespaceShared1,
|
||||||
Registered: true,
|
Registered: true,
|
||||||
RegisterMethod: "authKey",
|
RegisterMethod: "authKey",
|
||||||
IPAddress: "100.64.0.4",
|
IPAddress: "100.64.0.4",
|
||||||
AuthKeyID: uint(pak4n1.ID),
|
AuthKeyID: uint(PreAuthKey2InShared1.ID),
|
||||||
}
|
}
|
||||||
h.db.Save(m4)
|
app.db.Save(machine2InShared1)
|
||||||
|
|
||||||
err = h.AddSharedMachineToNamespace(m2, n1)
|
err = app.AddSharedMachineToNamespace(machineInShared2, namespaceShared1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
baseDomain := "foobar.headscale.net"
|
baseDomain := "foobar.headscale.net"
|
||||||
|
@ -170,122 +193,146 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
|
||||||
Proxied: true,
|
Proxied: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
m1peers, err := h.getPeers(m1)
|
peersOfMachineInShared1, err := app.getPeers(machineInShared1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
dnsConfig, err := getMapResponseDNSConfig(&dnsConfigOrig, baseDomain, *m1, m1peers)
|
dnsConfig := getMapResponseDNSConfig(
|
||||||
c.Assert(err, check.IsNil)
|
&dnsConfigOrig,
|
||||||
|
baseDomain,
|
||||||
|
*machineInShared1,
|
||||||
|
peersOfMachineInShared1,
|
||||||
|
)
|
||||||
c.Assert(dnsConfig, check.NotNil)
|
c.Assert(dnsConfig, check.NotNil)
|
||||||
c.Assert(len(dnsConfig.Routes), check.Equals, 2)
|
c.Assert(len(dnsConfig.Routes), check.Equals, 2)
|
||||||
|
|
||||||
routeN1 := fmt.Sprintf("%s.%s", n1.Name, baseDomain)
|
domainRouteShared1 := fmt.Sprintf("%s.%s", namespaceShared1.Name, baseDomain)
|
||||||
_, ok := dnsConfig.Routes[routeN1]
|
_, ok := dnsConfig.Routes[domainRouteShared1]
|
||||||
c.Assert(ok, check.Equals, true)
|
c.Assert(ok, check.Equals, true)
|
||||||
|
|
||||||
routeN2 := fmt.Sprintf("%s.%s", n2.Name, baseDomain)
|
domainRouteShared2 := fmt.Sprintf("%s.%s", namespaceShared2.Name, baseDomain)
|
||||||
_, ok = dnsConfig.Routes[routeN2]
|
_, ok = dnsConfig.Routes[domainRouteShared2]
|
||||||
c.Assert(ok, check.Equals, true)
|
c.Assert(ok, check.Equals, true)
|
||||||
|
|
||||||
routeN3 := fmt.Sprintf("%s.%s", n3.Name, baseDomain)
|
domainRouteShared3 := fmt.Sprintf("%s.%s", namespaceShared3.Name, baseDomain)
|
||||||
_, ok = dnsConfig.Routes[routeN3]
|
_, ok = dnsConfig.Routes[domainRouteShared3]
|
||||||
c.Assert(ok, check.Equals, false)
|
c.Assert(ok, check.Equals, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
|
func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
|
||||||
n1, err := h.CreateNamespace("shared1")
|
namespaceShared1, err := app.CreateNamespace("shared1")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
n2, err := h.CreateNamespace("shared2")
|
namespaceShared2, err := app.CreateNamespace("shared2")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
n3, err := h.CreateNamespace("shared3")
|
namespaceShared3, err := app.CreateNamespace("shared3")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak1n1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
|
preAuthKeyInShared1, err := app.CreatePreAuthKey(
|
||||||
|
namespaceShared1.Name,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak2n2, err := h.CreatePreAuthKey(n2.Name, false, false, nil)
|
preAuthKeyInShared2, err := app.CreatePreAuthKey(
|
||||||
|
namespaceShared2.Name,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak3n3, err := h.CreatePreAuthKey(n3.Name, false, false, nil)
|
preAuthKeyInShared3, err := app.CreatePreAuthKey(
|
||||||
|
namespaceShared3.Name,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak4n1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
|
preAuthKey2InShared1, err := app.CreatePreAuthKey(
|
||||||
|
namespaceShared1.Name,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = h.GetMachine(n1.Name, "test_get_shared_nodes_1")
|
_, err = app.GetMachine(namespaceShared1.Name, "test_get_shared_nodes_1")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
m1 := &Machine{
|
machineInShared1 := &Machine{
|
||||||
ID: 1,
|
ID: 1,
|
||||||
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||||
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||||
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||||
Name: "test_get_shared_nodes_1",
|
Name: "test_get_shared_nodes_1",
|
||||||
NamespaceID: n1.ID,
|
NamespaceID: namespaceShared1.ID,
|
||||||
Namespace: *n1,
|
Namespace: *namespaceShared1,
|
||||||
Registered: true,
|
Registered: true,
|
||||||
RegisterMethod: "authKey",
|
RegisterMethod: "authKey",
|
||||||
IPAddress: "100.64.0.1",
|
IPAddress: "100.64.0.1",
|
||||||
AuthKeyID: uint(pak1n1.ID),
|
AuthKeyID: uint(preAuthKeyInShared1.ID),
|
||||||
}
|
}
|
||||||
h.db.Save(m1)
|
app.db.Save(machineInShared1)
|
||||||
|
|
||||||
_, err = h.GetMachine(n1.Name, m1.Name)
|
_, err = app.GetMachine(namespaceShared1.Name, machineInShared1.Name)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
m2 := &Machine{
|
machineInShared2 := &Machine{
|
||||||
ID: 2,
|
ID: 2,
|
||||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
Name: "test_get_shared_nodes_2",
|
Name: "test_get_shared_nodes_2",
|
||||||
NamespaceID: n2.ID,
|
NamespaceID: namespaceShared2.ID,
|
||||||
Namespace: *n2,
|
Namespace: *namespaceShared2,
|
||||||
Registered: true,
|
Registered: true,
|
||||||
RegisterMethod: "authKey",
|
RegisterMethod: "authKey",
|
||||||
IPAddress: "100.64.0.2",
|
IPAddress: "100.64.0.2",
|
||||||
AuthKeyID: uint(pak2n2.ID),
|
AuthKeyID: uint(preAuthKeyInShared2.ID),
|
||||||
}
|
}
|
||||||
h.db.Save(m2)
|
app.db.Save(machineInShared2)
|
||||||
|
|
||||||
_, err = h.GetMachine(n2.Name, m2.Name)
|
_, err = app.GetMachine(namespaceShared2.Name, machineInShared2.Name)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
m3 := &Machine{
|
machineInShared3 := &Machine{
|
||||||
ID: 3,
|
ID: 3,
|
||||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
Name: "test_get_shared_nodes_3",
|
Name: "test_get_shared_nodes_3",
|
||||||
NamespaceID: n3.ID,
|
NamespaceID: namespaceShared3.ID,
|
||||||
Namespace: *n3,
|
Namespace: *namespaceShared3,
|
||||||
Registered: true,
|
Registered: true,
|
||||||
RegisterMethod: "authKey",
|
RegisterMethod: "authKey",
|
||||||
IPAddress: "100.64.0.3",
|
IPAddress: "100.64.0.3",
|
||||||
AuthKeyID: uint(pak3n3.ID),
|
AuthKeyID: uint(preAuthKeyInShared3.ID),
|
||||||
}
|
}
|
||||||
h.db.Save(m3)
|
app.db.Save(machineInShared3)
|
||||||
|
|
||||||
_, err = h.GetMachine(n3.Name, m3.Name)
|
_, err = app.GetMachine(namespaceShared3.Name, machineInShared3.Name)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
m4 := &Machine{
|
machine2InShared1 := &Machine{
|
||||||
ID: 4,
|
ID: 4,
|
||||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
Name: "test_get_shared_nodes_4",
|
Name: "test_get_shared_nodes_4",
|
||||||
NamespaceID: n1.ID,
|
NamespaceID: namespaceShared1.ID,
|
||||||
Namespace: *n1,
|
Namespace: *namespaceShared1,
|
||||||
Registered: true,
|
Registered: true,
|
||||||
RegisterMethod: "authKey",
|
RegisterMethod: "authKey",
|
||||||
IPAddress: "100.64.0.4",
|
IPAddress: "100.64.0.4",
|
||||||
AuthKeyID: uint(pak4n1.ID),
|
AuthKeyID: uint(preAuthKey2InShared1.ID),
|
||||||
}
|
}
|
||||||
h.db.Save(m4)
|
app.db.Save(machine2InShared1)
|
||||||
|
|
||||||
err = h.AddSharedMachineToNamespace(m2, n1)
|
err = app.AddSharedMachineToNamespace(machineInShared2, namespaceShared1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
baseDomain := "foobar.headscale.net"
|
baseDomain := "foobar.headscale.net"
|
||||||
|
@ -295,11 +342,15 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
|
||||||
Proxied: false,
|
Proxied: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
m1peers, err := h.getPeers(m1)
|
peersOfMachine1Shared1, err := app.getPeers(machineInShared1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
dnsConfig, err := getMapResponseDNSConfig(&dnsConfigOrig, baseDomain, *m1, m1peers)
|
dnsConfig := getMapResponseDNSConfig(
|
||||||
c.Assert(err, check.IsNil)
|
&dnsConfigOrig,
|
||||||
|
baseDomain,
|
||||||
|
*machineInShared1,
|
||||||
|
peersOfMachine1Shared1,
|
||||||
|
)
|
||||||
c.Assert(dnsConfig, check.NotNil)
|
c.Assert(dnsConfig, check.NotNil)
|
||||||
c.Assert(len(dnsConfig.Routes), check.Equals, 0)
|
c.Assert(len(dnsConfig.Routes), check.Equals, 0)
|
||||||
c.Assert(len(dnsConfig.Domains), check.Equals, 1)
|
c.Assert(len(dnsConfig.Domains), check.Equals, 1)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
# Running headscale
|
# Running headscale
|
||||||
|
|
||||||
## Server configuration
|
## Server configuration
|
||||||
|
|
||||||
1. Download the headscale binary https://github.com/juanfont/headscale/releases, and place it somewhere in your $PATH or use the docker container
|
1. Download the headscale binary https://github.com/juanfont/headscale/releases, and place it somewhere in your $PATH or use the docker container
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
|
@ -22,6 +23,7 @@
|
||||||
3. Get yourself a DB
|
3. Get yourself a DB
|
||||||
|
|
||||||
a) Get a Postgres DB running in docker
|
a) Get a Postgres DB running in docker
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
docker run --name headscale \
|
docker run --name headscale \
|
||||||
-e POSTGRES_DB=headscale
|
-e POSTGRES_DB=headscale
|
||||||
|
@ -30,7 +32,9 @@
|
||||||
-p 5432:5432 \
|
-p 5432:5432 \
|
||||||
-d postgres
|
-d postgres
|
||||||
```
|
```
|
||||||
|
|
||||||
or b) Prepare a SQLite DB file
|
or b) Prepare a SQLite DB file
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
touch config/db.sqlite
|
touch config/db.sqlite
|
||||||
```
|
```
|
||||||
|
@ -81,6 +85,7 @@
|
||||||
-p 127.0.0.1:8080:8080 \
|
-p 127.0.0.1:8080:8080 \
|
||||||
headscale/headscale:x.x.x headscale serve
|
headscale/headscale:x.x.x headscale serve
|
||||||
```
|
```
|
||||||
|
|
||||||
## Nodes configuration
|
## Nodes configuration
|
||||||
|
|
||||||
If you used tailscale.com before in your nodes, make sure you clear the tailscaled data folder
|
If you used tailscale.com before in your nodes, make sure you clear the tailscaled data folder
|
||||||
|
@ -90,7 +95,9 @@ If you used tailscale.com before in your nodes, make sure you clear the tailscal
|
||||||
rm -fr /var/lib/tailscale
|
rm -fr /var/lib/tailscale
|
||||||
systemctl start tailscaled
|
systemctl start tailscaled
|
||||||
```
|
```
|
||||||
|
|
||||||
### Adding node based on MACHINEKEY
|
### Adding node based on MACHINEKEY
|
||||||
|
|
||||||
1. Add your first machine
|
1. Add your first machine
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
|
|
17
grpcv1.go
17
grpcv1.go
|
@ -155,7 +155,10 @@ func (api headscaleV1APIServer) RegisterMachine(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.RegisterMachineRequest,
|
request *v1.RegisterMachineRequest,
|
||||||
) (*v1.RegisterMachineResponse, error) {
|
) (*v1.RegisterMachineResponse, error) {
|
||||||
log.Trace().Str("namespace", request.GetNamespace()).Str("machine_key", request.GetKey()).Msg("Registering machine")
|
log.Trace().
|
||||||
|
Str("namespace", request.GetNamespace()).
|
||||||
|
Str("machine_key", request.GetKey()).
|
||||||
|
Msg("Registering machine")
|
||||||
machine, err := api.h.RegisterMachine(
|
machine, err := api.h.RegisterMachine(
|
||||||
request.GetKey(),
|
request.GetKey(),
|
||||||
request.GetNamespace(),
|
request.GetNamespace(),
|
||||||
|
@ -208,7 +211,9 @@ func (api headscaleV1APIServer) ListMachines(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
sharedMachines, err := api.h.ListSharedMachinesInNamespace(request.GetNamespace())
|
sharedMachines, err := api.h.ListSharedMachinesInNamespace(
|
||||||
|
request.GetNamespace(),
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -333,12 +338,16 @@ func (api headscaleV1APIServer) DebugCreateMachine(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
routes, err := stringToIpPrefix(request.GetRoutes())
|
routes, err := stringToIPPrefix(request.GetRoutes())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Trace().Caller().Interface("route-prefix", routes).Interface("route-str", request.GetRoutes()).Msg("")
|
log.Trace().
|
||||||
|
Caller().
|
||||||
|
Interface("route-prefix", routes).
|
||||||
|
Interface("route-str", request.GetRoutes()).
|
||||||
|
Msg("")
|
||||||
|
|
||||||
hostinfo := tailcfg.Hostinfo{
|
hostinfo := tailcfg.Hostinfo{
|
||||||
RoutableIPs: routes,
|
RoutableIPs: routes,
|
||||||
|
|
|
@ -88,6 +88,7 @@ func (s *IntegrationCLITestSuite) SetupTest() {
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return fmt.Errorf("status code not OK")
|
return fmt.Errorf("status code not OK")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
// TODO(kradalby): If we cannot access headscale, or any other fatal error during
|
// TODO(kradalby): If we cannot access headscale, or any other fatal error during
|
||||||
|
@ -109,7 +110,10 @@ func (s *IntegrationCLITestSuite) TearDownTest() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *IntegrationCLITestSuite) HandleStats(suiteName string, stats *suite.SuiteInformation) {
|
func (s *IntegrationCLITestSuite) HandleStats(
|
||||||
|
suiteName string,
|
||||||
|
stats *suite.SuiteInformation,
|
||||||
|
) {
|
||||||
s.stats = stats
|
s.stats = stats
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -144,7 +148,6 @@ func (s *IntegrationCLITestSuite) TestNamespaceCommand() {
|
||||||
namespaces := make([]*v1.Namespace, len(names))
|
namespaces := make([]*v1.Namespace, len(names))
|
||||||
|
|
||||||
for index, namespaceName := range names {
|
for index, namespaceName := range names {
|
||||||
|
|
||||||
namespace, err := s.createNamespace(namespaceName)
|
namespace, err := s.createNamespace(namespaceName)
|
||||||
assert.Nil(s.T(), err)
|
assert.Nil(s.T(), err)
|
||||||
|
|
||||||
|
@ -298,11 +301,26 @@ func (s *IntegrationCLITestSuite) TestPreAuthKeyCommand() {
|
||||||
assert.True(s.T(), listedPreAuthKeys[3].Expiration.AsTime().After(time.Now()))
|
assert.True(s.T(), listedPreAuthKeys[3].Expiration.AsTime().After(time.Now()))
|
||||||
assert.True(s.T(), listedPreAuthKeys[4].Expiration.AsTime().After(time.Now()))
|
assert.True(s.T(), listedPreAuthKeys[4].Expiration.AsTime().After(time.Now()))
|
||||||
|
|
||||||
assert.True(s.T(), listedPreAuthKeys[0].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)))
|
assert.True(
|
||||||
assert.True(s.T(), listedPreAuthKeys[1].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)))
|
s.T(),
|
||||||
assert.True(s.T(), listedPreAuthKeys[2].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)))
|
listedPreAuthKeys[0].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)),
|
||||||
assert.True(s.T(), listedPreAuthKeys[3].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)))
|
)
|
||||||
assert.True(s.T(), listedPreAuthKeys[4].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)))
|
assert.True(
|
||||||
|
s.T(),
|
||||||
|
listedPreAuthKeys[1].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)),
|
||||||
|
)
|
||||||
|
assert.True(
|
||||||
|
s.T(),
|
||||||
|
listedPreAuthKeys[2].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)),
|
||||||
|
)
|
||||||
|
assert.True(
|
||||||
|
s.T(),
|
||||||
|
listedPreAuthKeys[3].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)),
|
||||||
|
)
|
||||||
|
assert.True(
|
||||||
|
s.T(),
|
||||||
|
listedPreAuthKeys[4].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)),
|
||||||
|
)
|
||||||
|
|
||||||
// Expire three keys
|
// Expire three keys
|
||||||
for i := 0; i < 3; i++ {
|
for i := 0; i < 3; i++ {
|
||||||
|
@ -341,11 +359,26 @@ func (s *IntegrationCLITestSuite) TestPreAuthKeyCommand() {
|
||||||
err = json.Unmarshal([]byte(listAfterExpireResult), &listedAfterExpirePreAuthKeys)
|
err = json.Unmarshal([]byte(listAfterExpireResult), &listedAfterExpirePreAuthKeys)
|
||||||
assert.Nil(s.T(), err)
|
assert.Nil(s.T(), err)
|
||||||
|
|
||||||
assert.True(s.T(), listedAfterExpirePreAuthKeys[0].Expiration.AsTime().Before(time.Now()))
|
assert.True(
|
||||||
assert.True(s.T(), listedAfterExpirePreAuthKeys[1].Expiration.AsTime().Before(time.Now()))
|
s.T(),
|
||||||
assert.True(s.T(), listedAfterExpirePreAuthKeys[2].Expiration.AsTime().Before(time.Now()))
|
listedAfterExpirePreAuthKeys[0].Expiration.AsTime().Before(time.Now()),
|
||||||
assert.True(s.T(), listedAfterExpirePreAuthKeys[3].Expiration.AsTime().After(time.Now()))
|
)
|
||||||
assert.True(s.T(), listedAfterExpirePreAuthKeys[4].Expiration.AsTime().After(time.Now()))
|
assert.True(
|
||||||
|
s.T(),
|
||||||
|
listedAfterExpirePreAuthKeys[1].Expiration.AsTime().Before(time.Now()),
|
||||||
|
)
|
||||||
|
assert.True(
|
||||||
|
s.T(),
|
||||||
|
listedAfterExpirePreAuthKeys[2].Expiration.AsTime().Before(time.Now()),
|
||||||
|
)
|
||||||
|
assert.True(
|
||||||
|
s.T(),
|
||||||
|
listedAfterExpirePreAuthKeys[3].Expiration.AsTime().After(time.Now()),
|
||||||
|
)
|
||||||
|
assert.True(
|
||||||
|
s.T(),
|
||||||
|
listedAfterExpirePreAuthKeys[4].Expiration.AsTime().After(time.Now()),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *IntegrationCLITestSuite) TestPreAuthKeyCommandWithoutExpiry() {
|
func (s *IntegrationCLITestSuite) TestPreAuthKeyCommandWithoutExpiry() {
|
||||||
|
@ -689,7 +722,10 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() {
|
||||||
assert.Nil(s.T(), err)
|
assert.Nil(s.T(), err)
|
||||||
|
|
||||||
var listOnlySharedMachineNamespace []v1.Machine
|
var listOnlySharedMachineNamespace []v1.Machine
|
||||||
err = json.Unmarshal([]byte(listOnlySharedMachineNamespaceResult), &listOnlySharedMachineNamespace)
|
err = json.Unmarshal(
|
||||||
|
[]byte(listOnlySharedMachineNamespaceResult),
|
||||||
|
&listOnlySharedMachineNamespace,
|
||||||
|
)
|
||||||
assert.Nil(s.T(), err)
|
assert.Nil(s.T(), err)
|
||||||
|
|
||||||
assert.Len(s.T(), listOnlySharedMachineNamespace, 2)
|
assert.Len(s.T(), listOnlySharedMachineNamespace, 2)
|
||||||
|
@ -738,7 +774,10 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() {
|
||||||
assert.Nil(s.T(), err)
|
assert.Nil(s.T(), err)
|
||||||
|
|
||||||
var listOnlyMachineNamespaceAfterDelete []v1.Machine
|
var listOnlyMachineNamespaceAfterDelete []v1.Machine
|
||||||
err = json.Unmarshal([]byte(listOnlyMachineNamespaceAfterDeleteResult), &listOnlyMachineNamespaceAfterDelete)
|
err = json.Unmarshal(
|
||||||
|
[]byte(listOnlyMachineNamespaceAfterDeleteResult),
|
||||||
|
&listOnlyMachineNamespaceAfterDelete,
|
||||||
|
)
|
||||||
assert.Nil(s.T(), err)
|
assert.Nil(s.T(), err)
|
||||||
|
|
||||||
assert.Len(s.T(), listOnlyMachineNamespaceAfterDelete, 4)
|
assert.Len(s.T(), listOnlyMachineNamespaceAfterDelete, 4)
|
||||||
|
@ -789,7 +828,10 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() {
|
||||||
assert.Nil(s.T(), err)
|
assert.Nil(s.T(), err)
|
||||||
|
|
||||||
var listOnlyMachineNamespaceAfterShare []v1.Machine
|
var listOnlyMachineNamespaceAfterShare []v1.Machine
|
||||||
err = json.Unmarshal([]byte(listOnlyMachineNamespaceAfterShareResult), &listOnlyMachineNamespaceAfterShare)
|
err = json.Unmarshal(
|
||||||
|
[]byte(listOnlyMachineNamespaceAfterShareResult),
|
||||||
|
&listOnlyMachineNamespaceAfterShare,
|
||||||
|
)
|
||||||
assert.Nil(s.T(), err)
|
assert.Nil(s.T(), err)
|
||||||
|
|
||||||
assert.Len(s.T(), listOnlyMachineNamespaceAfterShare, 5)
|
assert.Len(s.T(), listOnlyMachineNamespaceAfterShare, 5)
|
||||||
|
@ -846,7 +888,10 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() {
|
||||||
assert.Nil(s.T(), err)
|
assert.Nil(s.T(), err)
|
||||||
|
|
||||||
var listOnlyMachineNamespaceAfterUnshare []v1.Machine
|
var listOnlyMachineNamespaceAfterUnshare []v1.Machine
|
||||||
err = json.Unmarshal([]byte(listOnlyMachineNamespaceAfterUnshareResult), &listOnlyMachineNamespaceAfterUnshare)
|
err = json.Unmarshal(
|
||||||
|
[]byte(listOnlyMachineNamespaceAfterUnshareResult),
|
||||||
|
&listOnlyMachineNamespaceAfterUnshare,
|
||||||
|
)
|
||||||
assert.Nil(s.T(), err)
|
assert.Nil(s.T(), err)
|
||||||
|
|
||||||
assert.Len(s.T(), listOnlyMachineNamespaceAfterUnshare, 4)
|
assert.Len(s.T(), listOnlyMachineNamespaceAfterUnshare, 4)
|
||||||
|
@ -1010,5 +1055,9 @@ func (s *IntegrationCLITestSuite) TestRouteCommand() {
|
||||||
)
|
)
|
||||||
assert.Nil(s.T(), err)
|
assert.Nil(s.T(), err)
|
||||||
|
|
||||||
assert.Contains(s.T(), string(failEnableNonAdvertisedRoute), "route (route-machine) is not available on node")
|
assert.Contains(
|
||||||
|
s.T(),
|
||||||
|
string(failEnableNonAdvertisedRoute),
|
||||||
|
"route (route-machine) is not available on node",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,12 +12,18 @@ import (
|
||||||
"github.com/ory/dockertest/v3/docker"
|
"github.com/ory/dockertest/v3/docker"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ExecuteCommand(resource *dockertest.Resource, cmd []string, env []string) (string, error) {
|
const DOCKER_EXECUTE_TIMEOUT = 10 * time.Second
|
||||||
|
|
||||||
|
func ExecuteCommand(
|
||||||
|
resource *dockertest.Resource,
|
||||||
|
cmd []string,
|
||||||
|
env []string,
|
||||||
|
) (string, error) {
|
||||||
var stdout bytes.Buffer
|
var stdout bytes.Buffer
|
||||||
var stderr bytes.Buffer
|
var stderr bytes.Buffer
|
||||||
|
|
||||||
// TODO(kradalby): Make configurable
|
// TODO(kradalby): Make configurable
|
||||||
timeout := 10 * time.Second
|
timeout := DOCKER_EXECUTE_TIMEOUT
|
||||||
|
|
||||||
type result struct {
|
type result struct {
|
||||||
exitCode int
|
exitCode int
|
||||||
|
@ -51,11 +57,13 @@ func ExecuteCommand(resource *dockertest.Resource, cmd []string, env []string) (
|
||||||
fmt.Println("Command: ", cmd)
|
fmt.Println("Command: ", cmd)
|
||||||
fmt.Println("stdout: ", stdout.String())
|
fmt.Println("stdout: ", stdout.String())
|
||||||
fmt.Println("stderr: ", stderr.String())
|
fmt.Println("stderr: ", stderr.String())
|
||||||
|
|
||||||
return "", fmt.Errorf("command failed with: %s", stderr.String())
|
return "", fmt.Errorf("command failed with: %s", stderr.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
return stdout.String(), nil
|
return stdout.String(), nil
|
||||||
case <-time.After(timeout):
|
case <-time.After(timeout):
|
||||||
|
|
||||||
return "", fmt.Errorf("command timed out after %s", timeout)
|
return "", fmt.Errorf("command timed out after %s", timeout)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,10 +23,9 @@ import (
|
||||||
"github.com/ory/dockertest/v3/docker"
|
"github.com/ory/dockertest/v3/docker"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/suite"
|
"github.com/stretchr/testify/suite"
|
||||||
|
"inet.af/netaddr"
|
||||||
"tailscale.com/client/tailscale/apitype"
|
"tailscale.com/client/tailscale/apitype"
|
||||||
"tailscale.com/ipn/ipnstate"
|
"tailscale.com/ipn/ipnstate"
|
||||||
|
|
||||||
"inet.af/netaddr"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var tailscaleVersions = []string{"1.16.2", "1.14.3", "1.12.3"}
|
var tailscaleVersions = []string{"1.16.2", "1.14.3", "1.12.3"}
|
||||||
|
@ -89,7 +88,10 @@ func TestIntegrationTestSuite(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *IntegrationTestSuite) saveLog(resource *dockertest.Resource, basePath string) error {
|
func (s *IntegrationTestSuite) saveLog(
|
||||||
|
resource *dockertest.Resource,
|
||||||
|
basePath string,
|
||||||
|
) error {
|
||||||
err := os.MkdirAll(basePath, os.ModePerm)
|
err := os.MkdirAll(basePath, os.ModePerm)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -118,12 +120,20 @@ func (s *IntegrationTestSuite) saveLog(resource *dockertest.Resource, basePath s
|
||||||
|
|
||||||
fmt.Printf("Saving logs for %s to %s\n", resource.Container.Name, basePath)
|
fmt.Printf("Saving logs for %s to %s\n", resource.Container.Name, basePath)
|
||||||
|
|
||||||
err = ioutil.WriteFile(path.Join(basePath, resource.Container.Name+".stdout.log"), []byte(stdout.String()), 0o644)
|
err = ioutil.WriteFile(
|
||||||
|
path.Join(basePath, resource.Container.Name+".stdout.log"),
|
||||||
|
[]byte(stdout.String()),
|
||||||
|
0o644,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = ioutil.WriteFile(path.Join(basePath, resource.Container.Name+".stderr.log"), []byte(stdout.String()), 0o644)
|
err = ioutil.WriteFile(
|
||||||
|
path.Join(basePath, resource.Container.Name+".stderr.log"),
|
||||||
|
[]byte(stdout.String()),
|
||||||
|
0o644,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -144,24 +154,38 @@ func (s *IntegrationTestSuite) tailscaleContainer(
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
hostname := fmt.Sprintf("%s-tailscale-%s-%s", namespace, strings.Replace(version, ".", "-", -1), identifier)
|
hostname := fmt.Sprintf(
|
||||||
|
"%s-tailscale-%s-%s",
|
||||||
|
namespace,
|
||||||
|
strings.Replace(version, ".", "-", -1),
|
||||||
|
identifier,
|
||||||
|
)
|
||||||
tailscaleOptions := &dockertest.RunOptions{
|
tailscaleOptions := &dockertest.RunOptions{
|
||||||
Name: hostname,
|
Name: hostname,
|
||||||
Networks: []*dockertest.Network{&s.network},
|
Networks: []*dockertest.Network{&s.network},
|
||||||
Cmd: []string{"tailscaled", "--tun=userspace-networking", "--socks5-server=localhost:1055"},
|
Cmd: []string{
|
||||||
|
"tailscaled",
|
||||||
|
"--tun=userspace-networking",
|
||||||
|
"--socks5-server=localhost:1055",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
pts, err := s.pool.BuildAndRunWithBuildOptions(tailscaleBuildOptions, tailscaleOptions, DockerRestartPolicy)
|
pts, err := s.pool.BuildAndRunWithBuildOptions(
|
||||||
|
tailscaleBuildOptions,
|
||||||
|
tailscaleOptions,
|
||||||
|
DockerRestartPolicy,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Could not start resource: %s", err)
|
log.Fatalf("Could not start resource: %s", err)
|
||||||
}
|
}
|
||||||
fmt.Printf("Created %s container\n", hostname)
|
fmt.Printf("Created %s container\n", hostname)
|
||||||
|
|
||||||
return hostname, pts
|
return hostname, pts
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *IntegrationTestSuite) SetupSuite() {
|
func (s *IntegrationTestSuite) SetupSuite() {
|
||||||
var err error
|
var err error
|
||||||
h = Headscale{
|
app = Headscale{
|
||||||
dbType: "sqlite3",
|
dbType: "sqlite3",
|
||||||
dbString: "integration_test_db.sqlite3",
|
dbString: "integration_test_db.sqlite3",
|
||||||
}
|
}
|
||||||
|
@ -210,7 +234,11 @@ func (s *IntegrationTestSuite) SetupSuite() {
|
||||||
for i := 0; i < scales.count; i++ {
|
for i := 0; i < scales.count; i++ {
|
||||||
version := tailscaleVersions[i%len(tailscaleVersions)]
|
version := tailscaleVersions[i%len(tailscaleVersions)]
|
||||||
|
|
||||||
hostname, container := s.tailscaleContainer(namespace, fmt.Sprint(i), version)
|
hostname, container := s.tailscaleContainer(
|
||||||
|
namespace,
|
||||||
|
fmt.Sprint(i),
|
||||||
|
version,
|
||||||
|
)
|
||||||
scales.tailscales[hostname] = *container
|
scales.tailscales[hostname] = *container
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -220,13 +248,16 @@ func (s *IntegrationTestSuite) SetupSuite() {
|
||||||
|
|
||||||
if err := s.pool.Retry(func() error {
|
if err := s.pool.Retry(func() error {
|
||||||
url := fmt.Sprintf("http://%s/health", hostEndpoint)
|
url := fmt.Sprintf("http://%s/health", hostEndpoint)
|
||||||
|
|
||||||
resp, err := http.Get(url)
|
resp, err := http.Get(url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return fmt.Errorf("status code not OK")
|
return fmt.Errorf("status code not OK")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
// TODO(kradalby): If we cannot access headscale, or any other fatal error during
|
// TODO(kradalby): If we cannot access headscale, or any other fatal error during
|
||||||
|
@ -273,7 +304,10 @@ func (s *IntegrationTestSuite) SetupSuite() {
|
||||||
|
|
||||||
headscaleEndpoint := "http://headscale:8080"
|
headscaleEndpoint := "http://headscale:8080"
|
||||||
|
|
||||||
fmt.Printf("Joining tailscale containers to headscale at %s\n", headscaleEndpoint)
|
fmt.Printf(
|
||||||
|
"Joining tailscale containers to headscale at %s\n",
|
||||||
|
headscaleEndpoint,
|
||||||
|
)
|
||||||
for hostname, tailscale := range scales.tailscales {
|
for hostname, tailscale := range scales.tailscales {
|
||||||
command := []string{
|
command := []string{
|
||||||
"tailscale",
|
"tailscale",
|
||||||
|
@ -307,7 +341,10 @@ func (s *IntegrationTestSuite) SetupSuite() {
|
||||||
func (s *IntegrationTestSuite) TearDownSuite() {
|
func (s *IntegrationTestSuite) TearDownSuite() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *IntegrationTestSuite) HandleStats(suiteName string, stats *suite.SuiteInformation) {
|
func (s *IntegrationTestSuite) HandleStats(
|
||||||
|
suiteName string,
|
||||||
|
stats *suite.SuiteInformation,
|
||||||
|
) {
|
||||||
s.stats = stats
|
s.stats = stats
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -427,7 +464,13 @@ func (s *IntegrationTestSuite) TestPingAllPeers() {
|
||||||
ip.String(),
|
ip.String(),
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("Pinging from %s (%s) to %s (%s)\n", hostname, ips[hostname], peername, ip)
|
fmt.Printf(
|
||||||
|
"Pinging from %s (%s) to %s (%s)\n",
|
||||||
|
hostname,
|
||||||
|
ips[hostname],
|
||||||
|
peername,
|
||||||
|
ip,
|
||||||
|
)
|
||||||
result, err := ExecuteCommand(
|
result, err := ExecuteCommand(
|
||||||
&tailscale,
|
&tailscale,
|
||||||
command,
|
command,
|
||||||
|
@ -449,7 +492,15 @@ func (s *IntegrationTestSuite) TestSharedNodes() {
|
||||||
|
|
||||||
result, err := ExecuteCommand(
|
result, err := ExecuteCommand(
|
||||||
&s.headscale,
|
&s.headscale,
|
||||||
[]string{"headscale", "nodes", "list", "--output", "json", "--namespace", "shared"},
|
[]string{
|
||||||
|
"headscale",
|
||||||
|
"nodes",
|
||||||
|
"list",
|
||||||
|
"--output",
|
||||||
|
"json",
|
||||||
|
"--namespace",
|
||||||
|
"shared",
|
||||||
|
},
|
||||||
[]string{},
|
[]string{},
|
||||||
)
|
)
|
||||||
assert.Nil(s.T(), err)
|
assert.Nil(s.T(), err)
|
||||||
|
@ -459,7 +510,6 @@ func (s *IntegrationTestSuite) TestSharedNodes() {
|
||||||
assert.Nil(s.T(), err)
|
assert.Nil(s.T(), err)
|
||||||
|
|
||||||
for _, machine := range machineList {
|
for _, machine := range machineList {
|
||||||
|
|
||||||
result, err := ExecuteCommand(
|
result, err := ExecuteCommand(
|
||||||
&s.headscale,
|
&s.headscale,
|
||||||
[]string{
|
[]string{
|
||||||
|
@ -520,7 +570,13 @@ func (s *IntegrationTestSuite) TestSharedNodes() {
|
||||||
ip.String(),
|
ip.String(),
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("Pinging from %s (%s) to %s (%s)\n", hostname, mainIps[hostname], peername, ip)
|
fmt.Printf(
|
||||||
|
"Pinging from %s (%s) to %s (%s)\n",
|
||||||
|
hostname,
|
||||||
|
mainIps[hostname],
|
||||||
|
peername,
|
||||||
|
ip,
|
||||||
|
)
|
||||||
result, err := ExecuteCommand(
|
result, err := ExecuteCommand(
|
||||||
&tailscale,
|
&tailscale,
|
||||||
command,
|
command,
|
||||||
|
@ -553,7 +609,6 @@ func (s *IntegrationTestSuite) TestTailDrop() {
|
||||||
for peername, ip := range ips {
|
for peername, ip := range ips {
|
||||||
s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) {
|
s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) {
|
||||||
if peername != hostname {
|
if peername != hostname {
|
||||||
|
|
||||||
// Under normal circumstances, we should be able to send a file
|
// Under normal circumstances, we should be able to send a file
|
||||||
// using `tailscale file cp` - but not in userspace networking mode
|
// using `tailscale file cp` - but not in userspace networking mode
|
||||||
// So curl!
|
// So curl!
|
||||||
|
@ -578,9 +633,19 @@ func (s *IntegrationTestSuite) TestTailDrop() {
|
||||||
"PUT",
|
"PUT",
|
||||||
"--upload-file",
|
"--upload-file",
|
||||||
fmt.Sprintf("/tmp/file_from_%s", hostname),
|
fmt.Sprintf("/tmp/file_from_%s", hostname),
|
||||||
fmt.Sprintf("%s/v0/put/file_from_%s", peerAPI, hostname),
|
fmt.Sprintf(
|
||||||
|
"%s/v0/put/file_from_%s",
|
||||||
|
peerAPI,
|
||||||
|
hostname,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
fmt.Printf("Sending file from %s (%s) to %s (%s)\n", hostname, ips[hostname], peername, ip)
|
fmt.Printf(
|
||||||
|
"Sending file from %s (%s) to %s (%s)\n",
|
||||||
|
hostname,
|
||||||
|
ips[hostname],
|
||||||
|
peername,
|
||||||
|
ip,
|
||||||
|
)
|
||||||
_, err = ExecuteCommand(
|
_, err = ExecuteCommand(
|
||||||
&tailscale,
|
&tailscale,
|
||||||
command,
|
command,
|
||||||
|
@ -621,7 +686,13 @@ func (s *IntegrationTestSuite) TestTailDrop() {
|
||||||
"ls",
|
"ls",
|
||||||
fmt.Sprintf("/tmp/file_from_%s", peername),
|
fmt.Sprintf("/tmp/file_from_%s", peername),
|
||||||
}
|
}
|
||||||
fmt.Printf("Checking file in %s (%s) from %s (%s)\n", hostname, ips[hostname], peername, ip)
|
fmt.Printf(
|
||||||
|
"Checking file in %s (%s) from %s (%s)\n",
|
||||||
|
hostname,
|
||||||
|
ips[hostname],
|
||||||
|
peername,
|
||||||
|
ip,
|
||||||
|
)
|
||||||
result, err := ExecuteCommand(
|
result, err := ExecuteCommand(
|
||||||
&tailscale,
|
&tailscale,
|
||||||
command,
|
command,
|
||||||
|
@ -629,7 +700,11 @@ func (s *IntegrationTestSuite) TestTailDrop() {
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
fmt.Printf("Result for %s: %s\n", peername, result)
|
fmt.Printf("Result for %s: %s\n", peername, result)
|
||||||
assert.Equal(t, result, fmt.Sprintf("/tmp/file_from_%s\n", peername))
|
assert.Equal(
|
||||||
|
t,
|
||||||
|
result,
|
||||||
|
fmt.Sprintf("/tmp/file_from_%s\n", peername),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -696,10 +771,13 @@ func getIPs(tailscales map[string]dockertest.Resource) (map[string]netaddr.IP, e
|
||||||
|
|
||||||
ips[hostname] = ip
|
ips[hostname] = ip
|
||||||
}
|
}
|
||||||
|
|
||||||
return ips, nil
|
return ips, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getAPIURLs(tailscales map[string]dockertest.Resource) (map[netaddr.IP]string, error) {
|
func getAPIURLs(
|
||||||
|
tailscales map[string]dockertest.Resource,
|
||||||
|
) (map[netaddr.IP]string, error) {
|
||||||
fts := make(map[netaddr.IP]string)
|
fts := make(map[netaddr.IP]string)
|
||||||
for _, tailscale := range tailscales {
|
for _, tailscale := range tailscales {
|
||||||
command := []string{
|
command := []string{
|
||||||
|
@ -733,5 +811,6 @@ func getAPIURLs(tailscales map[string]dockertest.Resource) (map[netaddr.IP]strin
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return fts, nil
|
return fts, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,6 +24,7 @@ Configure DERP servers by editing `base/site/derp.yaml` if needed.
|
||||||
You'll somehow need to get `headscale:latest` into your cluster image registry.
|
You'll somehow need to get `headscale:latest` into your cluster image registry.
|
||||||
|
|
||||||
An easy way to do this with k3s:
|
An easy way to do this with k3s:
|
||||||
|
|
||||||
- Reconfigure k3s to use docker instead of containerd (`k3s server --docker`)
|
- Reconfigure k3s to use docker instead of containerd (`k3s server --docker`)
|
||||||
- `docker build -t headscale:latest ..` from here
|
- `docker build -t headscale:latest ..` from here
|
||||||
|
|
||||||
|
|
402
machine.go
402
machine.go
|
@ -10,10 +10,9 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fatih/set"
|
"github.com/fatih/set"
|
||||||
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
|
||||||
"gorm.io/datatypes"
|
"gorm.io/datatypes"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"inet.af/netaddr"
|
"inet.af/netaddr"
|
||||||
|
@ -21,7 +20,13 @@ import (
|
||||||
"tailscale.com/types/wgkey"
|
"tailscale.com/types/wgkey"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Machine is a Headscale client
|
const (
|
||||||
|
errMachineNotFound = Error("machine not found")
|
||||||
|
errMachineAlreadyRegistered = Error("machine already registered")
|
||||||
|
errMachineRouteIsNotAvailable = Error("route is not available on machine")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Machine is a Headscale client.
|
||||||
type Machine struct {
|
type Machine struct {
|
||||||
ID uint64 `gorm:"primary_key"`
|
ID uint64 `gorm:"primary_key"`
|
||||||
MachineKey string `gorm:"type:varchar(64);unique_index"`
|
MachineKey string `gorm:"type:varchar(64);unique_index"`
|
||||||
|
@ -56,53 +61,58 @@ type (
|
||||||
MachinesP []*Machine
|
MachinesP []*Machine
|
||||||
)
|
)
|
||||||
|
|
||||||
// For the time being this method is rather naive
|
// For the time being this method is rather naive.
|
||||||
func (m Machine) isAlreadyRegistered() bool {
|
func (machine Machine) isAlreadyRegistered() bool {
|
||||||
return m.Registered
|
return machine.Registered
|
||||||
}
|
}
|
||||||
|
|
||||||
// isExpired returns whether the machine registration has expired
|
// isExpired returns whether the machine registration has expired.
|
||||||
func (m Machine) isExpired() bool {
|
func (machine Machine) isExpired() bool {
|
||||||
return time.Now().UTC().After(*m.Expiry)
|
return time.Now().UTC().After(*machine.Expiry)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the Machine is expired, updateMachineExpiry updates the Machine Expiry time to the maximum allowed duration,
|
// If the Machine is expired, updateMachineExpiry updates the Machine Expiry time to the maximum allowed duration,
|
||||||
// or the default duration if no Expiry time was requested by the client. The expiry time here does not (yet) cause
|
// or the default duration if no Expiry time was requested by the client. The expiry time here does not (yet) cause
|
||||||
// a client to be disconnected, however they will have to re-auth the machine if they attempt to reconnect after the
|
// a client to be disconnected, however they will have to re-auth the machine if they attempt to reconnect after the
|
||||||
// expiry time.
|
// expiry time.
|
||||||
func (h *Headscale) updateMachineExpiry(m *Machine) {
|
func (h *Headscale) updateMachineExpiry(machine *Machine) {
|
||||||
if m.isExpired() {
|
if machine.isExpired() {
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
maxExpiry := now.Add(h.cfg.MaxMachineRegistrationDuration) // calculate the maximum expiry
|
maxExpiry := now.Add(
|
||||||
defaultExpiry := now.Add(h.cfg.DefaultMachineRegistrationDuration) // calculate the default expiry
|
h.cfg.MaxMachineRegistrationDuration,
|
||||||
|
) // calculate the maximum expiry
|
||||||
|
defaultExpiry := now.Add(
|
||||||
|
h.cfg.DefaultMachineRegistrationDuration,
|
||||||
|
) // calculate the default expiry
|
||||||
|
|
||||||
// clamp the expiry time of the machine registration to the maximum allowed, or use the default if none supplied
|
// clamp the expiry time of the machine registration to the maximum allowed, or use the default if none supplied
|
||||||
if maxExpiry.Before(*m.RequestedExpiry) {
|
if maxExpiry.Before(*machine.RequestedExpiry) {
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Msgf("Clamping registration expiry time to maximum: %v (%v)", maxExpiry, h.cfg.MaxMachineRegistrationDuration)
|
Msgf("Clamping registration expiry time to maximum: %v (%v)", maxExpiry, h.cfg.MaxMachineRegistrationDuration)
|
||||||
m.Expiry = &maxExpiry
|
machine.Expiry = &maxExpiry
|
||||||
} else if m.RequestedExpiry.IsZero() {
|
} else if machine.RequestedExpiry.IsZero() {
|
||||||
log.Debug().Msgf("Using default machine registration expiry time: %v (%v)", defaultExpiry, h.cfg.DefaultMachineRegistrationDuration)
|
log.Debug().Msgf("Using default machine registration expiry time: %v (%v)", defaultExpiry, h.cfg.DefaultMachineRegistrationDuration)
|
||||||
m.Expiry = &defaultExpiry
|
machine.Expiry = &defaultExpiry
|
||||||
} else {
|
} else {
|
||||||
log.Debug().Msgf("Using requested machine registration expiry time: %v", m.RequestedExpiry)
|
log.Debug().Msgf("Using requested machine registration expiry time: %v", machine.RequestedExpiry)
|
||||||
m.Expiry = m.RequestedExpiry
|
machine.Expiry = machine.RequestedExpiry
|
||||||
}
|
}
|
||||||
|
|
||||||
h.db.Save(&m)
|
h.db.Save(&machine)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) {
|
func (h *Headscale) getDirectPeers(machine *Machine) (Machines, error) {
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Finding direct peers")
|
Msg("Finding direct peers")
|
||||||
|
|
||||||
machines := Machines{}
|
machines := Machines{}
|
||||||
if err := h.db.Preload("Namespace").Where("namespace_id = ? AND machine_key <> ? AND registered",
|
if err := h.db.Preload("Namespace").Where("namespace_id = ? AND machine_key <> ? AND registered",
|
||||||
m.NamespaceID, m.MachineKey).Find(&machines).Error; err != nil {
|
machine.NamespaceID, machine.MachineKey).Find(&machines).Error; err != nil {
|
||||||
log.Error().Err(err).Msg("Error accessing db")
|
log.Error().Err(err).Msg("Error accessing db")
|
||||||
|
|
||||||
return Machines{}, err
|
return Machines{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -110,21 +120,22 @@ func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) {
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msgf("Found direct machines: %s", machines.String())
|
Msgf("Found direct machines: %s", machines.String())
|
||||||
|
|
||||||
return machines, nil
|
return machines, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getShared fetches machines that are shared to the `Namespace` of the machine we are getting peers for
|
// getShared fetches machines that are shared to the `Namespace` of the machine we are getting peers for.
|
||||||
func (h *Headscale) getShared(m *Machine) (Machines, error) {
|
func (h *Headscale) getShared(machine *Machine) (Machines, error) {
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Finding shared peers")
|
Msg("Finding shared peers")
|
||||||
|
|
||||||
sharedMachines := []SharedMachine{}
|
sharedMachines := []SharedMachine{}
|
||||||
if err := h.db.Preload("Namespace").Preload("Machine").Preload("Machine.Namespace").Where("namespace_id = ?",
|
if err := h.db.Preload("Namespace").Preload("Machine").Preload("Machine.Namespace").Where("namespace_id = ?",
|
||||||
m.NamespaceID).Find(&sharedMachines).Error; err != nil {
|
machine.NamespaceID).Find(&sharedMachines).Error; err != nil {
|
||||||
return Machines{}, err
|
return Machines{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -137,27 +148,30 @@ func (h *Headscale) getShared(m *Machine) (Machines, error) {
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msgf("Found shared peers: %s", peers.String())
|
Msgf("Found shared peers: %s", peers.String())
|
||||||
|
|
||||||
return peers, nil
|
return peers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getSharedTo fetches the machines of the namespaces this machine is shared in
|
// getSharedTo fetches the machines of the namespaces this machine is shared in.
|
||||||
func (h *Headscale) getSharedTo(m *Machine) (Machines, error) {
|
func (h *Headscale) getSharedTo(machine *Machine) (Machines, error) {
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Finding peers in namespaces this machine is shared with")
|
Msg("Finding peers in namespaces this machine is shared with")
|
||||||
|
|
||||||
sharedMachines := []SharedMachine{}
|
sharedMachines := []SharedMachine{}
|
||||||
if err := h.db.Preload("Namespace").Preload("Machine").Preload("Machine.Namespace").Where("machine_id = ?",
|
if err := h.db.Preload("Namespace").Preload("Machine").Preload("Machine.Namespace").Where("machine_id = ?",
|
||||||
m.ID).Find(&sharedMachines).Error; err != nil {
|
machine.ID).Find(&sharedMachines).Error; err != nil {
|
||||||
return Machines{}, err
|
return Machines{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
peers := make(Machines, 0)
|
peers := make(Machines, 0)
|
||||||
for _, sharedMachine := range sharedMachines {
|
for _, sharedMachine := range sharedMachines {
|
||||||
namespaceMachines, err := h.ListMachinesInNamespace(sharedMachine.Namespace.Name)
|
namespaceMachines, err := h.ListMachinesInNamespace(
|
||||||
|
sharedMachine.Namespace.Name,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Machines{}, err
|
return Machines{}, err
|
||||||
}
|
}
|
||||||
|
@ -168,36 +182,40 @@ func (h *Headscale) getSharedTo(m *Machine) (Machines, error) {
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msgf("Found peers we are shared with: %s", peers.String())
|
Msgf("Found peers we are shared with: %s", peers.String())
|
||||||
|
|
||||||
return peers, nil
|
return peers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) getPeers(m *Machine) (Machines, error) {
|
func (h *Headscale) getPeers(machine *Machine) (Machines, error) {
|
||||||
direct, err := h.getDirectPeers(m)
|
direct, err := h.getDirectPeers(machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot fetch peers")
|
Msg("Cannot fetch peers")
|
||||||
|
|
||||||
return Machines{}, err
|
return Machines{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
shared, err := h.getShared(m)
|
shared, err := h.getShared(machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot fetch peers")
|
Msg("Cannot fetch peers")
|
||||||
|
|
||||||
return Machines{}, err
|
return Machines{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
sharedTo, err := h.getSharedTo(m)
|
sharedTo, err := h.getSharedTo(machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot fetch peers")
|
Msg("Cannot fetch peers")
|
||||||
|
|
||||||
return Machines{}, err
|
return Machines{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -208,7 +226,7 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) {
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msgf("Found total peers: %s", peers.String())
|
Msgf("Found total peers: %s", peers.String())
|
||||||
|
|
||||||
return peers, nil
|
return peers, nil
|
||||||
|
@ -219,10 +237,11 @@ func (h *Headscale) ListMachines() ([]Machine, error) {
|
||||||
if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Find(&machines).Error; err != nil {
|
if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Find(&machines).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return machines, nil
|
return machines, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMachine finds a Machine by name and namespace and returns the Machine struct
|
// GetMachine finds a Machine by name and namespace and returns the Machine struct.
|
||||||
func (h *Headscale) GetMachine(namespace string, name string) (*Machine, error) {
|
func (h *Headscale) GetMachine(namespace string, name string) (*Machine, error) {
|
||||||
machines, err := h.ListMachinesInNamespace(namespace)
|
machines, err := h.ListMachinesInNamespace(namespace)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -234,73 +253,77 @@ func (h *Headscale) GetMachine(namespace string, name string) (*Machine, error)
|
||||||
return &m, nil
|
return &m, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("machine not found")
|
|
||||||
|
return nil, errMachineNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMachineByID finds a Machine by ID and returns the Machine struct
|
// GetMachineByID finds a Machine by ID and returns the Machine struct.
|
||||||
func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
|
func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
|
||||||
m := Machine{}
|
m := Machine{}
|
||||||
if result := h.db.Preload("Namespace").Find(&Machine{ID: id}).First(&m); result.Error != nil {
|
if result := h.db.Preload("Namespace").Find(&Machine{ID: id}).First(&m); result.Error != nil {
|
||||||
return nil, result.Error
|
return nil, result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
return &m, nil
|
return &m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMachineByMachineKey finds a Machine by ID and returns the Machine struct
|
// GetMachineByMachineKey finds a Machine by ID and returns the Machine struct.
|
||||||
func (h *Headscale) GetMachineByMachineKey(mKey string) (*Machine, error) {
|
func (h *Headscale) GetMachineByMachineKey(machineKey string) (*Machine, error) {
|
||||||
m := Machine{}
|
m := Machine{}
|
||||||
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey); result.Error != nil {
|
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", machineKey); result.Error != nil {
|
||||||
return nil, result.Error
|
return nil, result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
return &m, nil
|
return &m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateMachine takes a Machine struct pointer (typically already loaded from database
|
// UpdateMachine takes a Machine struct pointer (typically already loaded from database
|
||||||
// and updates it with the latest data from the database.
|
// and updates it with the latest data from the database.
|
||||||
func (h *Headscale) UpdateMachine(m *Machine) error {
|
func (h *Headscale) UpdateMachine(machine *Machine) error {
|
||||||
if result := h.db.Find(m).First(&m); result.Error != nil {
|
if result := h.db.Find(machine).First(&machine); result.Error != nil {
|
||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteMachine softs deletes a Machine from the database
|
// DeleteMachine softs deletes a Machine from the database.
|
||||||
func (h *Headscale) DeleteMachine(m *Machine) error {
|
func (h *Headscale) DeleteMachine(machine *Machine) error {
|
||||||
err := h.RemoveSharedMachineFromAllNamespaces(m)
|
err := h.RemoveSharedMachineFromAllNamespaces(machine)
|
||||||
if err != nil && err != errorMachineNotShared {
|
if err != nil && errors.Is(err, errMachineNotShared) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
m.Registered = false
|
machine.Registered = false
|
||||||
namespaceID := m.NamespaceID
|
namespaceID := machine.NamespaceID
|
||||||
h.db.Save(&m) // we mark it as unregistered, just in case
|
h.db.Save(&machine) // we mark it as unregistered, just in case
|
||||||
if err := h.db.Delete(&m).Error; err != nil {
|
if err := h.db.Delete(&machine).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return h.RequestMapUpdates(namespaceID)
|
return h.RequestMapUpdates(namespaceID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// HardDeleteMachine hard deletes a Machine from the database
|
// HardDeleteMachine hard deletes a Machine from the database.
|
||||||
func (h *Headscale) HardDeleteMachine(m *Machine) error {
|
func (h *Headscale) HardDeleteMachine(machine *Machine) error {
|
||||||
err := h.RemoveSharedMachineFromAllNamespaces(m)
|
err := h.RemoveSharedMachineFromAllNamespaces(machine)
|
||||||
if err != nil && err != errorMachineNotShared {
|
if err != nil && errors.Is(err, errMachineNotShared) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
namespaceID := m.NamespaceID
|
namespaceID := machine.NamespaceID
|
||||||
if err := h.db.Unscoped().Delete(&m).Error; err != nil {
|
if err := h.db.Unscoped().Delete(&machine).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return h.RequestMapUpdates(namespaceID)
|
return h.RequestMapUpdates(namespaceID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetHostInfo returns a Hostinfo struct for the machine
|
// GetHostInfo returns a Hostinfo struct for the machine.
|
||||||
func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) {
|
func (machine *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) {
|
||||||
hostinfo := tailcfg.Hostinfo{}
|
hostinfo := tailcfg.Hostinfo{}
|
||||||
if len(m.HostInfo) != 0 {
|
if len(machine.HostInfo) != 0 {
|
||||||
hi, err := m.HostInfo.MarshalJSON()
|
hi, err := machine.HostInfo.MarshalJSON()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -309,21 +332,21 @@ func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &hostinfo, nil
|
return &hostinfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) isOutdated(m *Machine) bool {
|
func (h *Headscale) isOutdated(machine *Machine) bool {
|
||||||
err := h.UpdateMachine(m)
|
if err := h.UpdateMachine(machine); err != nil {
|
||||||
if err != nil {
|
|
||||||
// It does not seem meaningful to propagate this error as the end result
|
// It does not seem meaningful to propagate this error as the end result
|
||||||
// will have to be that the machine has to be considered outdated.
|
// will have to be that the machine has to be considered outdated.
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
sharedMachines, _ := h.getShared(m)
|
sharedMachines, _ := h.getShared(machine)
|
||||||
|
|
||||||
namespaceSet := set.New(set.ThreadSafe)
|
namespaceSet := set.New(set.ThreadSafe)
|
||||||
namespaceSet.Add(m.Namespace.Name)
|
namespaceSet.Add(machine.Namespace.Name)
|
||||||
|
|
||||||
// Check if any of our shared namespaces has updates that we have
|
// Check if any of our shared namespaces has updates that we have
|
||||||
// not propagated.
|
// not propagated.
|
||||||
|
@ -333,27 +356,30 @@ func (h *Headscale) isOutdated(m *Machine) bool {
|
||||||
|
|
||||||
namespaces := make([]string, namespaceSet.Size())
|
namespaces := make([]string, namespaceSet.Size())
|
||||||
for index, namespace := range namespaceSet.List() {
|
for index, namespace := range namespaceSet.List() {
|
||||||
namespaces[index] = namespace.(string)
|
if name, ok := namespace.(string); ok {
|
||||||
|
namespaces[index] = name
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
lastChange := h.getLastStateChange(namespaces...)
|
lastChange := h.getLastStateChange(namespaces...)
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Time("last_successful_update", *m.LastSuccessfulUpdate).
|
Time("last_successful_update", *machine.LastSuccessfulUpdate).
|
||||||
Time("last_state_change", lastChange).
|
Time("last_state_change", lastChange).
|
||||||
Msgf("Checking if %s is missing updates", m.Name)
|
Msgf("Checking if %s is missing updates", machine.Name)
|
||||||
return m.LastSuccessfulUpdate.Before(lastChange)
|
|
||||||
|
return machine.LastSuccessfulUpdate.Before(lastChange)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m Machine) String() string {
|
func (machine Machine) String() string {
|
||||||
return m.Name
|
return machine.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ms Machines) String() string {
|
func (machines Machines) String() string {
|
||||||
temp := make([]string, len(ms))
|
temp := make([]string, len(machines))
|
||||||
|
|
||||||
for index, machine := range ms {
|
for index, machine := range machines {
|
||||||
temp[index] = machine.Name
|
temp[index] = machine.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -361,24 +387,24 @@ func (ms Machines) String() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(kradalby): Remove when we have generics...
|
// TODO(kradalby): Remove when we have generics...
|
||||||
func (ms MachinesP) String() string {
|
func (machines MachinesP) String() string {
|
||||||
temp := make([]string, len(ms))
|
temp := make([]string, len(machines))
|
||||||
|
|
||||||
for index, machine := range ms {
|
for index, machine := range machines {
|
||||||
temp[index] = machine.Name
|
temp[index] = machine.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
|
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ms Machines) toNodes(
|
func (machines Machines) toNodes(
|
||||||
baseDomain string,
|
baseDomain string,
|
||||||
dnsConfig *tailcfg.DNSConfig,
|
dnsConfig *tailcfg.DNSConfig,
|
||||||
includeRoutes bool,
|
includeRoutes bool,
|
||||||
) ([]*tailcfg.Node, error) {
|
) ([]*tailcfg.Node, error) {
|
||||||
nodes := make([]*tailcfg.Node, len(ms))
|
nodes := make([]*tailcfg.Node, len(machines))
|
||||||
|
|
||||||
for index, machine := range ms {
|
for index, machine := range machines {
|
||||||
node, err := machine.toNode(baseDomain, dnsConfig, includeRoutes)
|
node, err := machine.toNode(baseDomain, dnsConfig, includeRoutes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -391,20 +417,25 @@ func (ms Machines) toNodes(
|
||||||
}
|
}
|
||||||
|
|
||||||
// toNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes
|
// toNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes
|
||||||
// as per the expected behaviour in the official SaaS
|
// as per the expected behaviour in the official SaaS.
|
||||||
func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, includeRoutes bool) (*tailcfg.Node, error) {
|
func (machine Machine) toNode(
|
||||||
nKey, err := wgkey.ParseHex(m.NodeKey)
|
baseDomain string,
|
||||||
|
dnsConfig *tailcfg.DNSConfig,
|
||||||
|
includeRoutes bool,
|
||||||
|
) (*tailcfg.Node, error) {
|
||||||
|
nodeKey, err := wgkey.ParseHex(machine.NodeKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
mKey, err := wgkey.ParseHex(m.MachineKey)
|
|
||||||
|
machineKey, err := wgkey.ParseHex(machine.MachineKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var discoKey tailcfg.DiscoKey
|
var discoKey tailcfg.DiscoKey
|
||||||
if m.DiscoKey != "" {
|
if machine.DiscoKey != "" {
|
||||||
dKey, err := wgkey.ParseHex(m.DiscoKey)
|
dKey, err := wgkey.ParseHex(machine.DiscoKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -414,23 +445,27 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include
|
||||||
}
|
}
|
||||||
|
|
||||||
addrs := []netaddr.IPPrefix{}
|
addrs := []netaddr.IPPrefix{}
|
||||||
ip, err := netaddr.ParseIPPrefix(fmt.Sprintf("%s/32", m.IPAddress))
|
ip, err := netaddr.ParseIPPrefix(fmt.Sprintf("%s/32", machine.IPAddress))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("ip", m.IPAddress).
|
Str("ip", machine.IPAddress).
|
||||||
Msgf("Failed to parse IP Prefix from IP: %s", m.IPAddress)
|
Msgf("Failed to parse IP Prefix from IP: %s", machine.IPAddress)
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
addrs = append(addrs, ip) // missing the ipv6 ?
|
addrs = append(addrs, ip) // missing the ipv6 ?
|
||||||
|
|
||||||
allowedIPs := []netaddr.IPPrefix{}
|
allowedIPs := []netaddr.IPPrefix{}
|
||||||
allowedIPs = append(allowedIPs, ip) // we append the node own IP, as it is required by the clients
|
allowedIPs = append(
|
||||||
|
allowedIPs,
|
||||||
|
ip,
|
||||||
|
) // we append the node own IP, as it is required by the clients
|
||||||
|
|
||||||
if includeRoutes {
|
if includeRoutes {
|
||||||
routesStr := []string{}
|
routesStr := []string{}
|
||||||
if len(m.EnabledRoutes) != 0 {
|
if len(machine.EnabledRoutes) != 0 {
|
||||||
allwIps, err := m.EnabledRoutes.MarshalJSON()
|
allwIps, err := machine.EnabledRoutes.MarshalJSON()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -450,8 +485,8 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include
|
||||||
}
|
}
|
||||||
|
|
||||||
endpoints := []string{}
|
endpoints := []string{}
|
||||||
if len(m.Endpoints) != 0 {
|
if len(machine.Endpoints) != 0 {
|
||||||
be, err := m.Endpoints.MarshalJSON()
|
be, err := machine.Endpoints.MarshalJSON()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -462,8 +497,8 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include
|
||||||
}
|
}
|
||||||
|
|
||||||
hostinfo := tailcfg.Hostinfo{}
|
hostinfo := tailcfg.Hostinfo{}
|
||||||
if len(m.HostInfo) != 0 {
|
if len(machine.HostInfo) != 0 {
|
||||||
hi, err := m.HostInfo.MarshalJSON()
|
hi, err := machine.HostInfo.MarshalJSON()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -481,29 +516,34 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include
|
||||||
}
|
}
|
||||||
|
|
||||||
var keyExpiry time.Time
|
var keyExpiry time.Time
|
||||||
if m.Expiry != nil {
|
if machine.Expiry != nil {
|
||||||
keyExpiry = *m.Expiry
|
keyExpiry = *machine.Expiry
|
||||||
} else {
|
} else {
|
||||||
keyExpiry = time.Time{}
|
keyExpiry = time.Time{}
|
||||||
}
|
}
|
||||||
|
|
||||||
var hostname string
|
var hostname string
|
||||||
if dnsConfig != nil && dnsConfig.Proxied { // MagicDNS
|
if dnsConfig != nil && dnsConfig.Proxied { // MagicDNS
|
||||||
hostname = fmt.Sprintf("%s.%s.%s", m.Name, m.Namespace.Name, baseDomain)
|
hostname = fmt.Sprintf(
|
||||||
|
"%s.%s.%s",
|
||||||
|
machine.Name,
|
||||||
|
machine.Namespace.Name,
|
||||||
|
baseDomain,
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
hostname = m.Name
|
hostname = machine.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
n := tailcfg.Node{
|
node := tailcfg.Node{
|
||||||
ID: tailcfg.NodeID(m.ID), // this is the actual ID
|
ID: tailcfg.NodeID(machine.ID), // this is the actual ID
|
||||||
StableID: tailcfg.StableNodeID(
|
StableID: tailcfg.StableNodeID(
|
||||||
strconv.FormatUint(m.ID, 10),
|
strconv.FormatUint(machine.ID, Base10),
|
||||||
), // in headscale, unlike tailcontrol server, IDs are permanent
|
), // in headscale, unlike tailcontrol server, IDs are permanent
|
||||||
Name: hostname,
|
Name: hostname,
|
||||||
User: tailcfg.UserID(m.NamespaceID),
|
User: tailcfg.UserID(machine.NamespaceID),
|
||||||
Key: tailcfg.NodeKey(nKey),
|
Key: tailcfg.NodeKey(nodeKey),
|
||||||
KeyExpiry: keyExpiry,
|
KeyExpiry: keyExpiry,
|
||||||
Machine: tailcfg.MachineKey(mKey),
|
Machine: tailcfg.MachineKey(machineKey),
|
||||||
DiscoKey: discoKey,
|
DiscoKey: discoKey,
|
||||||
Addresses: addrs,
|
Addresses: addrs,
|
||||||
AllowedIPs: allowedIPs,
|
AllowedIPs: allowedIPs,
|
||||||
|
@ -511,81 +551,90 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include
|
||||||
DERP: derp,
|
DERP: derp,
|
||||||
|
|
||||||
Hostinfo: hostinfo,
|
Hostinfo: hostinfo,
|
||||||
Created: m.CreatedAt,
|
Created: machine.CreatedAt,
|
||||||
LastSeen: m.LastSeen,
|
LastSeen: machine.LastSeen,
|
||||||
|
|
||||||
KeepAlive: true,
|
KeepAlive: true,
|
||||||
MachineAuthorized: m.Registered,
|
MachineAuthorized: machine.Registered,
|
||||||
Capabilities: []string{tailcfg.CapabilityFileSharing},
|
Capabilities: []string{tailcfg.CapabilityFileSharing},
|
||||||
}
|
}
|
||||||
return &n, nil
|
|
||||||
|
return &node, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Machine) toProto() *v1.Machine {
|
func (machine *Machine) toProto() *v1.Machine {
|
||||||
machine := &v1.Machine{
|
machineProto := &v1.Machine{
|
||||||
Id: m.ID,
|
Id: machine.ID,
|
||||||
MachineKey: m.MachineKey,
|
MachineKey: machine.MachineKey,
|
||||||
|
|
||||||
NodeKey: m.NodeKey,
|
NodeKey: machine.NodeKey,
|
||||||
DiscoKey: m.DiscoKey,
|
DiscoKey: machine.DiscoKey,
|
||||||
IpAddress: m.IPAddress,
|
IpAddress: machine.IPAddress,
|
||||||
Name: m.Name,
|
Name: machine.Name,
|
||||||
Namespace: m.Namespace.toProto(),
|
Namespace: machine.Namespace.toProto(),
|
||||||
|
|
||||||
Registered: m.Registered,
|
Registered: machine.Registered,
|
||||||
|
|
||||||
// TODO(kradalby): Implement register method enum converter
|
// TODO(kradalby): Implement register method enum converter
|
||||||
// RegisterMethod: ,
|
// RegisterMethod: ,
|
||||||
|
|
||||||
CreatedAt: timestamppb.New(m.CreatedAt),
|
CreatedAt: timestamppb.New(machine.CreatedAt),
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.AuthKey != nil {
|
if machine.AuthKey != nil {
|
||||||
machine.PreAuthKey = m.AuthKey.toProto()
|
machineProto.PreAuthKey = machine.AuthKey.toProto()
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.LastSeen != nil {
|
if machine.LastSeen != nil {
|
||||||
machine.LastSeen = timestamppb.New(*m.LastSeen)
|
machineProto.LastSeen = timestamppb.New(*machine.LastSeen)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.LastSuccessfulUpdate != nil {
|
if machine.LastSuccessfulUpdate != nil {
|
||||||
machine.LastSuccessfulUpdate = timestamppb.New(*m.LastSuccessfulUpdate)
|
machineProto.LastSuccessfulUpdate = timestamppb.New(
|
||||||
|
*machine.LastSuccessfulUpdate,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.Expiry != nil {
|
if machine.Expiry != nil {
|
||||||
machine.Expiry = timestamppb.New(*m.Expiry)
|
machineProto.Expiry = timestamppb.New(*machine.Expiry)
|
||||||
}
|
}
|
||||||
|
|
||||||
return machine
|
return machineProto
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey
|
// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey.
|
||||||
func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, error) {
|
func (h *Headscale) RegisterMachine(
|
||||||
ns, err := h.GetNamespace(namespace)
|
key string,
|
||||||
|
namespaceName string,
|
||||||
|
) (*Machine, error) {
|
||||||
|
namespace, err := h.GetNamespace(namespaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
mKey, err := wgkey.ParseHex(key)
|
machineKey, err := wgkey.ParseHex(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
m := Machine{}
|
machine := Machine{}
|
||||||
if result := h.db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if result := h.db.First(&machine, "machine_key = ?", machineKey.HexString()); errors.Is(
|
||||||
return nil, errors.New("Machine not found")
|
result.Error,
|
||||||
|
gorm.ErrRecordNotFound,
|
||||||
|
) {
|
||||||
|
return nil, errMachineNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Attempting to register machine")
|
Msg("Attempting to register machine")
|
||||||
|
|
||||||
if m.isAlreadyRegistered() {
|
if machine.isAlreadyRegistered() {
|
||||||
err := errors.New("Machine already registered")
|
err := errMachineAlreadyRegistered
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
Err(err).
|
Err(err).
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Attempting to register machine")
|
Msg("Attempting to register machine")
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -596,42 +645,44 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
Err(err).
|
Err(err).
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Could not find IP for the new machine")
|
Msg("Could not find IP for the new machine")
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Str("ip", ip.String()).
|
Str("ip", ip.String()).
|
||||||
Msg("Found IP for host")
|
Msg("Found IP for host")
|
||||||
|
|
||||||
m.IPAddress = ip.String()
|
machine.IPAddress = ip.String()
|
||||||
m.NamespaceID = ns.ID
|
machine.NamespaceID = namespace.ID
|
||||||
m.Registered = true
|
machine.Registered = true
|
||||||
m.RegisterMethod = "cli"
|
machine.RegisterMethod = "cli"
|
||||||
h.db.Save(&m)
|
h.db.Save(&machine)
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Str("ip", ip.String()).
|
Str("ip", ip.String()).
|
||||||
Msg("Machine registered with the database")
|
Msg("Machine registered with the database")
|
||||||
|
|
||||||
return &m, nil
|
return &machine, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) {
|
func (machine *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) {
|
||||||
hostInfo, err := m.GetHostInfo()
|
hostInfo, err := machine.GetHostInfo()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return hostInfo.RoutableIPs, nil
|
return hostInfo.RoutableIPs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) {
|
func (machine *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) {
|
||||||
data, err := m.EnabledRoutes.MarshalJSON()
|
data, err := machine.EnabledRoutes.MarshalJSON()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -654,13 +705,13 @@ func (m *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) {
|
||||||
return routes, nil
|
return routes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Machine) IsRoutesEnabled(routeStr string) bool {
|
func (machine *Machine) IsRoutesEnabled(routeStr string) bool {
|
||||||
route, err := netaddr.ParseIPPrefix(routeStr)
|
route, err := netaddr.ParseIPPrefix(routeStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
enabledRoutes, err := m.GetEnabledRoutes()
|
enabledRoutes, err := machine.GetEnabledRoutes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -670,12 +721,13 @@ func (m *Machine) IsRoutesEnabled(routeStr string) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// EnableNodeRoute enables new routes based on a list of new routes. It will _replace_ the
|
// EnableNodeRoute enables new routes based on a list of new routes. It will _replace_ the
|
||||||
// previous list of routes.
|
// previous list of routes.
|
||||||
func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
|
func (h *Headscale) EnableRoutes(machine *Machine, routeStrs ...string) error {
|
||||||
newRoutes := make([]netaddr.IPPrefix, len(routeStrs))
|
newRoutes := make([]netaddr.IPPrefix, len(routeStrs))
|
||||||
for index, routeStr := range routeStrs {
|
for index, routeStr := range routeStrs {
|
||||||
route, err := netaddr.ParseIPPrefix(routeStr)
|
route, err := netaddr.ParseIPPrefix(routeStr)
|
||||||
|
@ -686,14 +738,18 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
|
||||||
newRoutes[index] = route
|
newRoutes[index] = route
|
||||||
}
|
}
|
||||||
|
|
||||||
availableRoutes, err := m.GetAdvertisedRoutes()
|
availableRoutes, err := machine.GetAdvertisedRoutes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, newRoute := range newRoutes {
|
for _, newRoute := range newRoutes {
|
||||||
if !containsIpPrefix(availableRoutes, newRoute) {
|
if !containsIPPrefix(availableRoutes, newRoute) {
|
||||||
return fmt.Errorf("route (%s) is not available on node %s", m.Name, newRoute)
|
return fmt.Errorf(
|
||||||
|
"route (%s) is not available on node %s: %w",
|
||||||
|
machine.Name,
|
||||||
|
newRoute, errMachineRouteIsNotAvailable,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -702,10 +758,10 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
m.EnabledRoutes = datatypes.JSON(routes)
|
machine.EnabledRoutes = datatypes.JSON(routes)
|
||||||
h.db.Save(&m)
|
h.db.Save(&machine)
|
||||||
|
|
||||||
err = h.RequestMapUpdates(m.NamespaceID)
|
err = h.RequestMapUpdates(machine.NamespaceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -713,13 +769,13 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Machine) RoutesToProto() (*v1.Routes, error) {
|
func (machine *Machine) RoutesToProto() (*v1.Routes, error) {
|
||||||
availableRoutes, err := m.GetAdvertisedRoutes()
|
availableRoutes, err := machine.GetAdvertisedRoutes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
enabledRoutes, err := m.GetEnabledRoutes()
|
enabledRoutes, err := machine.GetEnabledRoutes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
113
machine_test.go
113
machine_test.go
|
@ -8,152 +8,159 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Suite) TestGetMachine(c *check.C) {
|
func (s *Suite) TestGetMachine(c *check.C) {
|
||||||
n, err := h.CreateNamespace("test")
|
namespace, err := app.CreateNamespace("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = h.GetMachine("test", "testmachine")
|
_, err = app.GetMachine("test", "testmachine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
m := &Machine{
|
machine := &Machine{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: "foo",
|
MachineKey: "foo",
|
||||||
NodeKey: "bar",
|
NodeKey: "bar",
|
||||||
DiscoKey: "faa",
|
DiscoKey: "faa",
|
||||||
Name: "testmachine",
|
Name: "testmachine",
|
||||||
NamespaceID: n.ID,
|
NamespaceID: namespace.ID,
|
||||||
Registered: true,
|
Registered: true,
|
||||||
RegisterMethod: "authKey",
|
RegisterMethod: "authKey",
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
h.db.Save(m)
|
app.db.Save(machine)
|
||||||
|
|
||||||
m1, err := h.GetMachine("test", "testmachine")
|
machineFromDB, err := app.GetMachine("test", "testmachine")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = m1.GetHostInfo()
|
_, err = machineFromDB.GetHostInfo()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestGetMachineByID(c *check.C) {
|
func (s *Suite) TestGetMachineByID(c *check.C) {
|
||||||
n, err := h.CreateNamespace("test")
|
namespace, err := app.CreateNamespace("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = h.GetMachineByID(0)
|
_, err = app.GetMachineByID(0)
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
m := Machine{
|
machine := Machine{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: "foo",
|
MachineKey: "foo",
|
||||||
NodeKey: "bar",
|
NodeKey: "bar",
|
||||||
DiscoKey: "faa",
|
DiscoKey: "faa",
|
||||||
Name: "testmachine",
|
Name: "testmachine",
|
||||||
NamespaceID: n.ID,
|
NamespaceID: namespace.ID,
|
||||||
Registered: true,
|
Registered: true,
|
||||||
RegisterMethod: "authKey",
|
RegisterMethod: "authKey",
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
h.db.Save(&m)
|
app.db.Save(&machine)
|
||||||
|
|
||||||
m1, err := h.GetMachineByID(0)
|
machineByID, err := app.GetMachineByID(0)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = m1.GetHostInfo()
|
_, err = machineByID.GetHostInfo()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestDeleteMachine(c *check.C) {
|
func (s *Suite) TestDeleteMachine(c *check.C) {
|
||||||
n, err := h.CreateNamespace("test")
|
namespace, err := app.CreateNamespace("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
m := Machine{
|
machine := Machine{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: "foo",
|
MachineKey: "foo",
|
||||||
NodeKey: "bar",
|
NodeKey: "bar",
|
||||||
DiscoKey: "faa",
|
DiscoKey: "faa",
|
||||||
Name: "testmachine",
|
Name: "testmachine",
|
||||||
NamespaceID: n.ID,
|
NamespaceID: namespace.ID,
|
||||||
Registered: true,
|
Registered: true,
|
||||||
RegisterMethod: "authKey",
|
RegisterMethod: "authKey",
|
||||||
AuthKeyID: uint(1),
|
AuthKeyID: uint(1),
|
||||||
}
|
}
|
||||||
h.db.Save(&m)
|
app.db.Save(&machine)
|
||||||
err = h.DeleteMachine(&m)
|
|
||||||
|
err = app.DeleteMachine(&machine)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
v, err := h.getValue("namespaces_pending_updates")
|
|
||||||
|
namespacesPendingUpdates, err := app.getValue("namespaces_pending_updates")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
names := []string{}
|
names := []string{}
|
||||||
err = json.Unmarshal([]byte(v), &names)
|
err = json.Unmarshal([]byte(namespacesPendingUpdates), &names)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(names, check.DeepEquals, []string{n.Name})
|
c.Assert(names, check.DeepEquals, []string{namespace.Name})
|
||||||
h.checkForNamespacesPendingUpdates()
|
|
||||||
v, _ = h.getValue("namespaces_pending_updates")
|
app.checkForNamespacesPendingUpdates()
|
||||||
c.Assert(v, check.Equals, "")
|
|
||||||
_, err = h.GetMachine(n.Name, "testmachine")
|
namespacesPendingUpdates, _ = app.getValue("namespaces_pending_updates")
|
||||||
|
c.Assert(namespacesPendingUpdates, check.Equals, "")
|
||||||
|
_, err = app.GetMachine(namespace.Name, "testmachine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestHardDeleteMachine(c *check.C) {
|
func (s *Suite) TestHardDeleteMachine(c *check.C) {
|
||||||
n, err := h.CreateNamespace("test")
|
namespace, err := app.CreateNamespace("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
m := Machine{
|
machine := Machine{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: "foo",
|
MachineKey: "foo",
|
||||||
NodeKey: "bar",
|
NodeKey: "bar",
|
||||||
DiscoKey: "faa",
|
DiscoKey: "faa",
|
||||||
Name: "testmachine3",
|
Name: "testmachine3",
|
||||||
NamespaceID: n.ID,
|
NamespaceID: namespace.ID,
|
||||||
Registered: true,
|
Registered: true,
|
||||||
RegisterMethod: "authKey",
|
RegisterMethod: "authKey",
|
||||||
AuthKeyID: uint(1),
|
AuthKeyID: uint(1),
|
||||||
}
|
}
|
||||||
h.db.Save(&m)
|
app.db.Save(&machine)
|
||||||
err = h.HardDeleteMachine(&m)
|
|
||||||
|
err = app.HardDeleteMachine(&machine)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
_, err = h.GetMachine(n.Name, "testmachine3")
|
|
||||||
|
_, err = app.GetMachine(namespace.Name, "testmachine3")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestGetDirectPeers(c *check.C) {
|
func (s *Suite) TestGetDirectPeers(c *check.C) {
|
||||||
n, err := h.CreateNamespace("test")
|
namespace, err := app.CreateNamespace("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = h.GetMachineByID(0)
|
_, err = app.GetMachineByID(0)
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
for i := 0; i <= 10; i++ {
|
for index := 0; index <= 10; index++ {
|
||||||
m := Machine{
|
machine := Machine{
|
||||||
ID: uint64(i),
|
ID: uint64(index),
|
||||||
MachineKey: "foo" + strconv.Itoa(i),
|
MachineKey: "foo" + strconv.Itoa(index),
|
||||||
NodeKey: "bar" + strconv.Itoa(i),
|
NodeKey: "bar" + strconv.Itoa(index),
|
||||||
DiscoKey: "faa" + strconv.Itoa(i),
|
DiscoKey: "faa" + strconv.Itoa(index),
|
||||||
Name: "testmachine" + strconv.Itoa(i),
|
Name: "testmachine" + strconv.Itoa(index),
|
||||||
NamespaceID: n.ID,
|
NamespaceID: namespace.ID,
|
||||||
Registered: true,
|
Registered: true,
|
||||||
RegisterMethod: "authKey",
|
RegisterMethod: "authKey",
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
h.db.Save(&m)
|
app.db.Save(&machine)
|
||||||
}
|
}
|
||||||
|
|
||||||
m1, err := h.GetMachineByID(0)
|
machine0ByID, err := app.GetMachineByID(0)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = m1.GetHostInfo()
|
_, err = machine0ByID.GetHostInfo()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
peers, err := h.getDirectPeers(m1)
|
peersOfMachine0, err := app.getDirectPeers(machine0ByID)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
c.Assert(len(peers), check.Equals, 9)
|
c.Assert(len(peersOfMachine0), check.Equals, 9)
|
||||||
c.Assert(peers[0].Name, check.Equals, "testmachine2")
|
c.Assert(peersOfMachine0[0].Name, check.Equals, "testmachine2")
|
||||||
c.Assert(peers[5].Name, check.Equals, "testmachine7")
|
c.Assert(peersOfMachine0[5].Name, check.Equals, "testmachine7")
|
||||||
c.Assert(peers[8].Name, check.Equals, "testmachine10")
|
c.Assert(peersOfMachine0[8].Name, check.Equals, "testmachine10")
|
||||||
}
|
}
|
||||||
|
|
149
namespaces.go
149
namespaces.go
|
@ -15,9 +15,9 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
errorNamespaceExists = Error("Namespace already exists")
|
errNamespaceExists = Error("Namespace already exists")
|
||||||
errorNamespaceNotFound = Error("Namespace not found")
|
errNamespaceNotFound = Error("Namespace not found")
|
||||||
errorNamespaceNotEmptyOfNodes = Error("Namespace not empty: node(s) found")
|
errNamespaceNotEmptyOfNodes = Error("Namespace not empty: node(s) found")
|
||||||
)
|
)
|
||||||
|
|
||||||
// Namespace is the way Headscale implements the concept of users in Tailscale
|
// Namespace is the way Headscale implements the concept of users in Tailscale
|
||||||
|
@ -30,51 +30,53 @@ type Namespace struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateNamespace creates a new Namespace. Returns error if could not be created
|
// CreateNamespace creates a new Namespace. Returns error if could not be created
|
||||||
// or another namespace already exists
|
// or another namespace already exists.
|
||||||
func (h *Headscale) CreateNamespace(name string) (*Namespace, error) {
|
func (h *Headscale) CreateNamespace(name string) (*Namespace, error) {
|
||||||
n := Namespace{}
|
namespace := Namespace{}
|
||||||
if err := h.db.Where("name = ?", name).First(&n).Error; err == nil {
|
if err := h.db.Where("name = ?", name).First(&namespace).Error; err == nil {
|
||||||
return nil, errorNamespaceExists
|
return nil, errNamespaceExists
|
||||||
}
|
}
|
||||||
n.Name = name
|
namespace.Name = name
|
||||||
if err := h.db.Create(&n).Error; err != nil {
|
if err := h.db.Create(&namespace).Error; err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("func", "CreateNamespace").
|
Str("func", "CreateNamespace").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Could not create row")
|
Msg("Could not create row")
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &n, nil
|
|
||||||
|
return &namespace, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DestroyNamespace destroys a Namespace. Returns error if the Namespace does
|
// DestroyNamespace destroys a Namespace. Returns error if the Namespace does
|
||||||
// not exist or if there are machines associated with it.
|
// not exist or if there are machines associated with it.
|
||||||
func (h *Headscale) DestroyNamespace(name string) error {
|
func (h *Headscale) DestroyNamespace(name string) error {
|
||||||
n, err := h.GetNamespace(name)
|
namespace, err := h.GetNamespace(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorNamespaceNotFound
|
return errNamespaceNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := h.ListMachinesInNamespace(name)
|
machines, err := h.ListMachinesInNamespace(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if len(m) > 0 {
|
if len(machines) > 0 {
|
||||||
return errorNamespaceNotEmptyOfNodes
|
return errNamespaceNotEmptyOfNodes
|
||||||
}
|
}
|
||||||
|
|
||||||
keys, err := h.ListPreAuthKeys(name)
|
keys, err := h.ListPreAuthKeys(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for _, p := range keys {
|
for _, key := range keys {
|
||||||
err = h.DestroyPreAuthKey(&p)
|
err = h.DestroyPreAuthKey(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if result := h.db.Unscoped().Delete(&n); result.Error != nil {
|
if result := h.db.Unscoped().Delete(&namespace); result.Error != nil {
|
||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -84,25 +86,25 @@ func (h *Headscale) DestroyNamespace(name string) error {
|
||||||
// RenameNamespace renames a Namespace. Returns error if the Namespace does
|
// RenameNamespace renames a Namespace. Returns error if the Namespace does
|
||||||
// not exist or if another Namespace exists with the new name.
|
// not exist or if another Namespace exists with the new name.
|
||||||
func (h *Headscale) RenameNamespace(oldName, newName string) error {
|
func (h *Headscale) RenameNamespace(oldName, newName string) error {
|
||||||
n, err := h.GetNamespace(oldName)
|
oldNamespace, err := h.GetNamespace(oldName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = h.GetNamespace(newName)
|
_, err = h.GetNamespace(newName)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return errorNamespaceExists
|
return errNamespaceExists
|
||||||
}
|
}
|
||||||
if !errors.Is(err, errorNamespaceNotFound) {
|
if !errors.Is(err, errNamespaceNotFound) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
n.Name = newName
|
oldNamespace.Name = newName
|
||||||
|
|
||||||
if result := h.db.Save(&n); result.Error != nil {
|
if result := h.db.Save(&oldNamespace); result.Error != nil {
|
||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.RequestMapUpdates(n.ID)
|
err = h.RequestMapUpdates(oldNamespace.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -110,39 +112,45 @@ func (h *Headscale) RenameNamespace(oldName, newName string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetNamespace fetches a namespace by name
|
// GetNamespace fetches a namespace by name.
|
||||||
func (h *Headscale) GetNamespace(name string) (*Namespace, error) {
|
func (h *Headscale) GetNamespace(name string) (*Namespace, error) {
|
||||||
n := Namespace{}
|
namespace := Namespace{}
|
||||||
if result := h.db.First(&n, "name = ?", name); errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if result := h.db.First(&namespace, "name = ?", name); errors.Is(
|
||||||
return nil, errorNamespaceNotFound
|
result.Error,
|
||||||
}
|
gorm.ErrRecordNotFound,
|
||||||
return &n, nil
|
) {
|
||||||
|
return nil, errNamespaceNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListNamespaces gets all the existing namespaces
|
return &namespace, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListNamespaces gets all the existing namespaces.
|
||||||
func (h *Headscale) ListNamespaces() ([]Namespace, error) {
|
func (h *Headscale) ListNamespaces() ([]Namespace, error) {
|
||||||
namespaces := []Namespace{}
|
namespaces := []Namespace{}
|
||||||
if err := h.db.Find(&namespaces).Error; err != nil {
|
if err := h.db.Find(&namespaces).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return namespaces, nil
|
return namespaces, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListMachinesInNamespace gets all the nodes in a given namespace
|
// ListMachinesInNamespace gets all the nodes in a given namespace.
|
||||||
func (h *Headscale) ListMachinesInNamespace(name string) ([]Machine, error) {
|
func (h *Headscale) ListMachinesInNamespace(name string) ([]Machine, error) {
|
||||||
n, err := h.GetNamespace(name)
|
namespace, err := h.GetNamespace(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
machines := []Machine{}
|
machines := []Machine{}
|
||||||
if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Where(&Machine{NamespaceID: n.ID}).Find(&machines).Error; err != nil {
|
if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Where(&Machine{NamespaceID: namespace.ID}).Find(&machines).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return machines, nil
|
return machines, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListSharedMachinesInNamespace returns all the machines that are shared to the specified namespace
|
// ListSharedMachinesInNamespace returns all the machines that are shared to the specified namespace.
|
||||||
func (h *Headscale) ListSharedMachinesInNamespace(name string) ([]Machine, error) {
|
func (h *Headscale) ListSharedMachinesInNamespace(name string) ([]Machine, error) {
|
||||||
namespace, err := h.GetNamespace(name)
|
namespace, err := h.GetNamespace(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -155,48 +163,61 @@ func (h *Headscale) ListSharedMachinesInNamespace(name string) ([]Machine, error
|
||||||
|
|
||||||
machines := []Machine{}
|
machines := []Machine{}
|
||||||
for _, sharedMachine := range sharedMachines {
|
for _, sharedMachine := range sharedMachines {
|
||||||
machine, err := h.GetMachineByID(sharedMachine.MachineID) // otherwise not everything comes filled
|
machine, err := h.GetMachineByID(
|
||||||
|
sharedMachine.MachineID,
|
||||||
|
) // otherwise not everything comes filled
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
machines = append(machines, *machine)
|
machines = append(machines, *machine)
|
||||||
}
|
}
|
||||||
|
|
||||||
return machines, nil
|
return machines, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetMachineNamespace assigns a Machine to a namespace
|
// SetMachineNamespace assigns a Machine to a namespace.
|
||||||
func (h *Headscale) SetMachineNamespace(m *Machine, namespaceName string) error {
|
func (h *Headscale) SetMachineNamespace(machine *Machine, namespaceName string) error {
|
||||||
n, err := h.GetNamespace(namespaceName)
|
namespace, err := h.GetNamespace(namespaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
m.NamespaceID = n.ID
|
machine.NamespaceID = namespace.ID
|
||||||
h.db.Save(&m)
|
h.db.Save(&machine)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RequestMapUpdates signals the KV worker to update the maps for this namespace
|
// TODO(kradalby): Remove the need for this.
|
||||||
|
// RequestMapUpdates signals the KV worker to update the maps for this namespace.
|
||||||
func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
|
func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
|
||||||
namespace := Namespace{}
|
namespace := Namespace{}
|
||||||
if err := h.db.First(&namespace, namespaceID).Error; err != nil {
|
if err := h.db.First(&namespace, namespaceID).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
v, err := h.getValue("namespaces_pending_updates")
|
namespacesPendingUpdates, err := h.getValue("namespaces_pending_updates")
|
||||||
if err != nil || v == "" {
|
if err != nil || namespacesPendingUpdates == "" {
|
||||||
err = h.setValue("namespaces_pending_updates", fmt.Sprintf(`["%s"]`, namespace.Name))
|
err = h.setValue(
|
||||||
|
"namespaces_pending_updates",
|
||||||
|
fmt.Sprintf(`["%s"]`, namespace.Name),
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
names := []string{}
|
names := []string{}
|
||||||
err = json.Unmarshal([]byte(v), &names)
|
err = json.Unmarshal([]byte(namespacesPendingUpdates), &names)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = h.setValue("namespaces_pending_updates", fmt.Sprintf(`["%s"]`, namespace.Name))
|
err = h.setValue(
|
||||||
|
"namespaces_pending_updates",
|
||||||
|
fmt.Sprintf(`["%s"]`, namespace.Name),
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -207,22 +228,24 @@ func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
|
||||||
Str("func", "RequestMapUpdates").
|
Str("func", "RequestMapUpdates").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Could not marshal namespaces_pending_updates")
|
Msg("Could not marshal namespaces_pending_updates")
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return h.setValue("namespaces_pending_updates", string(data))
|
return h.setValue("namespaces_pending_updates", string(data))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) checkForNamespacesPendingUpdates() {
|
func (h *Headscale) checkForNamespacesPendingUpdates() {
|
||||||
v, err := h.getValue("namespaces_pending_updates")
|
namespacesPendingUpdates, err := h.getValue("namespaces_pending_updates")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if v == "" {
|
if namespacesPendingUpdates == "" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
namespaces := []string{}
|
namespaces := []string{}
|
||||||
err = json.Unmarshal([]byte(v), &namespaces)
|
err = json.Unmarshal([]byte(namespacesPendingUpdates), &namespaces)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -233,24 +256,25 @@ func (h *Headscale) checkForNamespacesPendingUpdates() {
|
||||||
Msg("Sending updates to nodes in namespacespace")
|
Msg("Sending updates to nodes in namespacespace")
|
||||||
h.setLastStateChangeToNow(namespace)
|
h.setLastStateChangeToNow(namespace)
|
||||||
}
|
}
|
||||||
newV, err := h.getValue("namespaces_pending_updates")
|
newPendingUpdateValue, err := h.getValue("namespaces_pending_updates")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if v == newV { // only clear when no changes, so we notified everybody
|
if namespacesPendingUpdates == newPendingUpdateValue { // only clear when no changes, so we notified everybody
|
||||||
err = h.setValue("namespaces_pending_updates", "")
|
err = h.setValue("namespaces_pending_updates", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("func", "checkForNamespacesPendingUpdates").
|
Str("func", "checkForNamespacesPendingUpdates").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Could not save to KV")
|
Msg("Could not save to KV")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *Namespace) toUser() *tailcfg.User {
|
func (n *Namespace) toUser() *tailcfg.User {
|
||||||
u := tailcfg.User{
|
user := tailcfg.User{
|
||||||
ID: tailcfg.UserID(n.ID),
|
ID: tailcfg.UserID(n.ID),
|
||||||
LoginName: n.Name,
|
LoginName: n.Name,
|
||||||
DisplayName: n.Name,
|
DisplayName: n.Name,
|
||||||
|
@ -259,25 +283,27 @@ func (n *Namespace) toUser() *tailcfg.User {
|
||||||
Logins: []tailcfg.LoginID{},
|
Logins: []tailcfg.LoginID{},
|
||||||
Created: time.Time{},
|
Created: time.Time{},
|
||||||
}
|
}
|
||||||
return &u
|
|
||||||
|
return &user
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *Namespace) toLogin() *tailcfg.Login {
|
func (n *Namespace) toLogin() *tailcfg.Login {
|
||||||
l := tailcfg.Login{
|
login := tailcfg.Login{
|
||||||
ID: tailcfg.LoginID(n.ID),
|
ID: tailcfg.LoginID(n.ID),
|
||||||
LoginName: n.Name,
|
LoginName: n.Name,
|
||||||
DisplayName: n.Name,
|
DisplayName: n.Name,
|
||||||
ProfilePicURL: "",
|
ProfilePicURL: "",
|
||||||
Domain: "headscale.net",
|
Domain: "headscale.net",
|
||||||
}
|
}
|
||||||
return &l
|
|
||||||
|
return &login
|
||||||
}
|
}
|
||||||
|
|
||||||
func getMapResponseUserProfiles(m Machine, peers Machines) []tailcfg.UserProfile {
|
func getMapResponseUserProfiles(machine Machine, peers Machines) []tailcfg.UserProfile {
|
||||||
namespaceMap := make(map[string]Namespace)
|
namespaceMap := make(map[string]Namespace)
|
||||||
namespaceMap[m.Namespace.Name] = m.Namespace
|
namespaceMap[machine.Namespace.Name] = machine.Namespace
|
||||||
for _, p := range peers {
|
for _, peer := range peers {
|
||||||
namespaceMap[p.Namespace.Name] = p.Namespace // not worth checking if already is there
|
namespaceMap[peer.Namespace.Name] = peer.Namespace // not worth checking if already is there
|
||||||
}
|
}
|
||||||
|
|
||||||
profiles := []tailcfg.UserProfile{}
|
profiles := []tailcfg.UserProfile{}
|
||||||
|
@ -289,12 +315,13 @@ func getMapResponseUserProfiles(m Machine, peers Machines) []tailcfg.UserProfile
|
||||||
DisplayName: namespace.Name,
|
DisplayName: namespace.Name,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return profiles
|
return profiles
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *Namespace) toProto() *v1.Namespace {
|
func (n *Namespace) toProto() *v1.Namespace {
|
||||||
return &v1.Namespace{
|
return &v1.Namespace{
|
||||||
Id: strconv.FormatUint(uint64(n.ID), 10),
|
Id: strconv.FormatUint(uint64(n.ID), Base10),
|
||||||
Name: n.Name,
|
Name: n.Name,
|
||||||
CreatedAt: timestamppb.New(n.CreatedAt),
|
CreatedAt: timestamppb.New(n.CreatedAt),
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,207 +7,232 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Suite) TestCreateAndDestroyNamespace(c *check.C) {
|
func (s *Suite) TestCreateAndDestroyNamespace(c *check.C) {
|
||||||
n, err := h.CreateNamespace("test")
|
namespace, err := app.CreateNamespace("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(n.Name, check.Equals, "test")
|
c.Assert(namespace.Name, check.Equals, "test")
|
||||||
|
|
||||||
ns, err := h.ListNamespaces()
|
namespaces, err := app.ListNamespaces()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(ns), check.Equals, 1)
|
c.Assert(len(namespaces), check.Equals, 1)
|
||||||
|
|
||||||
err = h.DestroyNamespace("test")
|
err = app.DestroyNamespace("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = h.GetNamespace("test")
|
_, err = app.GetNamespace("test")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestDestroyNamespaceErrors(c *check.C) {
|
func (s *Suite) TestDestroyNamespaceErrors(c *check.C) {
|
||||||
err := h.DestroyNamespace("test")
|
err := app.DestroyNamespace("test")
|
||||||
c.Assert(err, check.Equals, errorNamespaceNotFound)
|
c.Assert(err, check.Equals, errNamespaceNotFound)
|
||||||
|
|
||||||
n, err := h.CreateNamespace("test")
|
namespace, err := app.CreateNamespace("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = h.DestroyNamespace("test")
|
err = app.DestroyNamespace("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
result := h.db.Preload("Namespace").First(&pak, "key = ?", pak.Key)
|
result := app.db.Preload("Namespace").First(&pak, "key = ?", pak.Key)
|
||||||
// destroying a namespace also deletes all associated preauthkeys
|
// destroying a namespace also deletes all associated preauthkeys
|
||||||
c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound)
|
c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound)
|
||||||
|
|
||||||
n, err = h.CreateNamespace("test")
|
namespace, err = app.CreateNamespace("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err = h.CreatePreAuthKey(n.Name, false, false, nil)
|
pak, err = app.CreatePreAuthKey(namespace.Name, false, false, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
m := Machine{
|
machine := Machine{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: "foo",
|
MachineKey: "foo",
|
||||||
NodeKey: "bar",
|
NodeKey: "bar",
|
||||||
DiscoKey: "faa",
|
DiscoKey: "faa",
|
||||||
Name: "testmachine",
|
Name: "testmachine",
|
||||||
NamespaceID: n.ID,
|
NamespaceID: namespace.ID,
|
||||||
Registered: true,
|
Registered: true,
|
||||||
RegisterMethod: "authKey",
|
RegisterMethod: "authKey",
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
h.db.Save(&m)
|
app.db.Save(&machine)
|
||||||
|
|
||||||
err = h.DestroyNamespace("test")
|
err = app.DestroyNamespace("test")
|
||||||
c.Assert(err, check.Equals, errorNamespaceNotEmptyOfNodes)
|
c.Assert(err, check.Equals, errNamespaceNotEmptyOfNodes)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestRenameNamespace(c *check.C) {
|
func (s *Suite) TestRenameNamespace(c *check.C) {
|
||||||
n, err := h.CreateNamespace("test")
|
namespaceTest, err := app.CreateNamespace("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(n.Name, check.Equals, "test")
|
c.Assert(namespaceTest.Name, check.Equals, "test")
|
||||||
|
|
||||||
ns, err := h.ListNamespaces()
|
namespaces, err := app.ListNamespaces()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(ns), check.Equals, 1)
|
c.Assert(len(namespaces), check.Equals, 1)
|
||||||
|
|
||||||
err = h.RenameNamespace("test", "test_renamed")
|
err = app.RenameNamespace("test", "test_renamed")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = h.GetNamespace("test")
|
_, err = app.GetNamespace("test")
|
||||||
c.Assert(err, check.Equals, errorNamespaceNotFound)
|
c.Assert(err, check.Equals, errNamespaceNotFound)
|
||||||
|
|
||||||
_, err = h.GetNamespace("test_renamed")
|
_, err = app.GetNamespace("test_renamed")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = h.RenameNamespace("test_does_not_exit", "test")
|
err = app.RenameNamespace("test_does_not_exit", "test")
|
||||||
c.Assert(err, check.Equals, errorNamespaceNotFound)
|
c.Assert(err, check.Equals, errNamespaceNotFound)
|
||||||
|
|
||||||
n2, err := h.CreateNamespace("test2")
|
namespaceTest2, err := app.CreateNamespace("test2")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(n2.Name, check.Equals, "test2")
|
c.Assert(namespaceTest2.Name, check.Equals, "test2")
|
||||||
|
|
||||||
err = h.RenameNamespace("test2", "test_renamed")
|
err = app.RenameNamespace("test2", "test_renamed")
|
||||||
c.Assert(err, check.Equals, errorNamespaceExists)
|
c.Assert(err, check.Equals, errNamespaceExists)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
|
func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
|
||||||
n1, err := h.CreateNamespace("shared1")
|
namespaceShared1, err := app.CreateNamespace("shared1")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
n2, err := h.CreateNamespace("shared2")
|
namespaceShared2, err := app.CreateNamespace("shared2")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
n3, err := h.CreateNamespace("shared3")
|
namespaceShared3, err := app.CreateNamespace("shared3")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak1n1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
|
preAuthKeyShared1, err := app.CreatePreAuthKey(
|
||||||
|
namespaceShared1.Name,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak2n2, err := h.CreatePreAuthKey(n2.Name, false, false, nil)
|
preAuthKeyShared2, err := app.CreatePreAuthKey(
|
||||||
|
namespaceShared2.Name,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak3n3, err := h.CreatePreAuthKey(n3.Name, false, false, nil)
|
preAuthKeyShared3, err := app.CreatePreAuthKey(
|
||||||
|
namespaceShared3.Name,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak4n1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
|
preAuthKey2Shared1, err := app.CreatePreAuthKey(
|
||||||
|
namespaceShared1.Name,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = h.GetMachine(n1.Name, "test_get_shared_nodes_1")
|
_, err = app.GetMachine(namespaceShared1.Name, "test_get_shared_nodes_1")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
m1 := &Machine{
|
machineInShared1 := &Machine{
|
||||||
ID: 1,
|
ID: 1,
|
||||||
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||||
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||||
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||||
Name: "test_get_shared_nodes_1",
|
Name: "test_get_shared_nodes_1",
|
||||||
NamespaceID: n1.ID,
|
NamespaceID: namespaceShared1.ID,
|
||||||
Namespace: *n1,
|
Namespace: *namespaceShared1,
|
||||||
Registered: true,
|
Registered: true,
|
||||||
RegisterMethod: "authKey",
|
RegisterMethod: "authKey",
|
||||||
IPAddress: "100.64.0.1",
|
IPAddress: "100.64.0.1",
|
||||||
AuthKeyID: uint(pak1n1.ID),
|
AuthKeyID: uint(preAuthKeyShared1.ID),
|
||||||
}
|
}
|
||||||
h.db.Save(m1)
|
app.db.Save(machineInShared1)
|
||||||
|
|
||||||
_, err = h.GetMachine(n1.Name, m1.Name)
|
_, err = app.GetMachine(namespaceShared1.Name, machineInShared1.Name)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
m2 := &Machine{
|
machineInShared2 := &Machine{
|
||||||
ID: 2,
|
ID: 2,
|
||||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
Name: "test_get_shared_nodes_2",
|
Name: "test_get_shared_nodes_2",
|
||||||
NamespaceID: n2.ID,
|
NamespaceID: namespaceShared2.ID,
|
||||||
Namespace: *n2,
|
Namespace: *namespaceShared2,
|
||||||
Registered: true,
|
Registered: true,
|
||||||
RegisterMethod: "authKey",
|
RegisterMethod: "authKey",
|
||||||
IPAddress: "100.64.0.2",
|
IPAddress: "100.64.0.2",
|
||||||
AuthKeyID: uint(pak2n2.ID),
|
AuthKeyID: uint(preAuthKeyShared2.ID),
|
||||||
}
|
}
|
||||||
h.db.Save(m2)
|
app.db.Save(machineInShared2)
|
||||||
|
|
||||||
_, err = h.GetMachine(n2.Name, m2.Name)
|
_, err = app.GetMachine(namespaceShared2.Name, machineInShared2.Name)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
m3 := &Machine{
|
machineInShared3 := &Machine{
|
||||||
ID: 3,
|
ID: 3,
|
||||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
Name: "test_get_shared_nodes_3",
|
Name: "test_get_shared_nodes_3",
|
||||||
NamespaceID: n3.ID,
|
NamespaceID: namespaceShared3.ID,
|
||||||
Namespace: *n3,
|
Namespace: *namespaceShared3,
|
||||||
Registered: true,
|
Registered: true,
|
||||||
RegisterMethod: "authKey",
|
RegisterMethod: "authKey",
|
||||||
IPAddress: "100.64.0.3",
|
IPAddress: "100.64.0.3",
|
||||||
AuthKeyID: uint(pak3n3.ID),
|
AuthKeyID: uint(preAuthKeyShared3.ID),
|
||||||
}
|
}
|
||||||
h.db.Save(m3)
|
app.db.Save(machineInShared3)
|
||||||
|
|
||||||
_, err = h.GetMachine(n3.Name, m3.Name)
|
_, err = app.GetMachine(namespaceShared3.Name, machineInShared3.Name)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
m4 := &Machine{
|
machine2InShared1 := &Machine{
|
||||||
ID: 4,
|
ID: 4,
|
||||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
Name: "test_get_shared_nodes_4",
|
Name: "test_get_shared_nodes_4",
|
||||||
NamespaceID: n1.ID,
|
NamespaceID: namespaceShared1.ID,
|
||||||
Namespace: *n1,
|
Namespace: *namespaceShared1,
|
||||||
Registered: true,
|
Registered: true,
|
||||||
RegisterMethod: "authKey",
|
RegisterMethod: "authKey",
|
||||||
IPAddress: "100.64.0.4",
|
IPAddress: "100.64.0.4",
|
||||||
AuthKeyID: uint(pak4n1.ID),
|
AuthKeyID: uint(preAuthKey2Shared1.ID),
|
||||||
}
|
}
|
||||||
h.db.Save(m4)
|
app.db.Save(machine2InShared1)
|
||||||
|
|
||||||
err = h.AddSharedMachineToNamespace(m2, n1)
|
err = app.AddSharedMachineToNamespace(machineInShared2, namespaceShared1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
m1peers, err := h.getPeers(m1)
|
peersOfMachine1InShared1, err := app.getPeers(machineInShared1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
userProfiles := getMapResponseUserProfiles(*m1, m1peers)
|
userProfiles := getMapResponseUserProfiles(
|
||||||
|
*machineInShared1,
|
||||||
|
peersOfMachine1InShared1,
|
||||||
|
)
|
||||||
|
|
||||||
log.Trace().Msgf("userProfiles %#v", userProfiles)
|
log.Trace().Msgf("userProfiles %#v", userProfiles)
|
||||||
c.Assert(len(userProfiles), check.Equals, 2)
|
c.Assert(len(userProfiles), check.Equals, 2)
|
||||||
|
|
||||||
found := false
|
found := false
|
||||||
for _, up := range userProfiles {
|
for _, userProfiles := range userProfiles {
|
||||||
if up.DisplayName == n1.Name {
|
if userProfiles.DisplayName == namespaceShared1.Name {
|
||||||
found = true
|
found = true
|
||||||
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
c.Assert(found, check.Equals, true)
|
c.Assert(found, check.Equals, true)
|
||||||
|
|
||||||
found = false
|
found = false
|
||||||
for _, up := range userProfiles {
|
for _, userProfile := range userProfiles {
|
||||||
if up.DisplayName == n2.Name {
|
if userProfile.DisplayName == namespaceShared2.Name {
|
||||||
found = true
|
found = true
|
||||||
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
138
oidc.go
138
oidc.go
|
@ -17,6 +17,12 @@ import (
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
oidcStateCacheExpiration = time.Minute * 5
|
||||||
|
oidcStateCacheCleanupInterval = time.Minute * 10
|
||||||
|
randomByteSize = 16
|
||||||
|
)
|
||||||
|
|
||||||
type IDTokenClaims struct {
|
type IDTokenClaims struct {
|
||||||
Name string `json:"name,omitempty"`
|
Name string `json:"name,omitempty"`
|
||||||
Groups []string `json:"groups,omitempty"`
|
Groups []string `json:"groups,omitempty"`
|
||||||
|
@ -32,6 +38,7 @@ func (h *Headscale) initOIDC() error {
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error())
|
log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error())
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -39,14 +46,20 @@ func (h *Headscale) initOIDC() error {
|
||||||
ClientID: h.cfg.OIDC.ClientID,
|
ClientID: h.cfg.OIDC.ClientID,
|
||||||
ClientSecret: h.cfg.OIDC.ClientSecret,
|
ClientSecret: h.cfg.OIDC.ClientSecret,
|
||||||
Endpoint: h.oidcProvider.Endpoint(),
|
Endpoint: h.oidcProvider.Endpoint(),
|
||||||
RedirectURL: fmt.Sprintf("%s/oidc/callback", strings.TrimSuffix(h.cfg.ServerURL, "/")),
|
RedirectURL: fmt.Sprintf(
|
||||||
|
"%s/oidc/callback",
|
||||||
|
strings.TrimSuffix(h.cfg.ServerURL, "/"),
|
||||||
|
),
|
||||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// init the state cache if it hasn't been already
|
// init the state cache if it hasn't been already
|
||||||
if h.oidcStateCache == nil {
|
if h.oidcStateCache == nil {
|
||||||
h.oidcStateCache = cache.New(time.Minute*5, time.Minute*10)
|
h.oidcStateCache = cache.New(
|
||||||
|
oidcStateCacheExpiration,
|
||||||
|
oidcStateCacheCleanupInterval,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -54,50 +67,53 @@ func (h *Headscale) initOIDC() error {
|
||||||
|
|
||||||
// RegisterOIDC redirects to the OIDC provider for authentication
|
// RegisterOIDC redirects to the OIDC provider for authentication
|
||||||
// Puts machine key in cache so the callback can retrieve it using the oidc state param
|
// Puts machine key in cache so the callback can retrieve it using the oidc state param
|
||||||
// Listens in /oidc/register/:mKey
|
// Listens in /oidc/register/:mKey.
|
||||||
func (h *Headscale) RegisterOIDC(c *gin.Context) {
|
func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
|
||||||
mKeyStr := c.Param("mkey")
|
mKeyStr := ctx.Param("mkey")
|
||||||
if mKeyStr == "" {
|
if mKeyStr == "" {
|
||||||
c.String(http.StatusBadRequest, "Wrong params")
|
ctx.String(http.StatusBadRequest, "Wrong params")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
b := make([]byte, 16)
|
randomBlob := make([]byte, randomByteSize)
|
||||||
_, err := rand.Read(b)
|
if _, err := rand.Read(randomBlob); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Error().Msg("could not read 16 bytes from rand")
|
log.Error().Msg("could not read 16 bytes from rand")
|
||||||
c.String(http.StatusInternalServerError, "could not read 16 bytes from rand")
|
ctx.String(http.StatusInternalServerError, "could not read 16 bytes from rand")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
stateStr := hex.EncodeToString(b)[:32]
|
stateStr := hex.EncodeToString(randomBlob)[:32]
|
||||||
|
|
||||||
// place the machine key into the state cache, so it can be retrieved later
|
// place the machine key into the state cache, so it can be retrieved later
|
||||||
h.oidcStateCache.Set(stateStr, mKeyStr, time.Minute*5)
|
h.oidcStateCache.Set(stateStr, mKeyStr, oidcStateCacheExpiration)
|
||||||
|
|
||||||
authUrl := h.oauth2Config.AuthCodeURL(stateStr)
|
authURL := h.oauth2Config.AuthCodeURL(stateStr)
|
||||||
log.Debug().Msgf("Redirecting to %s for authentication", authUrl)
|
log.Debug().Msgf("Redirecting to %s for authentication", authURL)
|
||||||
|
|
||||||
c.Redirect(http.StatusFound, authUrl)
|
ctx.Redirect(http.StatusFound, authURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OIDCCallback handles the callback from the OIDC endpoint
|
// OIDCCallback handles the callback from the OIDC endpoint
|
||||||
// Retrieves the mkey from the state cache and adds the machine to the users email namespace
|
// Retrieves the mkey from the state cache and adds the machine to the users email namespace
|
||||||
// TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities
|
// TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities
|
||||||
// TODO: Add groups information from OIDC tokens into machine HostInfo
|
// TODO: Add groups information from OIDC tokens into machine HostInfo
|
||||||
// Listens in /oidc/callback
|
// Listens in /oidc/callback.
|
||||||
func (h *Headscale) OIDCCallback(c *gin.Context) {
|
func (h *Headscale) OIDCCallback(ctx *gin.Context) {
|
||||||
code := c.Query("code")
|
code := ctx.Query("code")
|
||||||
state := c.Query("state")
|
state := ctx.Query("state")
|
||||||
|
|
||||||
if code == "" || state == "" {
|
if code == "" || state == "" {
|
||||||
c.String(http.StatusBadRequest, "Wrong params")
|
ctx.String(http.StatusBadRequest, "Wrong params")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
oauth2Token, err := h.oauth2Config.Exchange(context.Background(), code)
|
oauth2Token, err := h.oauth2Config.Exchange(context.Background(), code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.String(http.StatusBadRequest, "Could not exchange code for token")
|
ctx.String(http.StatusBadRequest, "Could not exchange code for token")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -105,7 +121,8 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
||||||
|
|
||||||
rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string)
|
rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string)
|
||||||
if !rawIDTokenOK {
|
if !rawIDTokenOK {
|
||||||
c.String(http.StatusBadRequest, "Could not extract ID Token")
|
ctx.String(http.StatusBadRequest, "Could not extract ID Token")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -113,7 +130,8 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
||||||
|
|
||||||
idToken, err := verifier.Verify(context.Background(), rawIDToken)
|
idToken, err := verifier.Verify(context.Background(), rawIDToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.String(http.StatusBadRequest, "Failed to verify id token: %s", err.Error())
|
ctx.String(http.StatusBadRequest, "Failed to verify id token: %s", err.Error())
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -127,7 +145,11 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
||||||
// Extract custom claims
|
// Extract custom claims
|
||||||
var claims IDTokenClaims
|
var claims IDTokenClaims
|
||||||
if err = idToken.Claims(&claims); err != nil {
|
if err = idToken.Claims(&claims); err != nil {
|
||||||
c.String(http.StatusBadRequest, fmt.Sprintf("Failed to decode id token claims: %s", err))
|
ctx.String(
|
||||||
|
http.StatusBadRequest,
|
||||||
|
fmt.Sprintf("Failed to decode id token claims: %s", err),
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -135,62 +157,80 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
||||||
mKeyIf, mKeyFound := h.oidcStateCache.Get(state)
|
mKeyIf, mKeyFound := h.oidcStateCache.Get(state)
|
||||||
|
|
||||||
if !mKeyFound {
|
if !mKeyFound {
|
||||||
log.Error().Msg("requested machine state key expired before authorisation completed")
|
log.Error().
|
||||||
c.String(http.StatusBadRequest, "state has expired")
|
Msg("requested machine state key expired before authorisation completed")
|
||||||
|
ctx.String(http.StatusBadRequest, "state has expired")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
mKeyStr, mKeyOK := mKeyIf.(string)
|
mKeyStr, mKeyOK := mKeyIf.(string)
|
||||||
|
|
||||||
if !mKeyOK {
|
if !mKeyOK {
|
||||||
log.Error().Msg("could not get machine key from cache")
|
log.Error().Msg("could not get machine key from cache")
|
||||||
c.String(http.StatusInternalServerError, "could not get machine key from cache")
|
ctx.String(
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
"could not get machine key from cache",
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// retrieve machine information
|
// retrieve machine information
|
||||||
m, err := h.GetMachineByMachineKey(mKeyStr)
|
machine, err := h.GetMachineByMachineKey(mKeyStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Msg("machine key not found in database")
|
log.Error().Msg("machine key not found in database")
|
||||||
c.String(http.StatusInternalServerError, "could not get machine info from database")
|
ctx.String(
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
"could not get machine info from database",
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
|
|
||||||
if nsName, ok := h.getNamespaceFromEmail(claims.Email); ok {
|
if namespaceName, ok := h.getNamespaceFromEmail(claims.Email); ok {
|
||||||
// register the machine if it's new
|
// register the machine if it's new
|
||||||
if !m.Registered {
|
if !machine.Registered {
|
||||||
|
|
||||||
log.Debug().Msg("Registering new machine after successful callback")
|
log.Debug().Msg("Registering new machine after successful callback")
|
||||||
|
|
||||||
ns, err := h.GetNamespace(nsName)
|
namespace, err := h.GetNamespace(namespaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ns, err = h.CreateNamespace(nsName)
|
namespace, err = h.CreateNamespace(namespaceName)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Msgf("could not create new namespace '%s'", claims.Email)
|
log.Error().
|
||||||
c.String(http.StatusInternalServerError, "could not create new namespace")
|
Msgf("could not create new namespace '%s'", claims.Email)
|
||||||
|
ctx.String(
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
"could not create new namespace",
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ip, err := h.getAvailableIP()
|
ip, err := h.getAvailableIP()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.String(http.StatusInternalServerError, "could not get an IP from the pool")
|
ctx.String(
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
"could not get an IP from the pool",
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
m.IPAddress = ip.String()
|
machine.IPAddress = ip.String()
|
||||||
m.NamespaceID = ns.ID
|
machine.NamespaceID = namespace.ID
|
||||||
m.Registered = true
|
machine.Registered = true
|
||||||
m.RegisterMethod = "oidc"
|
machine.RegisterMethod = "oidc"
|
||||||
m.LastSuccessfulUpdate = &now
|
machine.LastSuccessfulUpdate = &now
|
||||||
h.db.Save(&m)
|
h.db.Save(&machine)
|
||||||
}
|
}
|
||||||
|
|
||||||
h.updateMachineExpiry(m)
|
h.updateMachineExpiry(machine)
|
||||||
|
|
||||||
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
|
ctx.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
|
||||||
<html>
|
<html>
|
||||||
<body>
|
<body>
|
||||||
<h1>headscale</h1>
|
<h1>headscale</h1>
|
||||||
|
@ -201,15 +241,17 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
||||||
</html>
|
</html>
|
||||||
|
|
||||||
`, claims.Email)))
|
`, claims.Email)))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("email", claims.Email).
|
Str("email", claims.Email).
|
||||||
Str("username", claims.Username).
|
Str("username", claims.Username).
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Email could not be mapped to a namespace")
|
Msg("Email could not be mapped to a namespace")
|
||||||
c.String(http.StatusBadRequest, "email from claim could not be mapped to a namespace")
|
ctx.String(
|
||||||
|
http.StatusBadRequest,
|
||||||
|
"email from claim could not be mapped to a namespace",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getNamespaceFromEmail passes the users email through a list of "matchers"
|
// getNamespaceFromEmail passes the users email through a list of "matchers"
|
||||||
|
|
50
oidc_test.go
50
oidc_test.go
|
@ -145,29 +145,37 @@ func TestHeadscale_getNamespaceFromEmail(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
//nolint
|
//nolint
|
||||||
for _, tt := range tests {
|
for _, test := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
h := &Headscale{
|
app := &Headscale{
|
||||||
cfg: tt.fields.cfg,
|
cfg: test.fields.cfg,
|
||||||
db: tt.fields.db,
|
db: test.fields.db,
|
||||||
dbString: tt.fields.dbString,
|
dbString: test.fields.dbString,
|
||||||
dbType: tt.fields.dbType,
|
dbType: test.fields.dbType,
|
||||||
dbDebug: tt.fields.dbDebug,
|
dbDebug: test.fields.dbDebug,
|
||||||
publicKey: tt.fields.publicKey,
|
publicKey: test.fields.publicKey,
|
||||||
privateKey: tt.fields.privateKey,
|
privateKey: test.fields.privateKey,
|
||||||
aclPolicy: tt.fields.aclPolicy,
|
aclPolicy: test.fields.aclPolicy,
|
||||||
aclRules: tt.fields.aclRules,
|
aclRules: test.fields.aclRules,
|
||||||
lastStateChange: tt.fields.lastStateChange,
|
lastStateChange: test.fields.lastStateChange,
|
||||||
oidcProvider: tt.fields.oidcProvider,
|
oidcProvider: test.fields.oidcProvider,
|
||||||
oauth2Config: tt.fields.oauth2Config,
|
oauth2Config: test.fields.oauth2Config,
|
||||||
oidcStateCache: tt.fields.oidcStateCache,
|
oidcStateCache: test.fields.oidcStateCache,
|
||||||
}
|
}
|
||||||
got, got1 := h.getNamespaceFromEmail(tt.args.email)
|
got, got1 := app.getNamespaceFromEmail(test.args.email)
|
||||||
if got != tt.want {
|
if got != test.want {
|
||||||
t.Errorf("Headscale.getNamespaceFromEmail() got = %v, want %v", got, tt.want)
|
t.Errorf(
|
||||||
|
"Headscale.getNamespaceFromEmail() got = %v, want %v",
|
||||||
|
got,
|
||||||
|
test.want,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
if got1 != tt.want1 {
|
if got1 != test.want1 {
|
||||||
t.Errorf("Headscale.getNamespaceFromEmail() got1 = %v, want %v", got1, tt.want1)
|
t.Errorf(
|
||||||
|
"Headscale.getNamespaceFromEmail() got1 = %v, want %v",
|
||||||
|
got1,
|
||||||
|
test.want1,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
263
poll.go
263
poll.go
|
@ -15,6 +15,11 @@ import (
|
||||||
"tailscale.com/types/wgkey"
|
"tailscale.com/types/wgkey"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
keepAliveInterval = 60 * time.Second
|
||||||
|
updateCheckInterval = 10 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
// PollNetMapHandler takes care of /machine/:id/map
|
// PollNetMapHandler takes care of /machine/:id/map
|
||||||
//
|
//
|
||||||
// This is the busiest endpoint, as it keeps the HTTP long poll that updates
|
// This is the busiest endpoint, as it keeps the HTTP long poll that updates
|
||||||
|
@ -24,20 +29,21 @@ import (
|
||||||
// only after their first request (marked with the ReadOnly field).
|
// only after their first request (marked with the ReadOnly field).
|
||||||
//
|
//
|
||||||
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
|
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
|
||||||
func (h *Headscale) PollNetMapHandler(c *gin.Context) {
|
func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("handler", "PollNetMap").
|
Str("handler", "PollNetMap").
|
||||||
Str("id", c.Param("id")).
|
Str("id", ctx.Param("id")).
|
||||||
Msg("PollNetMapHandler called")
|
Msg("PollNetMapHandler called")
|
||||||
body, _ := io.ReadAll(c.Request.Body)
|
body, _ := io.ReadAll(ctx.Request.Body)
|
||||||
mKeyStr := c.Param("id")
|
mKeyStr := ctx.Param("id")
|
||||||
mKey, err := wgkey.ParseHex(mKeyStr)
|
mKey, err := wgkey.ParseHex(mKeyStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "PollNetMap").
|
Str("handler", "PollNetMap").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot parse client key")
|
Msg("Cannot parse client key")
|
||||||
c.String(http.StatusBadRequest, "")
|
ctx.String(http.StatusBadRequest, "")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
req := tailcfg.MapRequest{}
|
req := tailcfg.MapRequest{}
|
||||||
|
@ -47,34 +53,36 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
|
||||||
Str("handler", "PollNetMap").
|
Str("handler", "PollNetMap").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot decode message")
|
Msg("Cannot decode message")
|
||||||
c.String(http.StatusBadRequest, "")
|
ctx.String(http.StatusBadRequest, "")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := h.GetMachineByMachineKey(mKey.HexString())
|
machine, err := h.GetMachineByMachineKey(mKey.HexString())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
log.Warn().
|
log.Warn().
|
||||||
Str("handler", "PollNetMap").
|
Str("handler", "PollNetMap").
|
||||||
Msgf("Ignoring request, cannot find machine with key %s", mKey.HexString())
|
Msgf("Ignoring request, cannot find machine with key %s", mKey.HexString())
|
||||||
c.String(http.StatusUnauthorized, "")
|
ctx.String(http.StatusUnauthorized, "")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "PollNetMap").
|
Str("handler", "PollNetMap").
|
||||||
Msgf("Failed to fetch machine from the database with Machine key: %s", mKey.HexString())
|
Msgf("Failed to fetch machine from the database with Machine key: %s", mKey.HexString())
|
||||||
c.String(http.StatusInternalServerError, "")
|
ctx.String(http.StatusInternalServerError, "")
|
||||||
}
|
}
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("handler", "PollNetMap").
|
Str("handler", "PollNetMap").
|
||||||
Str("id", c.Param("id")).
|
Str("id", ctx.Param("id")).
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Found machine in database")
|
Msg("Found machine in database")
|
||||||
|
|
||||||
hostinfo, _ := json.Marshal(req.Hostinfo)
|
hostinfo, _ := json.Marshal(req.Hostinfo)
|
||||||
m.Name = req.Hostinfo.Hostname
|
machine.Name = req.Hostinfo.Hostname
|
||||||
m.HostInfo = datatypes.JSON(hostinfo)
|
machine.HostInfo = datatypes.JSON(hostinfo)
|
||||||
m.DiscoKey = wgkey.Key(req.DiscoKey).HexString()
|
machine.DiscoKey = wgkey.Key(req.DiscoKey).HexString()
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
|
|
||||||
// From Tailscale client:
|
// From Tailscale client:
|
||||||
|
@ -87,20 +95,21 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
|
||||||
// before their first real endpoint update.
|
// before their first real endpoint update.
|
||||||
if !req.ReadOnly {
|
if !req.ReadOnly {
|
||||||
endpoints, _ := json.Marshal(req.Endpoints)
|
endpoints, _ := json.Marshal(req.Endpoints)
|
||||||
m.Endpoints = datatypes.JSON(endpoints)
|
machine.Endpoints = datatypes.JSON(endpoints)
|
||||||
m.LastSeen = &now
|
machine.LastSeen = &now
|
||||||
}
|
}
|
||||||
h.db.Save(&m)
|
h.db.Save(&machine)
|
||||||
|
|
||||||
data, err := h.getMapResponse(mKey, req, m)
|
data, err := h.getMapResponse(mKey, req, machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "PollNetMap").
|
Str("handler", "PollNetMap").
|
||||||
Str("id", c.Param("id")).
|
Str("id", ctx.Param("id")).
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Failed to get Map response")
|
Msg("Failed to get Map response")
|
||||||
c.String(http.StatusInternalServerError, ":(")
|
ctx.String(http.StatusInternalServerError, ":(")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -111,8 +120,8 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
|
||||||
// Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696
|
// Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Str("handler", "PollNetMap").
|
Str("handler", "PollNetMap").
|
||||||
Str("id", c.Param("id")).
|
Str("id", ctx.Param("id")).
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Bool("readOnly", req.ReadOnly).
|
Bool("readOnly", req.ReadOnly).
|
||||||
Bool("omitPeers", req.OmitPeers).
|
Bool("omitPeers", req.OmitPeers).
|
||||||
Bool("stream", req.Stream).
|
Bool("stream", req.Stream).
|
||||||
|
@ -121,15 +130,16 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
|
||||||
if req.ReadOnly {
|
if req.ReadOnly {
|
||||||
log.Info().
|
log.Info().
|
||||||
Str("handler", "PollNetMap").
|
Str("handler", "PollNetMap").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Client is starting up. Probably interested in a DERP map")
|
Msg("Client is starting up. Probably interested in a DERP map")
|
||||||
c.Data(200, "application/json; charset=utf-8", data)
|
ctx.Data(http.StatusOK, "application/json; charset=utf-8", data)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// There has been an update to _any_ of the nodes that the other nodes would
|
// There has been an update to _any_ of the nodes that the other nodes would
|
||||||
// need to know about
|
// need to know about
|
||||||
h.setLastStateChangeToNow(m.Namespace.Name)
|
h.setLastStateChangeToNow(machine.Namespace.Name)
|
||||||
|
|
||||||
// The request is not ReadOnly, so we need to set up channels for updating
|
// The request is not ReadOnly, so we need to set up channels for updating
|
||||||
// peers via longpoll
|
// peers via longpoll
|
||||||
|
@ -137,8 +147,8 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
|
||||||
// Only create update channel if it has not been created
|
// Only create update channel if it has not been created
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("handler", "PollNetMap").
|
Str("handler", "PollNetMap").
|
||||||
Str("id", c.Param("id")).
|
Str("id", ctx.Param("id")).
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Loading or creating update channel")
|
Msg("Loading or creating update channel")
|
||||||
updateChan := make(chan struct{})
|
updateChan := make(chan struct{})
|
||||||
|
|
||||||
|
@ -152,46 +162,59 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
|
||||||
if req.OmitPeers && !req.Stream {
|
if req.OmitPeers && !req.Stream {
|
||||||
log.Info().
|
log.Info().
|
||||||
Str("handler", "PollNetMap").
|
Str("handler", "PollNetMap").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Client sent endpoint update and is ok with a response without peer list")
|
Msg("Client sent endpoint update and is ok with a response without peer list")
|
||||||
c.Data(200, "application/json; charset=utf-8", data)
|
ctx.Data(http.StatusOK, "application/json; charset=utf-8", data)
|
||||||
|
|
||||||
// It sounds like we should update the nodes when we have received a endpoint update
|
// It sounds like we should update the nodes when we have received a endpoint update
|
||||||
// even tho the comments in the tailscale code dont explicitly say so.
|
// even tho the comments in the tailscale code dont explicitly say so.
|
||||||
updateRequestsFromNode.WithLabelValues(m.Name, m.Namespace.Name, "endpoint-update").Inc()
|
updateRequestsFromNode.WithLabelValues(machine.Name, machine.Namespace.Name, "endpoint-update").
|
||||||
|
Inc()
|
||||||
go func() { updateChan <- struct{}{} }()
|
go func() { updateChan <- struct{}{} }()
|
||||||
|
|
||||||
return
|
return
|
||||||
} else if req.OmitPeers && req.Stream {
|
} else if req.OmitPeers && req.Stream {
|
||||||
log.Warn().
|
log.Warn().
|
||||||
Str("handler", "PollNetMap").
|
Str("handler", "PollNetMap").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Ignoring request, don't know how to handle it")
|
Msg("Ignoring request, don't know how to handle it")
|
||||||
c.String(http.StatusBadRequest, "")
|
ctx.String(http.StatusBadRequest, "")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info().
|
log.Info().
|
||||||
Str("handler", "PollNetMap").
|
Str("handler", "PollNetMap").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Client is ready to access the tailnet")
|
Msg("Client is ready to access the tailnet")
|
||||||
log.Info().
|
log.Info().
|
||||||
Str("handler", "PollNetMap").
|
Str("handler", "PollNetMap").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Sending initial map")
|
Msg("Sending initial map")
|
||||||
go func() { pollDataChan <- data }()
|
go func() { pollDataChan <- data }()
|
||||||
|
|
||||||
log.Info().
|
log.Info().
|
||||||
Str("handler", "PollNetMap").
|
Str("handler", "PollNetMap").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Notifying peers")
|
Msg("Notifying peers")
|
||||||
updateRequestsFromNode.WithLabelValues(m.Name, m.Namespace.Name, "full-update").Inc()
|
updateRequestsFromNode.WithLabelValues(machine.Name, machine.Namespace.Name, "full-update").
|
||||||
|
Inc()
|
||||||
go func() { updateChan <- struct{}{} }()
|
go func() { updateChan <- struct{}{} }()
|
||||||
|
|
||||||
h.PollNetMapStream(c, m, req, mKey, pollDataChan, keepAliveChan, updateChan, cancelKeepAlive)
|
h.PollNetMapStream(
|
||||||
|
ctx,
|
||||||
|
machine,
|
||||||
|
req,
|
||||||
|
mKey,
|
||||||
|
pollDataChan,
|
||||||
|
keepAliveChan,
|
||||||
|
updateChan,
|
||||||
|
cancelKeepAlive,
|
||||||
|
)
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("handler", "PollNetMap").
|
Str("handler", "PollNetMap").
|
||||||
Str("id", c.Param("id")).
|
Str("id", ctx.Param("id")).
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Finished stream, closing PollNetMap session")
|
Msg("Finished stream, closing PollNetMap session")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -199,165 +222,181 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
|
||||||
// stream logic, ensuring we communicate updates and data
|
// stream logic, ensuring we communicate updates and data
|
||||||
// to the connected clients.
|
// to the connected clients.
|
||||||
func (h *Headscale) PollNetMapStream(
|
func (h *Headscale) PollNetMapStream(
|
||||||
c *gin.Context,
|
ctx *gin.Context,
|
||||||
m *Machine,
|
machine *Machine,
|
||||||
req tailcfg.MapRequest,
|
mapRequest tailcfg.MapRequest,
|
||||||
mKey wgkey.Key,
|
machineKey wgkey.Key,
|
||||||
pollDataChan chan []byte,
|
pollDataChan chan []byte,
|
||||||
keepAliveChan chan []byte,
|
keepAliveChan chan []byte,
|
||||||
updateChan chan struct{},
|
updateChan chan struct{},
|
||||||
cancelKeepAlive chan struct{},
|
cancelKeepAlive chan struct{},
|
||||||
) {
|
) {
|
||||||
go h.scheduledPollWorker(cancelKeepAlive, updateChan, keepAliveChan, mKey, req, m)
|
go h.scheduledPollWorker(
|
||||||
|
cancelKeepAlive,
|
||||||
|
updateChan,
|
||||||
|
keepAliveChan,
|
||||||
|
machineKey,
|
||||||
|
mapRequest,
|
||||||
|
machine,
|
||||||
|
)
|
||||||
|
|
||||||
c.Stream(func(w io.Writer) bool {
|
ctx.Stream(func(writer io.Writer) bool {
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Waiting for data to stream...")
|
Msg("Waiting for data to stream...")
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msgf("pollData is %#v, keepAliveChan is %#v, updateChan is %#v", pollDataChan, keepAliveChan, updateChan)
|
Msgf("pollData is %#v, keepAliveChan is %#v, updateChan is %#v", pollDataChan, keepAliveChan, updateChan)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case data := <-pollDataChan:
|
case data := <-pollDataChan:
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Str("channel", "pollData").
|
Str("channel", "pollData").
|
||||||
Int("bytes", len(data)).
|
Int("bytes", len(data)).
|
||||||
Msg("Sending data received via pollData channel")
|
Msg("Sending data received via pollData channel")
|
||||||
_, err := w.Write(data)
|
_, err := writer.Write(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Str("channel", "pollData").
|
Str("channel", "pollData").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot write data")
|
Msg("Cannot write data")
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Str("channel", "pollData").
|
Str("channel", "pollData").
|
||||||
Int("bytes", len(data)).
|
Int("bytes", len(data)).
|
||||||
Msg("Data from pollData channel written successfully")
|
Msg("Data from pollData channel written successfully")
|
||||||
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
|
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
|
||||||
// when an outdated machine object is kept alive, e.g. db is update from
|
// when an outdated machine object is kept alive, e.g. db is update from
|
||||||
// command line, but then overwritten.
|
// command line, but then overwritten.
|
||||||
err = h.UpdateMachine(m)
|
err = h.UpdateMachine(machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Str("channel", "pollData").
|
Str("channel", "pollData").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot update machine from database")
|
Msg("Cannot update machine from database")
|
||||||
}
|
}
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
m.LastSeen = &now
|
machine.LastSeen = &now
|
||||||
|
|
||||||
lastStateUpdate.WithLabelValues(m.Namespace.Name, m.Name).Set(float64(now.Unix()))
|
lastStateUpdate.WithLabelValues(machine.Namespace.Name, machine.Name).
|
||||||
m.LastSuccessfulUpdate = &now
|
Set(float64(now.Unix()))
|
||||||
|
machine.LastSuccessfulUpdate = &now
|
||||||
|
|
||||||
h.db.Save(&m)
|
h.db.Save(&machine)
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Str("channel", "pollData").
|
Str("channel", "pollData").
|
||||||
Int("bytes", len(data)).
|
Int("bytes", len(data)).
|
||||||
Msg("Machine entry in database updated successfully after sending pollData")
|
Msg("Machine entry in database updated successfully after sending pollData")
|
||||||
|
|
||||||
return true
|
return true
|
||||||
|
|
||||||
case data := <-keepAliveChan:
|
case data := <-keepAliveChan:
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Str("channel", "keepAlive").
|
Str("channel", "keepAlive").
|
||||||
Int("bytes", len(data)).
|
Int("bytes", len(data)).
|
||||||
Msg("Sending keep alive message")
|
Msg("Sending keep alive message")
|
||||||
_, err := w.Write(data)
|
_, err := writer.Write(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Str("channel", "keepAlive").
|
Str("channel", "keepAlive").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot write keep alive message")
|
Msg("Cannot write keep alive message")
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Str("channel", "keepAlive").
|
Str("channel", "keepAlive").
|
||||||
Int("bytes", len(data)).
|
Int("bytes", len(data)).
|
||||||
Msg("Keep alive sent successfully")
|
Msg("Keep alive sent successfully")
|
||||||
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
|
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
|
||||||
// when an outdated machine object is kept alive, e.g. db is update from
|
// when an outdated machine object is kept alive, e.g. db is update from
|
||||||
// command line, but then overwritten.
|
// command line, but then overwritten.
|
||||||
err = h.UpdateMachine(m)
|
err = h.UpdateMachine(machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Str("channel", "keepAlive").
|
Str("channel", "keepAlive").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot update machine from database")
|
Msg("Cannot update machine from database")
|
||||||
}
|
}
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
m.LastSeen = &now
|
machine.LastSeen = &now
|
||||||
h.db.Save(&m)
|
h.db.Save(&machine)
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Str("channel", "keepAlive").
|
Str("channel", "keepAlive").
|
||||||
Int("bytes", len(data)).
|
Int("bytes", len(data)).
|
||||||
Msg("Machine updated successfully after sending keep alive")
|
Msg("Machine updated successfully after sending keep alive")
|
||||||
|
|
||||||
return true
|
return true
|
||||||
|
|
||||||
case <-updateChan:
|
case <-updateChan:
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Str("channel", "update").
|
Str("channel", "update").
|
||||||
Msg("Received a request for update")
|
Msg("Received a request for update")
|
||||||
updateRequestsReceivedOnChannel.WithLabelValues(m.Name, m.Namespace.Name).Inc()
|
updateRequestsReceivedOnChannel.WithLabelValues(machine.Name, machine.Namespace.Name).
|
||||||
if h.isOutdated(m) {
|
Inc()
|
||||||
|
if h.isOutdated(machine) {
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Time("last_successful_update", *m.LastSuccessfulUpdate).
|
Time("last_successful_update", *machine.LastSuccessfulUpdate).
|
||||||
Time("last_state_change", h.getLastStateChange(m.Namespace.Name)).
|
Time("last_state_change", h.getLastStateChange(machine.Namespace.Name)).
|
||||||
Msgf("There has been updates since the last successful update to %s", m.Name)
|
Msgf("There has been updates since the last successful update to %s", machine.Name)
|
||||||
data, err := h.getMapResponse(mKey, req, m)
|
data, err := h.getMapResponse(machineKey, mapRequest, machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Str("channel", "update").
|
Str("channel", "update").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Could not get the map update")
|
Msg("Could not get the map update")
|
||||||
}
|
}
|
||||||
_, err = w.Write(data)
|
_, err = writer.Write(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Str("channel", "update").
|
Str("channel", "update").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Could not write the map response")
|
Msg("Could not write the map response")
|
||||||
updateRequestsSentToNode.WithLabelValues(m.Name, m.Namespace.Name, "failed").Inc()
|
updateRequestsSentToNode.WithLabelValues(machine.Name, machine.Namespace.Name, "failed").
|
||||||
|
Inc()
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Str("channel", "update").
|
Str("channel", "update").
|
||||||
Msg("Updated Map has been sent")
|
Msg("Updated Map has been sent")
|
||||||
updateRequestsSentToNode.WithLabelValues(m.Name, m.Namespace.Name, "success").Inc()
|
updateRequestsSentToNode.WithLabelValues(machine.Name, machine.Namespace.Name, "success").
|
||||||
|
Inc()
|
||||||
|
|
||||||
// Keep track of the last successful update,
|
// Keep track of the last successful update,
|
||||||
// we sometimes end in a state were the update
|
// we sometimes end in a state were the update
|
||||||
|
@ -366,62 +405,64 @@ func (h *Headscale) PollNetMapStream(
|
||||||
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
|
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
|
||||||
// when an outdated machine object is kept alive, e.g. db is update from
|
// when an outdated machine object is kept alive, e.g. db is update from
|
||||||
// command line, but then overwritten.
|
// command line, but then overwritten.
|
||||||
err = h.UpdateMachine(m)
|
err = h.UpdateMachine(machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Str("channel", "update").
|
Str("channel", "update").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot update machine from database")
|
Msg("Cannot update machine from database")
|
||||||
}
|
}
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
|
|
||||||
lastStateUpdate.WithLabelValues(m.Namespace.Name, m.Name).Set(float64(now.Unix()))
|
lastStateUpdate.WithLabelValues(machine.Namespace.Name, machine.Name).
|
||||||
m.LastSuccessfulUpdate = &now
|
Set(float64(now.Unix()))
|
||||||
|
machine.LastSuccessfulUpdate = &now
|
||||||
|
|
||||||
h.db.Save(&m)
|
h.db.Save(&machine)
|
||||||
} else {
|
} else {
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Time("last_successful_update", *m.LastSuccessfulUpdate).
|
Time("last_successful_update", *machine.LastSuccessfulUpdate).
|
||||||
Time("last_state_change", h.getLastStateChange(m.Namespace.Name)).
|
Time("last_state_change", h.getLastStateChange(machine.Namespace.Name)).
|
||||||
Msgf("%s is up to date", m.Name)
|
Msgf("%s is up to date", machine.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
|
|
||||||
case <-c.Request.Context().Done():
|
case <-ctx.Request.Context().Done():
|
||||||
log.Info().
|
log.Info().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("The client has closed the connection")
|
Msg("The client has closed the connection")
|
||||||
// TODO: Abstract away all the database calls, this can cause race conditions
|
// TODO: Abstract away all the database calls, this can cause race conditions
|
||||||
// when an outdated machine object is kept alive, e.g. db is update from
|
// when an outdated machine object is kept alive, e.g. db is update from
|
||||||
// command line, but then overwritten.
|
// command line, but then overwritten.
|
||||||
err := h.UpdateMachine(m)
|
err := h.UpdateMachine(machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Str("channel", "Done").
|
Str("channel", "Done").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot update machine from database")
|
Msg("Cannot update machine from database")
|
||||||
}
|
}
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
m.LastSeen = &now
|
machine.LastSeen = &now
|
||||||
h.db.Save(&m)
|
h.db.Save(&machine)
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Str("channel", "Done").
|
Str("channel", "Done").
|
||||||
Msg("Cancelling keepAlive channel")
|
Msg("Cancelling keepAlive channel")
|
||||||
cancelKeepAlive <- struct{}{}
|
cancelKeepAlive <- struct{}{}
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Str("channel", "Done").
|
Str("channel", "Done").
|
||||||
Msg("Closing update channel")
|
Msg("Closing update channel")
|
||||||
// h.closeUpdateChannel(m)
|
// h.closeUpdateChannel(m)
|
||||||
|
@ -429,14 +470,14 @@ func (h *Headscale) PollNetMapStream(
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Str("channel", "Done").
|
Str("channel", "Done").
|
||||||
Msg("Closing pollData channel")
|
Msg("Closing pollData channel")
|
||||||
close(pollDataChan)
|
close(pollDataChan)
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Str("channel", "Done").
|
Str("channel", "Done").
|
||||||
Msg("Closing keepAliveChan channel")
|
Msg("Closing keepAliveChan channel")
|
||||||
close(keepAliveChan)
|
close(keepAliveChan)
|
||||||
|
@ -450,12 +491,12 @@ func (h *Headscale) scheduledPollWorker(
|
||||||
cancelChan <-chan struct{},
|
cancelChan <-chan struct{},
|
||||||
updateChan chan<- struct{},
|
updateChan chan<- struct{},
|
||||||
keepAliveChan chan<- []byte,
|
keepAliveChan chan<- []byte,
|
||||||
mKey wgkey.Key,
|
machineKey wgkey.Key,
|
||||||
req tailcfg.MapRequest,
|
mapRequest tailcfg.MapRequest,
|
||||||
m *Machine,
|
machine *Machine,
|
||||||
) {
|
) {
|
||||||
keepAliveTicker := time.NewTicker(60 * time.Second)
|
keepAliveTicker := time.NewTicker(keepAliveInterval)
|
||||||
updateCheckerTicker := time.NewTicker(10 * time.Second)
|
updateCheckerTicker := time.NewTicker(updateCheckInterval)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
|
@ -463,27 +504,29 @@ func (h *Headscale) scheduledPollWorker(
|
||||||
return
|
return
|
||||||
|
|
||||||
case <-keepAliveTicker.C:
|
case <-keepAliveTicker.C:
|
||||||
data, err := h.getMapKeepAliveResponse(mKey, req, m)
|
data, err := h.getMapKeepAliveResponse(machineKey, mapRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("func", "keepAlive").
|
Str("func", "keepAlive").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Error generating the keep alive msg")
|
Msg("Error generating the keep alive msg")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Str("func", "keepAlive").
|
Str("func", "keepAlive").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Sending keepalive")
|
Msg("Sending keepalive")
|
||||||
keepAliveChan <- data
|
keepAliveChan <- data
|
||||||
|
|
||||||
case <-updateCheckerTicker.C:
|
case <-updateCheckerTicker.C:
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Str("func", "scheduledPollWorker").
|
Str("func", "scheduledPollWorker").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Sending update request")
|
Msg("Sending update request")
|
||||||
updateRequestsFromNode.WithLabelValues(m.Name, m.Namespace.Name, "scheduled-update").Inc()
|
updateRequestsFromNode.WithLabelValues(machine.Name, machine.Namespace.Name, "scheduled-update").
|
||||||
|
Inc()
|
||||||
updateChan <- struct{}{}
|
updateChan <- struct{}{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,19 +7,19 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
errorAuthKeyNotFound = Error("AuthKey not found")
|
errPreAuthKeyNotFound = Error("AuthKey not found")
|
||||||
errorAuthKeyExpired = Error("AuthKey expired")
|
errPreAuthKeyExpired = Error("AuthKey expired")
|
||||||
errSingleUseAuthKeyHasBeenUsed = Error("AuthKey has already been used")
|
errSingleUseAuthKeyHasBeenUsed = Error("AuthKey has already been used")
|
||||||
|
errNamespaceMismatch = Error("namespace mismatch")
|
||||||
)
|
)
|
||||||
|
|
||||||
// PreAuthKey describes a pre-authorization key usable in a particular namespace
|
// PreAuthKey describes a pre-authorization key usable in a particular namespace.
|
||||||
type PreAuthKey struct {
|
type PreAuthKey struct {
|
||||||
ID uint64 `gorm:"primary_key"`
|
ID uint64 `gorm:"primary_key"`
|
||||||
Key string
|
Key string
|
||||||
|
@ -33,14 +33,14 @@ type PreAuthKey struct {
|
||||||
Expiration *time.Time
|
Expiration *time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreatePreAuthKey creates a new PreAuthKey in a namespace, and returns it
|
// CreatePreAuthKey creates a new PreAuthKey in a namespace, and returns it.
|
||||||
func (h *Headscale) CreatePreAuthKey(
|
func (h *Headscale) CreatePreAuthKey(
|
||||||
namespaceName string,
|
namespaceName string,
|
||||||
reusable bool,
|
reusable bool,
|
||||||
ephemeral bool,
|
ephemeral bool,
|
||||||
expiration *time.Time,
|
expiration *time.Time,
|
||||||
) (*PreAuthKey, error) {
|
) (*PreAuthKey, error) {
|
||||||
n, err := h.GetNamespace(namespaceName)
|
namespace, err := h.GetNamespace(namespaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -51,35 +51,36 @@ func (h *Headscale) CreatePreAuthKey(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
k := PreAuthKey{
|
key := PreAuthKey{
|
||||||
Key: kstr,
|
Key: kstr,
|
||||||
NamespaceID: n.ID,
|
NamespaceID: namespace.ID,
|
||||||
Namespace: *n,
|
Namespace: *namespace,
|
||||||
Reusable: reusable,
|
Reusable: reusable,
|
||||||
Ephemeral: ephemeral,
|
Ephemeral: ephemeral,
|
||||||
CreatedAt: &now,
|
CreatedAt: &now,
|
||||||
Expiration: expiration,
|
Expiration: expiration,
|
||||||
}
|
}
|
||||||
h.db.Save(&k)
|
h.db.Save(&key)
|
||||||
|
|
||||||
return &k, nil
|
return &key, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListPreAuthKeys returns the list of PreAuthKeys for a namespace
|
// ListPreAuthKeys returns the list of PreAuthKeys for a namespace.
|
||||||
func (h *Headscale) ListPreAuthKeys(namespaceName string) ([]PreAuthKey, error) {
|
func (h *Headscale) ListPreAuthKeys(namespaceName string) ([]PreAuthKey, error) {
|
||||||
n, err := h.GetNamespace(namespaceName)
|
namespace, err := h.GetNamespace(namespaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
keys := []PreAuthKey{}
|
keys := []PreAuthKey{}
|
||||||
if err := h.db.Preload("Namespace").Where(&PreAuthKey{NamespaceID: n.ID}).Find(&keys).Error; err != nil {
|
if err := h.db.Preload("Namespace").Where(&PreAuthKey{NamespaceID: namespace.ID}).Find(&keys).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return keys, nil
|
return keys, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPreAuthKey returns a PreAuthKey for a given key
|
// GetPreAuthKey returns a PreAuthKey for a given key.
|
||||||
func (h *Headscale) GetPreAuthKey(namespace string, key string) (*PreAuthKey, error) {
|
func (h *Headscale) GetPreAuthKey(namespace string, key string) (*PreAuthKey, error) {
|
||||||
pak, err := h.checkKeyValidity(key)
|
pak, err := h.checkKeyValidity(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -87,7 +88,7 @@ func (h *Headscale) GetPreAuthKey(namespace string, key string) (*PreAuthKey, er
|
||||||
}
|
}
|
||||||
|
|
||||||
if pak.Namespace.Name != namespace {
|
if pak.Namespace.Name != namespace {
|
||||||
return nil, errors.New("Namespace mismatch")
|
return nil, errNamespaceMismatch
|
||||||
}
|
}
|
||||||
|
|
||||||
return pak, nil
|
return pak, nil
|
||||||
|
@ -95,32 +96,36 @@ func (h *Headscale) GetPreAuthKey(namespace string, key string) (*PreAuthKey, er
|
||||||
|
|
||||||
// DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey
|
// DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey
|
||||||
// does not exist.
|
// does not exist.
|
||||||
func (h *Headscale) DestroyPreAuthKey(pak *PreAuthKey) error {
|
func (h *Headscale) DestroyPreAuthKey(pak PreAuthKey) error {
|
||||||
if result := h.db.Unscoped().Delete(&pak); result.Error != nil {
|
if result := h.db.Unscoped().Delete(pak); result.Error != nil {
|
||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarkExpirePreAuthKey marks a PreAuthKey as expired
|
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
|
||||||
func (h *Headscale) ExpirePreAuthKey(k *PreAuthKey) error {
|
func (h *Headscale) ExpirePreAuthKey(k *PreAuthKey) error {
|
||||||
if err := h.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
|
if err := h.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkKeyValidity does the heavy lifting for validation of the PreAuthKey coming from a node
|
// checkKeyValidity does the heavy lifting for validation of the PreAuthKey coming from a node
|
||||||
// If returns no error and a PreAuthKey, it can be used
|
// If returns no error and a PreAuthKey, it can be used.
|
||||||
func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) {
|
func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) {
|
||||||
pak := PreAuthKey{}
|
pak := PreAuthKey{}
|
||||||
if result := h.db.Preload("Namespace").First(&pak, "key = ?", k); errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if result := h.db.Preload("Namespace").First(&pak, "key = ?", k); errors.Is(
|
||||||
return nil, errorAuthKeyNotFound
|
result.Error,
|
||||||
|
gorm.ErrRecordNotFound,
|
||||||
|
) {
|
||||||
|
return nil, errPreAuthKeyNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
if pak.Expiration != nil && pak.Expiration.Before(time.Now()) {
|
if pak.Expiration != nil && pak.Expiration.Before(time.Now()) {
|
||||||
return nil, errorAuthKeyExpired
|
return nil, errPreAuthKeyExpired
|
||||||
}
|
}
|
||||||
|
|
||||||
if pak.Reusable || pak.Ephemeral { // we don't need to check if has been used before
|
if pak.Reusable || pak.Ephemeral { // we don't need to check if has been used before
|
||||||
|
@ -145,13 +150,14 @@ func (h *Headscale) generateKey() (string, error) {
|
||||||
if _, err := rand.Read(bytes); err != nil {
|
if _, err := rand.Read(bytes); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
return hex.EncodeToString(bytes), nil
|
return hex.EncodeToString(bytes), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (key *PreAuthKey) toProto() *v1.PreAuthKey {
|
func (key *PreAuthKey) toProto() *v1.PreAuthKey {
|
||||||
protoKey := v1.PreAuthKey{
|
protoKey := v1.PreAuthKey{
|
||||||
Namespace: key.Namespace.Name,
|
Namespace: key.Namespace.Name,
|
||||||
Id: strconv.FormatUint(key.ID, 10),
|
Id: strconv.FormatUint(key.ID, Base10),
|
||||||
Key: key.Key,
|
Key: key.Key,
|
||||||
Ephemeral: key.Ephemeral,
|
Ephemeral: key.Ephemeral,
|
||||||
Reusable: key.Reusable,
|
Reusable: key.Reusable,
|
||||||
|
|
|
@ -7,189 +7,189 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (*Suite) TestCreatePreAuthKey(c *check.C) {
|
func (*Suite) TestCreatePreAuthKey(c *check.C) {
|
||||||
_, err := h.CreatePreAuthKey("bogus", true, false, nil)
|
_, err := app.CreatePreAuthKey("bogus", true, false, nil)
|
||||||
|
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
n, err := h.CreateNamespace("test")
|
namespace, err := app.CreateNamespace("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
k, err := h.CreatePreAuthKey(n.Name, true, false, nil)
|
key, err := app.CreatePreAuthKey(namespace.Name, true, false, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
// Did we get a valid key?
|
// Did we get a valid key?
|
||||||
c.Assert(k.Key, check.NotNil)
|
c.Assert(key.Key, check.NotNil)
|
||||||
c.Assert(len(k.Key), check.Equals, 48)
|
c.Assert(len(key.Key), check.Equals, 48)
|
||||||
|
|
||||||
// Make sure the Namespace association is populated
|
// Make sure the Namespace association is populated
|
||||||
c.Assert(k.Namespace.Name, check.Equals, n.Name)
|
c.Assert(key.Namespace.Name, check.Equals, namespace.Name)
|
||||||
|
|
||||||
_, err = h.ListPreAuthKeys("bogus")
|
_, err = app.ListPreAuthKeys("bogus")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
keys, err := h.ListPreAuthKeys(n.Name)
|
keys, err := app.ListPreAuthKeys(namespace.Name)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(keys), check.Equals, 1)
|
c.Assert(len(keys), check.Equals, 1)
|
||||||
|
|
||||||
// Make sure the Namespace association is populated
|
// Make sure the Namespace association is populated
|
||||||
c.Assert((keys)[0].Namespace.Name, check.Equals, n.Name)
|
c.Assert((keys)[0].Namespace.Name, check.Equals, namespace.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestExpiredPreAuthKey(c *check.C) {
|
func (*Suite) TestExpiredPreAuthKey(c *check.C) {
|
||||||
n, err := h.CreateNamespace("test2")
|
namespace, err := app.CreateNamespace("test2")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
pak, err := h.CreatePreAuthKey(n.Name, true, false, &now)
|
pak, err := app.CreatePreAuthKey(namespace.Name, true, false, &now)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
p, err := h.checkKeyValidity(pak.Key)
|
key, err := app.checkKeyValidity(pak.Key)
|
||||||
c.Assert(err, check.Equals, errorAuthKeyExpired)
|
c.Assert(err, check.Equals, errPreAuthKeyExpired)
|
||||||
c.Assert(p, check.IsNil)
|
c.Assert(key, check.IsNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestPreAuthKeyDoesNotExist(c *check.C) {
|
func (*Suite) TestPreAuthKeyDoesNotExist(c *check.C) {
|
||||||
p, err := h.checkKeyValidity("potatoKey")
|
key, err := app.checkKeyValidity("potatoKey")
|
||||||
c.Assert(err, check.Equals, errorAuthKeyNotFound)
|
c.Assert(err, check.Equals, errPreAuthKeyNotFound)
|
||||||
c.Assert(p, check.IsNil)
|
c.Assert(key, check.IsNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestValidateKeyOk(c *check.C) {
|
func (*Suite) TestValidateKeyOk(c *check.C) {
|
||||||
n, err := h.CreateNamespace("test3")
|
namespace, err := app.CreateNamespace("test3")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := h.CreatePreAuthKey(n.Name, true, false, nil)
|
pak, err := app.CreatePreAuthKey(namespace.Name, true, false, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
p, err := h.checkKeyValidity(pak.Key)
|
key, err := app.checkKeyValidity(pak.Key)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(p.ID, check.Equals, pak.ID)
|
c.Assert(key.ID, check.Equals, pak.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestAlreadyUsedKey(c *check.C) {
|
func (*Suite) TestAlreadyUsedKey(c *check.C) {
|
||||||
n, err := h.CreateNamespace("test4")
|
namespace, err := app.CreateNamespace("test4")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
m := Machine{
|
machine := Machine{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: "foo",
|
MachineKey: "foo",
|
||||||
NodeKey: "bar",
|
NodeKey: "bar",
|
||||||
DiscoKey: "faa",
|
DiscoKey: "faa",
|
||||||
Name: "testest",
|
Name: "testest",
|
||||||
NamespaceID: n.ID,
|
NamespaceID: namespace.ID,
|
||||||
Registered: true,
|
Registered: true,
|
||||||
RegisterMethod: "authKey",
|
RegisterMethod: "authKey",
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
h.db.Save(&m)
|
app.db.Save(&machine)
|
||||||
|
|
||||||
p, err := h.checkKeyValidity(pak.Key)
|
key, err := app.checkKeyValidity(pak.Key)
|
||||||
c.Assert(err, check.Equals, errSingleUseAuthKeyHasBeenUsed)
|
c.Assert(err, check.Equals, errSingleUseAuthKeyHasBeenUsed)
|
||||||
c.Assert(p, check.IsNil)
|
c.Assert(key, check.IsNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestReusableBeingUsedKey(c *check.C) {
|
func (*Suite) TestReusableBeingUsedKey(c *check.C) {
|
||||||
n, err := h.CreateNamespace("test5")
|
namespace, err := app.CreateNamespace("test5")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := h.CreatePreAuthKey(n.Name, true, false, nil)
|
pak, err := app.CreatePreAuthKey(namespace.Name, true, false, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
m := Machine{
|
machine := Machine{
|
||||||
ID: 1,
|
ID: 1,
|
||||||
MachineKey: "foo",
|
MachineKey: "foo",
|
||||||
NodeKey: "bar",
|
NodeKey: "bar",
|
||||||
DiscoKey: "faa",
|
DiscoKey: "faa",
|
||||||
Name: "testest",
|
Name: "testest",
|
||||||
NamespaceID: n.ID,
|
NamespaceID: namespace.ID,
|
||||||
Registered: true,
|
Registered: true,
|
||||||
RegisterMethod: "authKey",
|
RegisterMethod: "authKey",
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
h.db.Save(&m)
|
app.db.Save(&machine)
|
||||||
|
|
||||||
p, err := h.checkKeyValidity(pak.Key)
|
key, err := app.checkKeyValidity(pak.Key)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(p.ID, check.Equals, pak.ID)
|
c.Assert(key.ID, check.Equals, pak.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) {
|
func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) {
|
||||||
n, err := h.CreateNamespace("test6")
|
namespace, err := app.CreateNamespace("test6")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
p, err := h.checkKeyValidity(pak.Key)
|
key, err := app.checkKeyValidity(pak.Key)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(p.ID, check.Equals, pak.ID)
|
c.Assert(key.ID, check.Equals, pak.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestEphemeralKey(c *check.C) {
|
func (*Suite) TestEphemeralKey(c *check.C) {
|
||||||
n, err := h.CreateNamespace("test7")
|
namespace, err := app.CreateNamespace("test7")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := h.CreatePreAuthKey(n.Name, false, true, nil)
|
pak, err := app.CreatePreAuthKey(namespace.Name, false, true, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
m := Machine{
|
machine := Machine{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: "foo",
|
MachineKey: "foo",
|
||||||
NodeKey: "bar",
|
NodeKey: "bar",
|
||||||
DiscoKey: "faa",
|
DiscoKey: "faa",
|
||||||
Name: "testest",
|
Name: "testest",
|
||||||
NamespaceID: n.ID,
|
NamespaceID: namespace.ID,
|
||||||
Registered: true,
|
Registered: true,
|
||||||
RegisterMethod: "authKey",
|
RegisterMethod: "authKey",
|
||||||
LastSeen: &now,
|
LastSeen: &now,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
h.db.Save(&m)
|
app.db.Save(&machine)
|
||||||
|
|
||||||
_, err = h.checkKeyValidity(pak.Key)
|
_, err = app.checkKeyValidity(pak.Key)
|
||||||
// Ephemeral keys are by definition reusable
|
// Ephemeral keys are by definition reusable
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = h.GetMachine("test7", "testest")
|
_, err = app.GetMachine("test7", "testest")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
h.expireEphemeralNodesWorker()
|
app.expireEphemeralNodesWorker()
|
||||||
|
|
||||||
// The machine record should have been deleted
|
// The machine record should have been deleted
|
||||||
_, err = h.GetMachine("test7", "testest")
|
_, err = app.GetMachine("test7", "testest")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestExpirePreauthKey(c *check.C) {
|
func (*Suite) TestExpirePreauthKey(c *check.C) {
|
||||||
n, err := h.CreateNamespace("test3")
|
namespace, err := app.CreateNamespace("test3")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := h.CreatePreAuthKey(n.Name, true, false, nil)
|
pak, err := app.CreatePreAuthKey(namespace.Name, true, false, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(pak.Expiration, check.IsNil)
|
c.Assert(pak.Expiration, check.IsNil)
|
||||||
|
|
||||||
err = h.ExpirePreAuthKey(pak)
|
err = app.ExpirePreAuthKey(pak)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(pak.Expiration, check.NotNil)
|
c.Assert(pak.Expiration, check.NotNil)
|
||||||
|
|
||||||
p, err := h.checkKeyValidity(pak.Key)
|
key, err := app.checkKeyValidity(pak.Key)
|
||||||
c.Assert(err, check.Equals, errorAuthKeyExpired)
|
c.Assert(err, check.Equals, errPreAuthKeyExpired)
|
||||||
c.Assert(p, check.IsNil)
|
c.Assert(key, check.IsNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) {
|
func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) {
|
||||||
n, err := h.CreateNamespace("test6")
|
namespace, err := app.CreateNamespace("test6")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
pak.Used = true
|
pak.Used = true
|
||||||
h.db.Save(&pak)
|
app.db.Save(&pak)
|
||||||
|
|
||||||
_, err = h.checkKeyValidity(pak.Key)
|
_, err = app.checkKeyValidity(pak.Key)
|
||||||
c.Assert(err, check.Equals, errSingleUseAuthKeyHasBeenUsed)
|
c.Assert(err, check.Equals, errSingleUseAuthKeyHasBeenUsed)
|
||||||
}
|
}
|
||||||
|
|
55
routes.go
55
routes.go
|
@ -2,38 +2,48 @@ package headscale
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"gorm.io/datatypes"
|
"gorm.io/datatypes"
|
||||||
"inet.af/netaddr"
|
"inet.af/netaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
errRouteIsNotAvailable = Error("route is not available")
|
||||||
|
)
|
||||||
|
|
||||||
// Deprecated: use machine function instead
|
// Deprecated: use machine function instead
|
||||||
// GetAdvertisedNodeRoutes returns the subnet routes advertised by a node (identified by
|
// GetAdvertisedNodeRoutes returns the subnet routes advertised by a node (identified by
|
||||||
// namespace and node name)
|
// namespace and node name).
|
||||||
func (h *Headscale) GetAdvertisedNodeRoutes(namespace string, nodeName string) (*[]netaddr.IPPrefix, error) {
|
func (h *Headscale) GetAdvertisedNodeRoutes(
|
||||||
m, err := h.GetMachine(namespace, nodeName)
|
namespace string,
|
||||||
|
nodeName string,
|
||||||
|
) (*[]netaddr.IPPrefix, error) {
|
||||||
|
machine, err := h.GetMachine(namespace, nodeName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
hostInfo, err := m.GetHostInfo()
|
hostInfo, err := machine.GetHostInfo()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &hostInfo.RoutableIPs, nil
|
return &hostInfo.RoutableIPs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: use machine function instead
|
// Deprecated: use machine function instead
|
||||||
// GetEnabledNodeRoutes returns the subnet routes enabled by a node (identified by
|
// GetEnabledNodeRoutes returns the subnet routes enabled by a node (identified by
|
||||||
// namespace and node name)
|
// namespace and node name).
|
||||||
func (h *Headscale) GetEnabledNodeRoutes(namespace string, nodeName string) ([]netaddr.IPPrefix, error) {
|
func (h *Headscale) GetEnabledNodeRoutes(
|
||||||
m, err := h.GetMachine(namespace, nodeName)
|
namespace string,
|
||||||
|
nodeName string,
|
||||||
|
) ([]netaddr.IPPrefix, error) {
|
||||||
|
machine, err := h.GetMachine(namespace, nodeName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := m.EnabledRoutes.MarshalJSON()
|
data, err := machine.EnabledRoutes.MarshalJSON()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -57,8 +67,12 @@ func (h *Headscale) GetEnabledNodeRoutes(namespace string, nodeName string) ([]n
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: use machine function instead
|
// Deprecated: use machine function instead
|
||||||
// IsNodeRouteEnabled checks if a certain route has been enabled
|
// IsNodeRouteEnabled checks if a certain route has been enabled.
|
||||||
func (h *Headscale) IsNodeRouteEnabled(namespace string, nodeName string, routeStr string) bool {
|
func (h *Headscale) IsNodeRouteEnabled(
|
||||||
|
namespace string,
|
||||||
|
nodeName string,
|
||||||
|
routeStr string,
|
||||||
|
) bool {
|
||||||
route, err := netaddr.ParseIPPrefix(routeStr)
|
route, err := netaddr.ParseIPPrefix(routeStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
|
@ -74,14 +88,19 @@ func (h *Headscale) IsNodeRouteEnabled(namespace string, nodeName string, routeS
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: use EnableRoute in machine.go
|
// Deprecated: use EnableRoute in machine.go
|
||||||
// EnableNodeRoute enables a subnet route advertised by a node (identified by
|
// EnableNodeRoute enables a subnet route advertised by a node (identified by
|
||||||
// namespace and node name)
|
// namespace and node name).
|
||||||
func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr string) error {
|
func (h *Headscale) EnableNodeRoute(
|
||||||
m, err := h.GetMachine(namespace, nodeName)
|
namespace string,
|
||||||
|
nodeName string,
|
||||||
|
routeStr string,
|
||||||
|
) error {
|
||||||
|
machine, err := h.GetMachine(namespace, nodeName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -113,7 +132,7 @@ func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr
|
||||||
}
|
}
|
||||||
|
|
||||||
if !available {
|
if !available {
|
||||||
return fmt.Errorf("route (%s) is not available on node %s", nodeName, routeStr)
|
return errRouteIsNotAvailable
|
||||||
}
|
}
|
||||||
|
|
||||||
routes, err := json.Marshal(enabledRoutes)
|
routes, err := json.Marshal(enabledRoutes)
|
||||||
|
@ -121,10 +140,10 @@ func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
m.EnabledRoutes = datatypes.JSON(routes)
|
machine.EnabledRoutes = datatypes.JSON(routes)
|
||||||
h.db.Save(&m)
|
h.db.Save(&machine)
|
||||||
|
|
||||||
err = h.RequestMapUpdates(m.NamespaceID)
|
err = h.RequestMapUpdates(machine.NamespaceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,57 +10,60 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Suite) TestGetRoutes(c *check.C) {
|
func (s *Suite) TestGetRoutes(c *check.C) {
|
||||||
n, err := h.CreateNamespace("test")
|
namespace, err := app.CreateNamespace("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = h.GetMachine("test", "test_get_route_machine")
|
_, err = app.GetMachine("test", "test_get_route_machine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
route, err := netaddr.ParseIPPrefix("10.0.0.0/24")
|
route, err := netaddr.ParseIPPrefix("10.0.0.0/24")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
hi := tailcfg.Hostinfo{
|
hostInfo := tailcfg.Hostinfo{
|
||||||
RoutableIPs: []netaddr.IPPrefix{route},
|
RoutableIPs: []netaddr.IPPrefix{route},
|
||||||
}
|
}
|
||||||
hostinfo, err := json.Marshal(hi)
|
hostinfo, err := json.Marshal(hostInfo)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
m := Machine{
|
machine := Machine{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: "foo",
|
MachineKey: "foo",
|
||||||
NodeKey: "bar",
|
NodeKey: "bar",
|
||||||
DiscoKey: "faa",
|
DiscoKey: "faa",
|
||||||
Name: "test_get_route_machine",
|
Name: "test_get_route_machine",
|
||||||
NamespaceID: n.ID,
|
NamespaceID: namespace.ID,
|
||||||
Registered: true,
|
Registered: true,
|
||||||
RegisterMethod: "authKey",
|
RegisterMethod: "authKey",
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
HostInfo: datatypes.JSON(hostinfo),
|
HostInfo: datatypes.JSON(hostinfo),
|
||||||
}
|
}
|
||||||
h.db.Save(&m)
|
app.db.Save(&machine)
|
||||||
|
|
||||||
r, err := h.GetAdvertisedNodeRoutes("test", "test_get_route_machine")
|
advertisedRoutes, err := app.GetAdvertisedNodeRoutes(
|
||||||
|
"test",
|
||||||
|
"test_get_route_machine",
|
||||||
|
)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(*r), check.Equals, 1)
|
c.Assert(len(*advertisedRoutes), check.Equals, 1)
|
||||||
|
|
||||||
err = h.EnableNodeRoute("test", "test_get_route_machine", "192.168.0.0/24")
|
err = app.EnableNodeRoute("test", "test_get_route_machine", "192.168.0.0/24")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
err = h.EnableNodeRoute("test", "test_get_route_machine", "10.0.0.0/24")
|
err = app.EnableNodeRoute("test", "test_get_route_machine", "10.0.0.0/24")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestGetEnableRoutes(c *check.C) {
|
func (s *Suite) TestGetEnableRoutes(c *check.C) {
|
||||||
n, err := h.CreateNamespace("test")
|
namespace, err := app.CreateNamespace("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = h.GetMachine("test", "test_enable_route_machine")
|
_, err = app.GetMachine("test", "test_enable_route_machine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
route, err := netaddr.ParseIPPrefix(
|
route, err := netaddr.ParseIPPrefix(
|
||||||
|
@ -73,56 +76,68 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
|
||||||
)
|
)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
hi := tailcfg.Hostinfo{
|
hostInfo := tailcfg.Hostinfo{
|
||||||
RoutableIPs: []netaddr.IPPrefix{route, route2},
|
RoutableIPs: []netaddr.IPPrefix{route, route2},
|
||||||
}
|
}
|
||||||
hostinfo, err := json.Marshal(hi)
|
hostinfo, err := json.Marshal(hostInfo)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
m := Machine{
|
machine := Machine{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: "foo",
|
MachineKey: "foo",
|
||||||
NodeKey: "bar",
|
NodeKey: "bar",
|
||||||
DiscoKey: "faa",
|
DiscoKey: "faa",
|
||||||
Name: "test_enable_route_machine",
|
Name: "test_enable_route_machine",
|
||||||
NamespaceID: n.ID,
|
NamespaceID: namespace.ID,
|
||||||
Registered: true,
|
Registered: true,
|
||||||
RegisterMethod: "authKey",
|
RegisterMethod: "authKey",
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
HostInfo: datatypes.JSON(hostinfo),
|
HostInfo: datatypes.JSON(hostinfo),
|
||||||
}
|
}
|
||||||
h.db.Save(&m)
|
app.db.Save(&machine)
|
||||||
|
|
||||||
availableRoutes, err := h.GetAdvertisedNodeRoutes("test", "test_enable_route_machine")
|
availableRoutes, err := app.GetAdvertisedNodeRoutes(
|
||||||
|
"test",
|
||||||
|
"test_enable_route_machine",
|
||||||
|
)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(*availableRoutes), check.Equals, 2)
|
c.Assert(len(*availableRoutes), check.Equals, 2)
|
||||||
|
|
||||||
enabledRoutes, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine")
|
noEnabledRoutes, err := app.GetEnabledNodeRoutes(
|
||||||
|
"test",
|
||||||
|
"test_enable_route_machine",
|
||||||
|
)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(enabledRoutes), check.Equals, 0)
|
c.Assert(len(noEnabledRoutes), check.Equals, 0)
|
||||||
|
|
||||||
err = h.EnableNodeRoute("test", "test_enable_route_machine", "192.168.0.0/24")
|
err = app.EnableNodeRoute("test", "test_enable_route_machine", "192.168.0.0/24")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
err = h.EnableNodeRoute("test", "test_enable_route_machine", "10.0.0.0/24")
|
err = app.EnableNodeRoute("test", "test_enable_route_machine", "10.0.0.0/24")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
enabledRoutes1, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine")
|
enabledRoutes, err := app.GetEnabledNodeRoutes("test", "test_enable_route_machine")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(enabledRoutes1), check.Equals, 1)
|
c.Assert(len(enabledRoutes), check.Equals, 1)
|
||||||
|
|
||||||
// Adding it twice will just let it pass through
|
// Adding it twice will just let it pass through
|
||||||
err = h.EnableNodeRoute("test", "test_enable_route_machine", "10.0.0.0/24")
|
err = app.EnableNodeRoute("test", "test_enable_route_machine", "10.0.0.0/24")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
enabledRoutes2, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine")
|
enableRoutesAfterDoubleApply, err := app.GetEnabledNodeRoutes(
|
||||||
|
"test",
|
||||||
|
"test_enable_route_machine",
|
||||||
|
)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(enabledRoutes2), check.Equals, 1)
|
c.Assert(len(enableRoutesAfterDoubleApply), check.Equals, 1)
|
||||||
|
|
||||||
err = h.EnableNodeRoute("test", "test_enable_route_machine", "150.0.10.0/25")
|
err = app.EnableNodeRoute("test", "test_enable_route_machine", "150.0.10.0/25")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
enabledRoutes3, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine")
|
enabledRoutesWithAdditionalRoute, err := app.GetEnabledNodeRoutes(
|
||||||
|
"test",
|
||||||
|
"test_enable_route_machine",
|
||||||
|
)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(enabledRoutes3), check.Equals, 2)
|
c.Assert(len(enabledRoutesWithAdditionalRoute), check.Equals, 2)
|
||||||
}
|
}
|
||||||
|
|
58
sharing.go
58
sharing.go
|
@ -2,11 +2,13 @@ package headscale
|
||||||
|
|
||||||
import "gorm.io/gorm"
|
import "gorm.io/gorm"
|
||||||
|
|
||||||
const errorSameNamespace = Error("Destination namespace same as origin")
|
const (
|
||||||
const errorMachineAlreadyShared = Error("Node already shared to this namespace")
|
errSameNamespace = Error("Destination namespace same as origin")
|
||||||
const errorMachineNotShared = Error("Machine not shared to this namespace")
|
errMachineAlreadyShared = Error("Node already shared to this namespace")
|
||||||
|
errMachineNotShared = Error("Machine not shared to this namespace")
|
||||||
|
)
|
||||||
|
|
||||||
// SharedMachine is a join table to support sharing nodes between namespaces
|
// SharedMachine is a join table to support sharing nodes between namespaces.
|
||||||
type SharedMachine struct {
|
type SharedMachine struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
MachineID uint64
|
MachineID uint64
|
||||||
|
@ -15,49 +17,57 @@ type SharedMachine struct {
|
||||||
Namespace Namespace
|
Namespace Namespace
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddSharedMachineToNamespace adds a machine as a shared node to a namespace
|
// AddSharedMachineToNamespace adds a machine as a shared node to a namespace.
|
||||||
func (h *Headscale) AddSharedMachineToNamespace(m *Machine, ns *Namespace) error {
|
func (h *Headscale) AddSharedMachineToNamespace(
|
||||||
if m.NamespaceID == ns.ID {
|
machine *Machine,
|
||||||
return errorSameNamespace
|
namespace *Namespace,
|
||||||
|
) error {
|
||||||
|
if machine.NamespaceID == namespace.ID {
|
||||||
|
return errSameNamespace
|
||||||
}
|
}
|
||||||
|
|
||||||
sharedMachines := []SharedMachine{}
|
sharedMachines := []SharedMachine{}
|
||||||
if err := h.db.Where("machine_id = ? AND namespace_id = ?", m.ID, ns.ID).Find(&sharedMachines).Error; err != nil {
|
if err := h.db.Where("machine_id = ? AND namespace_id = ?", machine.ID, namespace.ID).Find(&sharedMachines).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if len(sharedMachines) > 0 {
|
if len(sharedMachines) > 0 {
|
||||||
return errorMachineAlreadyShared
|
return errMachineAlreadyShared
|
||||||
}
|
}
|
||||||
|
|
||||||
sharedMachine := SharedMachine{
|
sharedMachine := SharedMachine{
|
||||||
MachineID: m.ID,
|
MachineID: machine.ID,
|
||||||
Machine: *m,
|
Machine: *machine,
|
||||||
NamespaceID: ns.ID,
|
NamespaceID: namespace.ID,
|
||||||
Namespace: *ns,
|
Namespace: *namespace,
|
||||||
}
|
}
|
||||||
h.db.Save(&sharedMachine)
|
h.db.Save(&sharedMachine)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveSharedMachineFromNamespace removes a shared machine from a namespace
|
// RemoveSharedMachineFromNamespace removes a shared machine from a namespace.
|
||||||
func (h *Headscale) RemoveSharedMachineFromNamespace(m *Machine, ns *Namespace) error {
|
func (h *Headscale) RemoveSharedMachineFromNamespace(
|
||||||
if m.NamespaceID == ns.ID {
|
machine *Machine,
|
||||||
|
namespace *Namespace,
|
||||||
|
) error {
|
||||||
|
if machine.NamespaceID == namespace.ID {
|
||||||
// Can't unshare from primary namespace
|
// Can't unshare from primary namespace
|
||||||
return errorMachineNotShared
|
return errMachineNotShared
|
||||||
}
|
}
|
||||||
|
|
||||||
sharedMachine := SharedMachine{}
|
sharedMachine := SharedMachine{}
|
||||||
result := h.db.Where("machine_id = ? AND namespace_id = ?", m.ID, ns.ID).Unscoped().Delete(&sharedMachine)
|
result := h.db.Where("machine_id = ? AND namespace_id = ?", machine.ID, namespace.ID).
|
||||||
|
Unscoped().
|
||||||
|
Delete(&sharedMachine)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
if result.RowsAffected == 0 {
|
if result.RowsAffected == 0 {
|
||||||
return errorMachineNotShared
|
return errMachineNotShared
|
||||||
}
|
}
|
||||||
|
|
||||||
err := h.RequestMapUpdates(ns.ID)
|
err := h.RequestMapUpdates(namespace.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -65,10 +75,10 @@ func (h *Headscale) RemoveSharedMachineFromNamespace(m *Machine, ns *Namespace)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveSharedMachineFromAllNamespaces removes a machine as a shared node from all namespaces
|
// RemoveSharedMachineFromAllNamespaces removes a machine as a shared node from all namespaces.
|
||||||
func (h *Headscale) RemoveSharedMachineFromAllNamespaces(m *Machine) error {
|
func (h *Headscale) RemoveSharedMachineFromAllNamespaces(machine *Machine) error {
|
||||||
sharedMachine := SharedMachine{}
|
sharedMachine := SharedMachine{}
|
||||||
if result := h.db.Where("machine_id = ?", m.ID).Unscoped().Delete(&sharedMachine); result.Error != nil {
|
if result := h.db.Where("machine_id = ?", machine.ID).Unscoped().Delete(&sharedMachine); result.Error != nil {
|
||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
227
sharing_test.go
227
sharing_test.go
|
@ -4,45 +4,48 @@ import (
|
||||||
"gopkg.in/check.v1"
|
"gopkg.in/check.v1"
|
||||||
)
|
)
|
||||||
|
|
||||||
func CreateNodeNamespace(c *check.C, namespace, node, key, IP string) (*Namespace, *Machine) {
|
func CreateNodeNamespace(
|
||||||
n1, err := h.CreateNamespace(namespace)
|
c *check.C,
|
||||||
|
namespaceName, node, key, ip string,
|
||||||
|
) (*Namespace, *Machine) {
|
||||||
|
namespace, err := app.CreateNamespace(namespaceName)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
|
pak1, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = h.GetMachine(n1.Name, node)
|
_, err = app.GetMachine(namespace.Name, node)
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
m1 := &Machine{
|
machine := &Machine{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: key,
|
MachineKey: key,
|
||||||
NodeKey: key,
|
NodeKey: key,
|
||||||
DiscoKey: key,
|
DiscoKey: key,
|
||||||
Name: node,
|
Name: node,
|
||||||
NamespaceID: n1.ID,
|
NamespaceID: namespace.ID,
|
||||||
Registered: true,
|
Registered: true,
|
||||||
RegisterMethod: "authKey",
|
RegisterMethod: "authKey",
|
||||||
IPAddress: IP,
|
IPAddress: ip,
|
||||||
AuthKeyID: uint(pak1.ID),
|
AuthKeyID: uint(pak1.ID),
|
||||||
}
|
}
|
||||||
h.db.Save(m1)
|
app.db.Save(machine)
|
||||||
|
|
||||||
_, err = h.GetMachine(n1.Name, m1.Name)
|
_, err = app.GetMachine(namespace.Name, machine.Name)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
return n1, m1
|
return namespace, machine
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestBasicSharedNodesInNamespace(c *check.C) {
|
func (s *Suite) TestBasicSharedNodesInNamespace(c *check.C) {
|
||||||
n1, m1 := CreateNodeNamespace(
|
namespace1, machine1 := CreateNodeNamespace(
|
||||||
c,
|
c,
|
||||||
"shared1",
|
"shared1",
|
||||||
"test_get_shared_nodes_1",
|
"test_get_shared_nodes_1",
|
||||||
"686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
"686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||||
"100.64.0.1",
|
"100.64.0.1",
|
||||||
)
|
)
|
||||||
_, m2 := CreateNodeNamespace(
|
_, machine2 := CreateNodeNamespace(
|
||||||
c,
|
c,
|
||||||
"shared2",
|
"shared2",
|
||||||
"test_get_shared_nodes_2",
|
"test_get_shared_nodes_2",
|
||||||
|
@ -50,21 +53,21 @@ func (s *Suite) TestBasicSharedNodesInNamespace(c *check.C) {
|
||||||
"100.64.0.2",
|
"100.64.0.2",
|
||||||
)
|
)
|
||||||
|
|
||||||
p1s, err := h.getPeers(m1)
|
peersOfMachine1BeforeShared, err := app.getPeers(machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(p1s), check.Equals, 0)
|
c.Assert(len(peersOfMachine1BeforeShared), check.Equals, 0)
|
||||||
|
|
||||||
err = h.AddSharedMachineToNamespace(m2, n1)
|
err = app.AddSharedMachineToNamespace(machine2, namespace1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
p1sAfter, err := h.getPeers(m1)
|
peersOfMachine1AfterShared, err := app.getPeers(machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(p1sAfter), check.Equals, 1)
|
c.Assert(len(peersOfMachine1AfterShared), check.Equals, 1)
|
||||||
c.Assert(p1sAfter[0].ID, check.Equals, m2.ID)
|
c.Assert(peersOfMachine1AfterShared[0].ID, check.Equals, machine2.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestSameNamespace(c *check.C) {
|
func (s *Suite) TestSameNamespace(c *check.C) {
|
||||||
n1, m1 := CreateNodeNamespace(
|
namespace1, machine1 := CreateNodeNamespace(
|
||||||
c,
|
c,
|
||||||
"shared1",
|
"shared1",
|
||||||
"test_get_shared_nodes_1",
|
"test_get_shared_nodes_1",
|
||||||
|
@ -72,23 +75,23 @@ func (s *Suite) TestSameNamespace(c *check.C) {
|
||||||
"100.64.0.1",
|
"100.64.0.1",
|
||||||
)
|
)
|
||||||
|
|
||||||
p1s, err := h.getPeers(m1)
|
peersOfMachine1BeforeShare, err := app.getPeers(machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(p1s), check.Equals, 0)
|
c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 0)
|
||||||
|
|
||||||
err = h.AddSharedMachineToNamespace(m1, n1)
|
err = app.AddSharedMachineToNamespace(machine1, namespace1)
|
||||||
c.Assert(err, check.Equals, errorSameNamespace)
|
c.Assert(err, check.Equals, errSameNamespace)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestUnshare(c *check.C) {
|
func (s *Suite) TestUnshare(c *check.C) {
|
||||||
n1, m1 := CreateNodeNamespace(
|
namespace1, machine1 := CreateNodeNamespace(
|
||||||
c,
|
c,
|
||||||
"shared1",
|
"shared1",
|
||||||
"test_unshare_1",
|
"test_unshare_1",
|
||||||
"686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
"686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||||
"100.64.0.1",
|
"100.64.0.1",
|
||||||
)
|
)
|
||||||
_, m2 := CreateNodeNamespace(
|
_, machine2 := CreateNodeNamespace(
|
||||||
c,
|
c,
|
||||||
"shared2",
|
"shared2",
|
||||||
"test_unshare_2",
|
"test_unshare_2",
|
||||||
|
@ -96,40 +99,40 @@ func (s *Suite) TestUnshare(c *check.C) {
|
||||||
"100.64.0.2",
|
"100.64.0.2",
|
||||||
)
|
)
|
||||||
|
|
||||||
p1s, err := h.getPeers(m1)
|
peersOfMachine1BeforeShare, err := app.getPeers(machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(p1s), check.Equals, 0)
|
c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 0)
|
||||||
|
|
||||||
err = h.AddSharedMachineToNamespace(m2, n1)
|
err = app.AddSharedMachineToNamespace(machine2, namespace1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
p1s, err = h.getShared(m1)
|
peersOfMachine1BeforeShare, err = app.getShared(machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(p1s), check.Equals, 1)
|
c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 1)
|
||||||
|
|
||||||
err = h.RemoveSharedMachineFromNamespace(m2, n1)
|
err = app.RemoveSharedMachineFromNamespace(machine2, namespace1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
p1s, err = h.getShared(m1)
|
peersOfMachine1BeforeShare, err = app.getShared(machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(p1s), check.Equals, 0)
|
c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 0)
|
||||||
|
|
||||||
err = h.RemoveSharedMachineFromNamespace(m2, n1)
|
err = app.RemoveSharedMachineFromNamespace(machine2, namespace1)
|
||||||
c.Assert(err, check.Equals, errorMachineNotShared)
|
c.Assert(err, check.Equals, errMachineNotShared)
|
||||||
|
|
||||||
err = h.RemoveSharedMachineFromNamespace(m1, n1)
|
err = app.RemoveSharedMachineFromNamespace(machine1, namespace1)
|
||||||
c.Assert(err, check.Equals, errorMachineNotShared)
|
c.Assert(err, check.Equals, errMachineNotShared)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestAlreadyShared(c *check.C) {
|
func (s *Suite) TestAlreadyShared(c *check.C) {
|
||||||
n1, m1 := CreateNodeNamespace(
|
namespace1, machine1 := CreateNodeNamespace(
|
||||||
c,
|
c,
|
||||||
"shared1",
|
"shared1",
|
||||||
"test_get_shared_nodes_1",
|
"test_get_shared_nodes_1",
|
||||||
"686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
"686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||||
"100.64.0.1",
|
"100.64.0.1",
|
||||||
)
|
)
|
||||||
_, m2 := CreateNodeNamespace(
|
_, machine2 := CreateNodeNamespace(
|
||||||
c,
|
c,
|
||||||
"shared2",
|
"shared2",
|
||||||
"test_get_shared_nodes_2",
|
"test_get_shared_nodes_2",
|
||||||
|
@ -137,25 +140,25 @@ func (s *Suite) TestAlreadyShared(c *check.C) {
|
||||||
"100.64.0.2",
|
"100.64.0.2",
|
||||||
)
|
)
|
||||||
|
|
||||||
p1s, err := h.getPeers(m1)
|
peersOfMachine1BeforeShare, err := app.getPeers(machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(p1s), check.Equals, 0)
|
c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 0)
|
||||||
|
|
||||||
err = h.AddSharedMachineToNamespace(m2, n1)
|
err = app.AddSharedMachineToNamespace(machine2, namespace1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
err = h.AddSharedMachineToNamespace(m2, n1)
|
err = app.AddSharedMachineToNamespace(machine2, namespace1)
|
||||||
c.Assert(err, check.Equals, errorMachineAlreadyShared)
|
c.Assert(err, check.Equals, errMachineAlreadyShared)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestDoNotIncludeRoutesOnShared(c *check.C) {
|
func (s *Suite) TestDoNotIncludeRoutesOnShared(c *check.C) {
|
||||||
n1, m1 := CreateNodeNamespace(
|
namespace1, machine1 := CreateNodeNamespace(
|
||||||
c,
|
c,
|
||||||
"shared1",
|
"shared1",
|
||||||
"test_get_shared_nodes_1",
|
"test_get_shared_nodes_1",
|
||||||
"686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
"686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||||
"100.64.0.1",
|
"100.64.0.1",
|
||||||
)
|
)
|
||||||
_, m2 := CreateNodeNamespace(
|
_, machine2 := CreateNodeNamespace(
|
||||||
c,
|
c,
|
||||||
"shared2",
|
"shared2",
|
||||||
"test_get_shared_nodes_2",
|
"test_get_shared_nodes_2",
|
||||||
|
@ -163,35 +166,35 @@ func (s *Suite) TestDoNotIncludeRoutesOnShared(c *check.C) {
|
||||||
"100.64.0.2",
|
"100.64.0.2",
|
||||||
)
|
)
|
||||||
|
|
||||||
p1s, err := h.getPeers(m1)
|
peersOfMachine1BeforeShare, err := app.getPeers(machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(p1s), check.Equals, 0)
|
c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 0)
|
||||||
|
|
||||||
err = h.AddSharedMachineToNamespace(m2, n1)
|
err = app.AddSharedMachineToNamespace(machine2, namespace1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
p1sAfter, err := h.getPeers(m1)
|
peersOfMachine1AfterShare, err := app.getPeers(machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(p1sAfter), check.Equals, 1)
|
c.Assert(len(peersOfMachine1AfterShare), check.Equals, 1)
|
||||||
c.Assert(p1sAfter[0].Name, check.Equals, "test_get_shared_nodes_2")
|
c.Assert(peersOfMachine1AfterShare[0].Name, check.Equals, "test_get_shared_nodes_2")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestComplexSharingAcrossNamespaces(c *check.C) {
|
func (s *Suite) TestComplexSharingAcrossNamespaces(c *check.C) {
|
||||||
n1, m1 := CreateNodeNamespace(
|
namespace1, machine1 := CreateNodeNamespace(
|
||||||
c,
|
c,
|
||||||
"shared1",
|
"shared1",
|
||||||
"test_get_shared_nodes_1",
|
"test_get_shared_nodes_1",
|
||||||
"686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
"686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||||
"100.64.0.1",
|
"100.64.0.1",
|
||||||
)
|
)
|
||||||
_, m2 := CreateNodeNamespace(
|
_, machine2 := CreateNodeNamespace(
|
||||||
c,
|
c,
|
||||||
"shared2",
|
"shared2",
|
||||||
"test_get_shared_nodes_2",
|
"test_get_shared_nodes_2",
|
||||||
"dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
"dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
"100.64.0.2",
|
"100.64.0.2",
|
||||||
)
|
)
|
||||||
_, m3 := CreateNodeNamespace(
|
_, machine3 := CreateNodeNamespace(
|
||||||
c,
|
c,
|
||||||
"shared3",
|
"shared3",
|
||||||
"test_get_shared_nodes_3",
|
"test_get_shared_nodes_3",
|
||||||
|
@ -199,76 +202,80 @@ func (s *Suite) TestComplexSharingAcrossNamespaces(c *check.C) {
|
||||||
"100.64.0.3",
|
"100.64.0.3",
|
||||||
)
|
)
|
||||||
|
|
||||||
pak4, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
|
pak4, err := app.CreatePreAuthKey(namespace1.Name, false, false, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
m4 := &Machine{
|
machine4 := &Machine{
|
||||||
ID: 4,
|
ID: 4,
|
||||||
MachineKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f",
|
MachineKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f",
|
||||||
NodeKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f",
|
NodeKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f",
|
||||||
DiscoKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f",
|
DiscoKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f",
|
||||||
Name: "test_get_shared_nodes_4",
|
Name: "test_get_shared_nodes_4",
|
||||||
NamespaceID: n1.ID,
|
NamespaceID: namespace1.ID,
|
||||||
Registered: true,
|
Registered: true,
|
||||||
RegisterMethod: "authKey",
|
RegisterMethod: "authKey",
|
||||||
IPAddress: "100.64.0.4",
|
IPAddress: "100.64.0.4",
|
||||||
AuthKeyID: uint(pak4.ID),
|
AuthKeyID: uint(pak4.ID),
|
||||||
}
|
}
|
||||||
h.db.Save(m4)
|
app.db.Save(machine4)
|
||||||
|
|
||||||
_, err = h.GetMachine(n1.Name, m4.Name)
|
_, err = app.GetMachine(namespace1.Name, machine4.Name)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
p1s, err := h.getPeers(m1)
|
peersOfMachine1BeforeShare, err := app.getPeers(machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(p1s), check.Equals, 1) // node1 can see node4
|
c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 1) // node1 can see node4
|
||||||
c.Assert(p1s[0].Name, check.Equals, m4.Name)
|
c.Assert(peersOfMachine1BeforeShare[0].Name, check.Equals, machine4.Name)
|
||||||
|
|
||||||
err = h.AddSharedMachineToNamespace(m2, n1)
|
err = app.AddSharedMachineToNamespace(machine2, namespace1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
p1sAfter, err := h.getPeers(m1)
|
peersOfMachine1AfterShare, err := app.getPeers(machine1)
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
c.Assert(len(p1sAfter), check.Equals, 2) // node1 can see node2 (shared) and node4 (same namespace)
|
|
||||||
c.Assert(p1sAfter[0].Name, check.Equals, m2.Name)
|
|
||||||
c.Assert(p1sAfter[1].Name, check.Equals, m4.Name)
|
|
||||||
|
|
||||||
node1shared, err := h.getShared(m1)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
c.Assert(len(node1shared), check.Equals, 1) // node1 can see node2 as shared
|
|
||||||
c.Assert(node1shared[0].Name, check.Equals, m2.Name)
|
|
||||||
|
|
||||||
pAlone, err := h.getPeers(m3)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
c.Assert(len(pAlone), check.Equals, 0) // node3 is alone
|
|
||||||
|
|
||||||
pSharedTo, err := h.getPeers(m2)
|
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(
|
c.Assert(
|
||||||
len(pSharedTo),
|
len(peersOfMachine1AfterShare),
|
||||||
|
check.Equals,
|
||||||
|
2,
|
||||||
|
) // node1 can see node2 (shared) and node4 (same namespace)
|
||||||
|
c.Assert(peersOfMachine1AfterShare[0].Name, check.Equals, machine2.Name)
|
||||||
|
c.Assert(peersOfMachine1AfterShare[1].Name, check.Equals, machine4.Name)
|
||||||
|
|
||||||
|
sharedOfMachine1, err := app.getShared(machine1)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(len(sharedOfMachine1), check.Equals, 1) // node1 can see node2 as shared
|
||||||
|
c.Assert(sharedOfMachine1[0].Name, check.Equals, machine2.Name)
|
||||||
|
|
||||||
|
peersOfMachine3, err := app.getPeers(machine3)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(len(peersOfMachine3), check.Equals, 0) // node3 is alone
|
||||||
|
|
||||||
|
peersOfMachine2, err := app.getPeers(machine2)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(
|
||||||
|
len(peersOfMachine2),
|
||||||
check.Equals,
|
check.Equals,
|
||||||
2,
|
2,
|
||||||
) // node2 should see node1 (sharedTo) and node4 (sharedTo), as is shared in namespace1
|
) // node2 should see node1 (sharedTo) and node4 (sharedTo), as is shared in namespace1
|
||||||
c.Assert(pSharedTo[0].Name, check.Equals, m1.Name)
|
c.Assert(peersOfMachine2[0].Name, check.Equals, machine1.Name)
|
||||||
c.Assert(pSharedTo[1].Name, check.Equals, m4.Name)
|
c.Assert(peersOfMachine2[1].Name, check.Equals, machine4.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestDeleteSharedMachine(c *check.C) {
|
func (s *Suite) TestDeleteSharedMachine(c *check.C) {
|
||||||
n1, m1 := CreateNodeNamespace(
|
namespace1, machine1 := CreateNodeNamespace(
|
||||||
c,
|
c,
|
||||||
"shared1",
|
"shared1",
|
||||||
"test_get_shared_nodes_1",
|
"test_get_shared_nodes_1",
|
||||||
"686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
"686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||||
"100.64.0.1",
|
"100.64.0.1",
|
||||||
)
|
)
|
||||||
_, m2 := CreateNodeNamespace(
|
_, machine2 := CreateNodeNamespace(
|
||||||
c,
|
c,
|
||||||
"shared2",
|
"shared2",
|
||||||
"test_get_shared_nodes_2",
|
"test_get_shared_nodes_2",
|
||||||
"dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
"dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
"100.64.0.2",
|
"100.64.0.2",
|
||||||
)
|
)
|
||||||
_, m3 := CreateNodeNamespace(
|
_, machine3 := CreateNodeNamespace(
|
||||||
c,
|
c,
|
||||||
"shared3",
|
"shared3",
|
||||||
"test_get_shared_nodes_3",
|
"test_get_shared_nodes_3",
|
||||||
|
@ -276,56 +283,58 @@ func (s *Suite) TestDeleteSharedMachine(c *check.C) {
|
||||||
"100.64.0.3",
|
"100.64.0.3",
|
||||||
)
|
)
|
||||||
|
|
||||||
pak4n1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
|
pak4n1, err := app.CreatePreAuthKey(namespace1.Name, false, false, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
m4 := &Machine{
|
machine4 := &Machine{
|
||||||
ID: 4,
|
ID: 4,
|
||||||
MachineKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f",
|
MachineKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f",
|
||||||
NodeKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f",
|
NodeKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f",
|
||||||
DiscoKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f",
|
DiscoKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f",
|
||||||
Name: "test_get_shared_nodes_4",
|
Name: "test_get_shared_nodes_4",
|
||||||
NamespaceID: n1.ID,
|
NamespaceID: namespace1.ID,
|
||||||
Registered: true,
|
Registered: true,
|
||||||
RegisterMethod: "authKey",
|
RegisterMethod: "authKey",
|
||||||
IPAddress: "100.64.0.4",
|
IPAddress: "100.64.0.4",
|
||||||
AuthKeyID: uint(pak4n1.ID),
|
AuthKeyID: uint(pak4n1.ID),
|
||||||
}
|
}
|
||||||
h.db.Save(m4)
|
app.db.Save(machine4)
|
||||||
|
|
||||||
_, err = h.GetMachine(n1.Name, m4.Name)
|
_, err = app.GetMachine(namespace1.Name, machine4.Name)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
p1s, err := h.getPeers(m1)
|
peersOfMachine1BeforeShare, err := app.getPeers(machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(p1s), check.Equals, 1) // nodes 1 and 4
|
c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 1) // nodes 1 and 4
|
||||||
c.Assert(p1s[0].Name, check.Equals, m4.Name)
|
c.Assert(peersOfMachine1BeforeShare[0].Name, check.Equals, machine4.Name)
|
||||||
|
|
||||||
err = h.AddSharedMachineToNamespace(m2, n1)
|
err = app.AddSharedMachineToNamespace(machine2, namespace1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
p1sAfter, err := h.getPeers(m1)
|
peersOfMachine1AfterShare, err := app.getPeers(machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(p1sAfter), check.Equals, 2) // nodes 1, 2, 4
|
c.Assert(len(peersOfMachine1AfterShare), check.Equals, 2) // nodes 1, 2, 4
|
||||||
c.Assert(p1sAfter[0].Name, check.Equals, m2.Name)
|
c.Assert(peersOfMachine1AfterShare[0].Name, check.Equals, machine2.Name)
|
||||||
c.Assert(p1sAfter[1].Name, check.Equals, m4.Name)
|
c.Assert(peersOfMachine1AfterShare[1].Name, check.Equals, machine4.Name)
|
||||||
|
|
||||||
node1shared, err := h.getShared(m1)
|
sharedOfMachine1, err := app.getShared(machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(node1shared), check.Equals, 1) // nodes 1, 2, 4
|
c.Assert(len(sharedOfMachine1), check.Equals, 1) // nodes 1, 2, 4
|
||||||
c.Assert(node1shared[0].Name, check.Equals, m2.Name)
|
c.Assert(sharedOfMachine1[0].Name, check.Equals, machine2.Name)
|
||||||
|
|
||||||
pAlone, err := h.getPeers(m3)
|
peersOfMachine3, err := app.getPeers(machine3)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(pAlone), check.Equals, 0) // node 3 is alone
|
c.Assert(len(peersOfMachine3), check.Equals, 0) // node 3 is alone
|
||||||
|
|
||||||
sharedMachines, err := h.ListSharedMachinesInNamespace(n1.Name)
|
sharedMachinesInNamespace1, err := app.ListSharedMachinesInNamespace(
|
||||||
|
namespace1.Name,
|
||||||
|
)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(sharedMachines), check.Equals, 1)
|
c.Assert(len(sharedMachinesInNamespace1), check.Equals, 1)
|
||||||
|
|
||||||
err = h.DeleteMachine(m2)
|
err = app.DeleteMachine(machine2)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
sharedMachines, err = h.ListSharedMachinesInNamespace(n1.Name)
|
sharedMachinesInNamespace1, err = app.ListSharedMachinesInNamespace(namespace1.Name)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(sharedMachines), check.Equals, 0)
|
c.Assert(len(sharedMachinesInNamespace1), check.Equals, 0)
|
||||||
}
|
}
|
||||||
|
|
22
swagger.go
22
swagger.go
|
@ -6,16 +6,15 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"text/template"
|
"text/template"
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
//go:embed gen/openapiv2/headscale/v1/headscale.swagger.json
|
//go:embed gen/openapiv2/headscale/v1/headscale.swagger.json
|
||||||
var apiV1JSON []byte
|
var apiV1JSON []byte
|
||||||
|
|
||||||
func SwaggerUI(c *gin.Context) {
|
func SwaggerUI(ctx *gin.Context) {
|
||||||
t := template.Must(template.New("swagger").Parse(`
|
swaggerTemplate := template.Must(template.New("swagger").Parse(`
|
||||||
<html>
|
<html>
|
||||||
<head>
|
<head>
|
||||||
<link rel="stylesheet" type="text/css" href="https://unpkg.com/swagger-ui-dist@3/swagger-ui.css">
|
<link rel="stylesheet" type="text/css" href="https://unpkg.com/swagger-ui-dist@3/swagger-ui.css">
|
||||||
|
@ -48,18 +47,23 @@ func SwaggerUI(c *gin.Context) {
|
||||||
</html>`))
|
</html>`))
|
||||||
|
|
||||||
var payload bytes.Buffer
|
var payload bytes.Buffer
|
||||||
if err := t.Execute(&payload, struct{}{}); err != nil {
|
if err := swaggerTemplate.Execute(&payload, struct{}{}); err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Could not render Swagger")
|
Msg("Could not render Swagger")
|
||||||
c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Swagger"))
|
ctx.Data(
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
"text/html; charset=utf-8",
|
||||||
|
[]byte("Could not render Swagger"),
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes())
|
ctx.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes())
|
||||||
}
|
}
|
||||||
|
|
||||||
func SwaggerAPIv1(c *gin.Context) {
|
func SwaggerAPIv1(ctx *gin.Context) {
|
||||||
c.Data(http.StatusOK, "application/json; charset=utf-8", apiV1JSON)
|
ctx.Data(http.StatusOK, "application/json; charset=utf-8", apiV1JSON)
|
||||||
}
|
}
|
||||||
|
|
57
utils.go
57
utils.go
|
@ -20,31 +20,48 @@ import (
|
||||||
"tailscale.com/types/wgkey"
|
"tailscale.com/types/wgkey"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
errCannotDecryptReponse = Error("cannot decrypt response")
|
||||||
|
errResponseMissingNonce = Error("response missing nonce")
|
||||||
|
errCouldNotAllocateIP = Error("could not find any suitable IP")
|
||||||
|
)
|
||||||
|
|
||||||
// Error is used to compare errors as per https://dave.cheney.net/2016/04/07/constant-errors
|
// Error is used to compare errors as per https://dave.cheney.net/2016/04/07/constant-errors
|
||||||
type Error string
|
type Error string
|
||||||
|
|
||||||
func (e Error) Error() string { return string(e) }
|
func (e Error) Error() string { return string(e) }
|
||||||
|
|
||||||
func decode(msg []byte, v interface{}, pubKey *wgkey.Key, privKey *wgkey.Private) error {
|
func decode(
|
||||||
|
msg []byte,
|
||||||
|
v interface{},
|
||||||
|
pubKey *wgkey.Key,
|
||||||
|
privKey *wgkey.Private,
|
||||||
|
) error {
|
||||||
return decodeMsg(msg, v, pubKey, privKey)
|
return decodeMsg(msg, v, pubKey, privKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
func decodeMsg(msg []byte, v interface{}, pubKey *wgkey.Key, privKey *wgkey.Private) error {
|
func decodeMsg(
|
||||||
|
msg []byte,
|
||||||
|
output interface{},
|
||||||
|
pubKey *wgkey.Key,
|
||||||
|
privKey *wgkey.Private,
|
||||||
|
) error {
|
||||||
decrypted, err := decryptMsg(msg, pubKey, privKey)
|
decrypted, err := decryptMsg(msg, pubKey, privKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// fmt.Println(string(decrypted))
|
// fmt.Println(string(decrypted))
|
||||||
if err := json.Unmarshal(decrypted, v); err != nil {
|
if err := json.Unmarshal(decrypted, output); err != nil {
|
||||||
return fmt.Errorf("response: %v", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func decryptMsg(msg []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, error) {
|
func decryptMsg(msg []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, error) {
|
||||||
var nonce [24]byte
|
var nonce [24]byte
|
||||||
if len(msg) < len(nonce)+1 {
|
if len(msg) < len(nonce)+1 {
|
||||||
return nil, fmt.Errorf("response missing nonce, len=%d", len(msg))
|
return nil, errResponseMissingNonce
|
||||||
}
|
}
|
||||||
copy(nonce[:], msg)
|
copy(nonce[:], msg)
|
||||||
msg = msg[len(nonce):]
|
msg = msg[len(nonce):]
|
||||||
|
@ -52,8 +69,9 @@ func decryptMsg(msg []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte,
|
||||||
pub, pri := (*[32]byte)(pubKey), (*[32]byte)(privKey)
|
pub, pri := (*[32]byte)(pubKey), (*[32]byte)(privKey)
|
||||||
decrypted, ok := box.Open(nil, msg, &nonce, pub, pri)
|
decrypted, ok := box.Open(nil, msg, &nonce, pub, pri)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("cannot decrypt response")
|
return nil, errCannotDecryptReponse
|
||||||
}
|
}
|
||||||
|
|
||||||
return decrypted, nil
|
return decrypted, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -66,13 +84,18 @@ func encode(v interface{}, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, e
|
||||||
return encodeMsg(b, pubKey, privKey)
|
return encodeMsg(b, pubKey, privKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
func encodeMsg(b []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, error) {
|
func encodeMsg(
|
||||||
|
payload []byte,
|
||||||
|
pubKey *wgkey.Key,
|
||||||
|
privKey *wgkey.Private,
|
||||||
|
) ([]byte, error) {
|
||||||
var nonce [24]byte
|
var nonce [24]byte
|
||||||
if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil {
|
if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
pub, pri := (*[32]byte)(pubKey), (*[32]byte)(privKey)
|
pub, pri := (*[32]byte)(pubKey), (*[32]byte)(privKey)
|
||||||
msg := box.Seal(nonce[:], b, &nonce, pub, pri)
|
msg := box.Seal(nonce[:], payload, &nonce, pub, pri)
|
||||||
|
|
||||||
return msg, nil
|
return msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -89,7 +112,7 @@ func (h *Headscale) getAvailableIP() (*netaddr.IP, error) {
|
||||||
|
|
||||||
for {
|
for {
|
||||||
if !ipPrefix.Contains(ip) {
|
if !ipPrefix.Contains(ip) {
|
||||||
return nil, fmt.Errorf("could not find any suitable IP in %s", ipPrefix)
|
return nil, errCouldNotAllocateIP
|
||||||
}
|
}
|
||||||
|
|
||||||
// Some OS (including Linux) does not like when IPs ends with 0 or 255, which
|
// Some OS (including Linux) does not like when IPs ends with 0 or 255, which
|
||||||
|
@ -98,13 +121,14 @@ func (h *Headscale) getAvailableIP() (*netaddr.IP, error) {
|
||||||
ipRaw := ip.As4()
|
ipRaw := ip.As4()
|
||||||
if ipRaw[3] == 0 || ipRaw[3] == 255 {
|
if ipRaw[3] == 0 || ipRaw[3] == 255 {
|
||||||
ip = ip.Next()
|
ip = ip.Next()
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if ip.IsZero() &&
|
if ip.IsZero() &&
|
||||||
ip.IsLoopback() {
|
ip.IsLoopback() {
|
||||||
|
|
||||||
ip = ip.Next()
|
ip = ip.Next()
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -125,7 +149,7 @@ func (h *Headscale) getUsedIPs() ([]netaddr.IP, error) {
|
||||||
if addr != "" {
|
if addr != "" {
|
||||||
ip, err := netaddr.ParseIP(addr)
|
ip, err := netaddr.ParseIP(addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse ip from database, %w", err)
|
return nil, fmt.Errorf("failed to parse ip from database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ips[index] = ip
|
ips[index] = ip
|
||||||
|
@ -156,11 +180,16 @@ func tailNodesToString(nodes []*tailcfg.Node) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func tailMapResponseToString(resp tailcfg.MapResponse) string {
|
func tailMapResponseToString(resp tailcfg.MapResponse) string {
|
||||||
return fmt.Sprintf("{ Node: %s, Peers: %s }", resp.Node.Name, tailNodesToString(resp.Peers))
|
return fmt.Sprintf(
|
||||||
|
"{ Node: %s, Peers: %s }",
|
||||||
|
resp.Node.Name,
|
||||||
|
tailNodesToString(resp.Peers),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) {
|
func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) {
|
||||||
var d net.Dialer
|
var d net.Dialer
|
||||||
|
|
||||||
return d.DialContext(ctx, "unix", addr)
|
return d.DialContext(ctx, "unix", addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -174,7 +203,7 @@ func ipPrefixToString(prefixes []netaddr.IPPrefix) []string {
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
func stringToIpPrefix(prefixes []string) ([]netaddr.IPPrefix, error) {
|
func stringToIPPrefix(prefixes []string) ([]netaddr.IPPrefix, error) {
|
||||||
result := make([]netaddr.IPPrefix, len(prefixes))
|
result := make([]netaddr.IPPrefix, len(prefixes))
|
||||||
|
|
||||||
for index, prefixStr := range prefixes {
|
for index, prefixStr := range prefixes {
|
||||||
|
@ -189,7 +218,7 @@ func stringToIpPrefix(prefixes []string) ([]netaddr.IPPrefix, error) {
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func containsIpPrefix(prefixes []netaddr.IPPrefix, prefix netaddr.IPPrefix) bool {
|
func containsIPPrefix(prefixes []netaddr.IPPrefix, prefix netaddr.IPPrefix) bool {
|
||||||
for _, p := range prefixes {
|
for _, p := range prefixes {
|
||||||
if prefix == p {
|
if prefix == p {
|
||||||
return true
|
return true
|
||||||
|
|
|
@ -6,7 +6,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Suite) TestGetAvailableIp(c *check.C) {
|
func (s *Suite) TestGetAvailableIp(c *check.C) {
|
||||||
ip, err := h.getAvailableIP()
|
ip, err := app.getAvailableIP()
|
||||||
|
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
@ -16,33 +16,33 @@ func (s *Suite) TestGetAvailableIp(c *check.C) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestGetUsedIps(c *check.C) {
|
func (s *Suite) TestGetUsedIps(c *check.C) {
|
||||||
ip, err := h.getAvailableIP()
|
ip, err := app.getAvailableIP()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
n, err := h.CreateNamespace("test_ip")
|
namespace, err := app.CreateNamespace("test_ip")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = h.GetMachine("test", "testmachine")
|
_, err = app.GetMachine("test", "testmachine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
m := Machine{
|
machine := Machine{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: "foo",
|
MachineKey: "foo",
|
||||||
NodeKey: "bar",
|
NodeKey: "bar",
|
||||||
DiscoKey: "faa",
|
DiscoKey: "faa",
|
||||||
Name: "testmachine",
|
Name: "testmachine",
|
||||||
NamespaceID: n.ID,
|
NamespaceID: namespace.ID,
|
||||||
Registered: true,
|
Registered: true,
|
||||||
RegisterMethod: "authKey",
|
RegisterMethod: "authKey",
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
IPAddress: ip.String(),
|
IPAddress: ip.String(),
|
||||||
}
|
}
|
||||||
h.db.Save(&m)
|
app.db.Save(&machine)
|
||||||
|
|
||||||
ips, err := h.getUsedIPs()
|
ips, err := app.getUsedIPs()
|
||||||
|
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
@ -50,42 +50,42 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
|
||||||
|
|
||||||
c.Assert(ips[0], check.Equals, expected)
|
c.Assert(ips[0], check.Equals, expected)
|
||||||
|
|
||||||
m1, err := h.GetMachineByID(0)
|
machine1, err := app.GetMachineByID(0)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
c.Assert(m1.IPAddress, check.Equals, expected.String())
|
c.Assert(machine1.IPAddress, check.Equals, expected.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestGetMultiIp(c *check.C) {
|
func (s *Suite) TestGetMultiIp(c *check.C) {
|
||||||
n, err := h.CreateNamespace("test-ip-multi")
|
namespace, err := app.CreateNamespace("test-ip-multi")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
for i := 1; i <= 350; i++ {
|
for index := 1; index <= 350; index++ {
|
||||||
ip, err := h.getAvailableIP()
|
ip, err := app.getAvailableIP()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = h.GetMachine("test", "testmachine")
|
_, err = app.GetMachine("test", "testmachine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
m := Machine{
|
machine := Machine{
|
||||||
ID: uint64(i),
|
ID: uint64(index),
|
||||||
MachineKey: "foo",
|
MachineKey: "foo",
|
||||||
NodeKey: "bar",
|
NodeKey: "bar",
|
||||||
DiscoKey: "faa",
|
DiscoKey: "faa",
|
||||||
Name: "testmachine",
|
Name: "testmachine",
|
||||||
NamespaceID: n.ID,
|
NamespaceID: namespace.ID,
|
||||||
Registered: true,
|
Registered: true,
|
||||||
RegisterMethod: "authKey",
|
RegisterMethod: "authKey",
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
IPAddress: ip.String(),
|
IPAddress: ip.String(),
|
||||||
}
|
}
|
||||||
h.db.Save(&m)
|
app.db.Save(&machine)
|
||||||
}
|
}
|
||||||
|
|
||||||
ips, err := h.getUsedIPs()
|
ips, err := app.getUsedIPs()
|
||||||
|
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
@ -96,59 +96,67 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
|
||||||
c.Assert(ips[300], check.Equals, netaddr.MustParseIP("10.27.1.47"))
|
c.Assert(ips[300], check.Equals, netaddr.MustParseIP("10.27.1.47"))
|
||||||
|
|
||||||
// Check that we can read back the IPs
|
// Check that we can read back the IPs
|
||||||
m1, err := h.GetMachineByID(1)
|
machine1, err := app.GetMachineByID(1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(m1.IPAddress, check.Equals, netaddr.MustParseIP("10.27.0.1").String())
|
c.Assert(
|
||||||
|
machine1.IPAddress,
|
||||||
|
check.Equals,
|
||||||
|
netaddr.MustParseIP("10.27.0.1").String(),
|
||||||
|
)
|
||||||
|
|
||||||
m50, err := h.GetMachineByID(50)
|
machine50, err := app.GetMachineByID(50)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(m50.IPAddress, check.Equals, netaddr.MustParseIP("10.27.0.50").String())
|
c.Assert(
|
||||||
|
machine50.IPAddress,
|
||||||
|
check.Equals,
|
||||||
|
netaddr.MustParseIP("10.27.0.50").String(),
|
||||||
|
)
|
||||||
|
|
||||||
expectedNextIP := netaddr.MustParseIP("10.27.1.97")
|
expectedNextIP := netaddr.MustParseIP("10.27.1.97")
|
||||||
nextIP, err := h.getAvailableIP()
|
nextIP, err := app.getAvailableIP()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
c.Assert(nextIP.String(), check.Equals, expectedNextIP.String())
|
c.Assert(nextIP.String(), check.Equals, expectedNextIP.String())
|
||||||
|
|
||||||
// If we call get Available again, we should receive
|
// If we call get Available again, we should receive
|
||||||
// the same IP, as it has not been reserved.
|
// the same IP, as it has not been reserved.
|
||||||
nextIP2, err := h.getAvailableIP()
|
nextIP2, err := app.getAvailableIP()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
c.Assert(nextIP2.String(), check.Equals, expectedNextIP.String())
|
c.Assert(nextIP2.String(), check.Equals, expectedNextIP.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
|
func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
|
||||||
ip, err := h.getAvailableIP()
|
ip, err := app.getAvailableIP()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
expected := netaddr.MustParseIP("10.27.0.1")
|
expected := netaddr.MustParseIP("10.27.0.1")
|
||||||
|
|
||||||
c.Assert(ip.String(), check.Equals, expected.String())
|
c.Assert(ip.String(), check.Equals, expected.String())
|
||||||
|
|
||||||
n, err := h.CreateNamespace("test_ip")
|
namespace, err := app.CreateNamespace("test_ip")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = h.GetMachine("test", "testmachine")
|
_, err = app.GetMachine("test", "testmachine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
m := Machine{
|
machine := Machine{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: "foo",
|
MachineKey: "foo",
|
||||||
NodeKey: "bar",
|
NodeKey: "bar",
|
||||||
DiscoKey: "faa",
|
DiscoKey: "faa",
|
||||||
Name: "testmachine",
|
Name: "testmachine",
|
||||||
NamespaceID: n.ID,
|
NamespaceID: namespace.ID,
|
||||||
Registered: true,
|
Registered: true,
|
||||||
RegisterMethod: "authKey",
|
RegisterMethod: "authKey",
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
h.db.Save(&m)
|
app.db.Save(&machine)
|
||||||
|
|
||||||
ip2, err := h.getAvailableIP()
|
ip2, err := app.getAvailableIP()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
c.Assert(ip2.String(), check.Equals, expected.String())
|
c.Assert(ip2.String(), check.Equals, expected.String())
|
||||||
|
|
Loading…
Reference in a new issue