mirror of
https://github.com/juanfont/headscale.git
synced 2024-11-29 18:33:05 +00:00
introduce mapper package
The mapper package contains functions related to creating and marshalling reponses to machines. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
699655a93f
commit
f7f472ae07
18 changed files with 780 additions and 1063 deletions
47
flake.nix
47
flake.nix
|
@ -6,39 +6,36 @@
|
||||||
flake-utils.url = "github:numtide/flake-utils";
|
flake-utils.url = "github:numtide/flake-utils";
|
||||||
};
|
};
|
||||||
|
|
||||||
outputs =
|
outputs = {
|
||||||
{ self
|
self,
|
||||||
, nixpkgs
|
nixpkgs,
|
||||||
, flake-utils
|
flake-utils,
|
||||||
, ...
|
...
|
||||||
}:
|
}: let
|
||||||
let
|
|
||||||
headscaleVersion =
|
headscaleVersion =
|
||||||
if (self ? shortRev)
|
if (self ? shortRev)
|
||||||
then self.shortRev
|
then self.shortRev
|
||||||
else "dev";
|
else "dev";
|
||||||
in
|
in
|
||||||
{
|
{
|
||||||
overlay = _: prev:
|
overlay = _: prev: let
|
||||||
let
|
|
||||||
pkgs = nixpkgs.legacyPackages.${prev.system};
|
pkgs = nixpkgs.legacyPackages.${prev.system};
|
||||||
in
|
in rec {
|
||||||
rec {
|
|
||||||
headscale = pkgs.buildGo120Module rec {
|
headscale = pkgs.buildGo120Module rec {
|
||||||
pname = "headscale";
|
pname = "headscale";
|
||||||
version = headscaleVersion;
|
version = headscaleVersion;
|
||||||
src = pkgs.lib.cleanSource self;
|
src = pkgs.lib.cleanSource self;
|
||||||
|
|
||||||
tags = [ "ts2019" ];
|
tags = ["ts2019"];
|
||||||
|
|
||||||
# Only run unit tests when testing a build
|
# Only run unit tests when testing a build
|
||||||
checkFlags = [ "-short" ];
|
checkFlags = ["-short"];
|
||||||
|
|
||||||
# When updating go.mod or go.sum, a new sha will need to be calculated,
|
# When updating go.mod or go.sum, a new sha will need to be calculated,
|
||||||
# update this if you have a mismatch after doing a change to thos files.
|
# update this if you have a mismatch after doing a change to thos files.
|
||||||
vendorSha256 = "sha256-IOkbbFtE6+tNKnglE/8ZuNxhPSnloqM2sLgTvagMmnc=";
|
vendorSha256 = "sha256-ui0V7a8bAAm5B7zfN9g2pWTyMpudtm10RgYWQwC6kcA=";
|
||||||
|
|
||||||
ldflags = [ "-s" "-w" "-X github.com/juanfont/headscale/cmd/headscale/cli.Version=v${version}" ];
|
ldflags = ["-s" "-w" "-X github.com/juanfont/headscale/cmd/headscale/cli.Version=v${version}"];
|
||||||
};
|
};
|
||||||
|
|
||||||
golines = pkgs.buildGoModule rec {
|
golines = pkgs.buildGoModule rec {
|
||||||
|
@ -54,7 +51,7 @@
|
||||||
|
|
||||||
vendorSha256 = "sha256-rxYuzn4ezAxaeDhxd8qdOzt+CKYIh03A9zKNdzILq18=";
|
vendorSha256 = "sha256-rxYuzn4ezAxaeDhxd8qdOzt+CKYIh03A9zKNdzILq18=";
|
||||||
|
|
||||||
nativeBuildInputs = [ pkgs.installShellFiles ];
|
nativeBuildInputs = [pkgs.installShellFiles];
|
||||||
};
|
};
|
||||||
|
|
||||||
golangci-lint = prev.golangci-lint.override {
|
golangci-lint = prev.golangci-lint.override {
|
||||||
|
@ -76,20 +73,19 @@
|
||||||
|
|
||||||
vendorSha256 = "sha256-dGdnDuRbwg8fU7uB5GaHEWa/zI3w06onqjturvooJQA=";
|
vendorSha256 = "sha256-dGdnDuRbwg8fU7uB5GaHEWa/zI3w06onqjturvooJQA=";
|
||||||
|
|
||||||
nativeBuildInputs = [ pkgs.installShellFiles ];
|
nativeBuildInputs = [pkgs.installShellFiles];
|
||||||
|
|
||||||
subPackages = [ "protoc-gen-grpc-gateway" "protoc-gen-openapiv2" ];
|
subPackages = ["protoc-gen-grpc-gateway" "protoc-gen-openapiv2"];
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
// flake-utils.lib.eachDefaultSystem
|
// flake-utils.lib.eachDefaultSystem
|
||||||
(system:
|
(system: let
|
||||||
let
|
|
||||||
pkgs = import nixpkgs {
|
pkgs = import nixpkgs {
|
||||||
overlays = [ self.overlay ];
|
overlays = [self.overlay];
|
||||||
inherit system;
|
inherit system;
|
||||||
};
|
};
|
||||||
buildDeps = with pkgs; [ git go_1_20 gnumake ];
|
buildDeps = with pkgs; [git go_1_20 gnumake];
|
||||||
devDeps = with pkgs;
|
devDeps = with pkgs;
|
||||||
buildDeps
|
buildDeps
|
||||||
++ [
|
++ [
|
||||||
|
@ -123,11 +119,10 @@
|
||||||
headscale-docker = pkgs.dockerTools.buildLayeredImage {
|
headscale-docker = pkgs.dockerTools.buildLayeredImage {
|
||||||
name = "headscale";
|
name = "headscale";
|
||||||
tag = headscaleVersion;
|
tag = headscaleVersion;
|
||||||
contents = [ pkgs.headscale ];
|
contents = [pkgs.headscale];
|
||||||
config.Entrypoint = [ (pkgs.headscale + "/bin/headscale") ];
|
config.Entrypoint = [(pkgs.headscale + "/bin/headscale")];
|
||||||
};
|
};
|
||||||
in
|
in rec {
|
||||||
rec {
|
|
||||||
# `nix develop`
|
# `nix develop`
|
||||||
devShell = pkgs.mkShell {
|
devShell = pkgs.mkShell {
|
||||||
buildInputs = devDeps;
|
buildInputs = devDeps;
|
||||||
|
|
|
@ -1,115 +0,0 @@
|
||||||
package hscontrol
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
"tailscale.com/tailcfg"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (h *Headscale) generateMapResponse(
|
|
||||||
mapRequest tailcfg.MapRequest,
|
|
||||||
machine *types.Machine,
|
|
||||||
) (*tailcfg.MapResponse, error) {
|
|
||||||
log.Trace().
|
|
||||||
Str("func", "generateMapResponse").
|
|
||||||
Str("machine", mapRequest.Hostinfo.Hostname).
|
|
||||||
Msg("Creating Map response")
|
|
||||||
node, err := h.db.TailNode(*machine, h.ACLPolicy, h.cfg.DNSConfig)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Caller().
|
|
||||||
Str("func", "generateMapResponse").
|
|
||||||
Err(err).
|
|
||||||
Msg("Cannot convert to node")
|
|
||||||
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
peers, err := h.db.GetValidPeers(h.aclRules, machine)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Caller().
|
|
||||||
Str("func", "generateMapResponse").
|
|
||||||
Err(err).
|
|
||||||
Msg("Cannot fetch peers")
|
|
||||||
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
profiles := h.db.GetMapResponseUserProfiles(*machine, peers)
|
|
||||||
|
|
||||||
nodePeers, err := h.db.TailNodes(peers, h.ACLPolicy, h.cfg.DNSConfig)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Caller().
|
|
||||||
Str("func", "generateMapResponse").
|
|
||||||
Err(err).
|
|
||||||
Msg("Failed to convert peers to Tailscale nodes")
|
|
||||||
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
dnsConfig := getMapResponseDNSConfig(
|
|
||||||
h.cfg.DNSConfig,
|
|
||||||
h.cfg.BaseDomain,
|
|
||||||
*machine,
|
|
||||||
peers,
|
|
||||||
)
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
|
|
||||||
resp := tailcfg.MapResponse{
|
|
||||||
KeepAlive: false,
|
|
||||||
Node: node,
|
|
||||||
|
|
||||||
// TODO: Only send if updated
|
|
||||||
DERPMap: h.DERPMap,
|
|
||||||
|
|
||||||
// TODO: Only send if updated
|
|
||||||
Peers: nodePeers,
|
|
||||||
|
|
||||||
// TODO(kradalby): Implement:
|
|
||||||
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L1351-L1374
|
|
||||||
// PeersChanged
|
|
||||||
// PeersRemoved
|
|
||||||
// PeersChangedPatch
|
|
||||||
// PeerSeenChange
|
|
||||||
// OnlineChange
|
|
||||||
|
|
||||||
// TODO: Only send if updated
|
|
||||||
DNSConfig: dnsConfig,
|
|
||||||
|
|
||||||
// TODO: Only send if updated
|
|
||||||
Domain: h.cfg.BaseDomain,
|
|
||||||
|
|
||||||
// Do not instruct clients to collect services, we do not
|
|
||||||
// support or do anything with them
|
|
||||||
CollectServices: "false",
|
|
||||||
|
|
||||||
// TODO: Only send if updated
|
|
||||||
PacketFilter: h.aclRules,
|
|
||||||
|
|
||||||
UserProfiles: profiles,
|
|
||||||
|
|
||||||
// TODO: Only send if updated
|
|
||||||
SSHPolicy: h.sshPolicy,
|
|
||||||
|
|
||||||
ControlTime: &now,
|
|
||||||
|
|
||||||
Debug: &tailcfg.Debug{
|
|
||||||
DisableLogTail: !h.cfg.LogTail.Enabled,
|
|
||||||
RandomizeClientPort: h.cfg.RandomizeClientPort,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Trace().
|
|
||||||
Str("func", "generateMapResponse").
|
|
||||||
Str("machine", mapRequest.Hostinfo.Hostname).
|
|
||||||
// Interface("payload", resp).
|
|
||||||
Msgf("Generated map response: %s", util.TailMapResponseToString(resp))
|
|
||||||
|
|
||||||
return &resp, nil
|
|
||||||
}
|
|
|
@ -80,7 +80,7 @@ type Headscale struct {
|
||||||
dbString string
|
dbString string
|
||||||
dbType string
|
dbType string
|
||||||
dbDebug bool
|
dbDebug bool
|
||||||
privateKey *key.MachinePrivate
|
privateKey2019 *key.MachinePrivate
|
||||||
noisePrivateKey *key.MachinePrivate
|
noisePrivateKey *key.MachinePrivate
|
||||||
|
|
||||||
DERPMap *tailcfg.DERPMap
|
DERPMap *tailcfg.DERPMap
|
||||||
|
@ -166,7 +166,7 @@ func NewHeadscale(cfg *Config) (*Headscale, error) {
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
dbType: cfg.DBtype,
|
dbType: cfg.DBtype,
|
||||||
dbString: dbString,
|
dbString: dbString,
|
||||||
privateKey: privateKey,
|
privateKey2019: privateKey,
|
||||||
noisePrivateKey: noisePrivateKey,
|
noisePrivateKey: noisePrivateKey,
|
||||||
aclRules: tailcfg.FilterAllowAll, // default allowall
|
aclRules: tailcfg.FilterAllowAll, // default allowall
|
||||||
registrationCache: registrationCache,
|
registrationCache: registrationCache,
|
||||||
|
|
|
@ -39,15 +39,7 @@ var (
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
// filterMachinesByACL wrapper function to not have devs pass around locks and maps
|
// ListPeers returns all peers of machine, regardless of any Policy.
|
||||||
// related to the application outside of tests.
|
|
||||||
func (hsdb *HSDatabase) filterMachinesByACL(
|
|
||||||
aclRules []tailcfg.FilterRule,
|
|
||||||
currentMachine *types.Machine, peers types.Machines,
|
|
||||||
) types.Machines {
|
|
||||||
return policy.FilterMachinesByACL(currentMachine, peers, aclRules)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hsdb *HSDatabase) ListPeers(machine *types.Machine) (types.Machines, error) {
|
func (hsdb *HSDatabase) ListPeers(machine *types.Machine) (types.Machines, error) {
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -72,67 +64,6 @@ func (hsdb *HSDatabase) ListPeers(machine *types.Machine) (types.Machines, error
|
||||||
return machines, nil
|
return machines, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) getPeers(
|
|
||||||
aclRules []tailcfg.FilterRule,
|
|
||||||
machine *types.Machine,
|
|
||||||
) (types.Machines, error) {
|
|
||||||
var peers types.Machines
|
|
||||||
var err error
|
|
||||||
|
|
||||||
// If ACLs rules are defined, filter visible host list with the ACLs
|
|
||||||
// else use the classic user scope
|
|
||||||
if len(aclRules) > 0 {
|
|
||||||
var machines []types.Machine
|
|
||||||
machines, err = hsdb.ListMachines()
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("Error retrieving list of machines")
|
|
||||||
|
|
||||||
return types.Machines{}, err
|
|
||||||
}
|
|
||||||
peers = hsdb.filterMachinesByACL(aclRules, machine, machines)
|
|
||||||
} else {
|
|
||||||
peers, err = hsdb.ListPeers(machine)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Caller().
|
|
||||||
Err(err).
|
|
||||||
Msg("Cannot fetch peers")
|
|
||||||
|
|
||||||
return types.Machines{}, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
sort.Slice(peers, func(i, j int) bool { return peers[i].ID < peers[j].ID })
|
|
||||||
|
|
||||||
log.Trace().
|
|
||||||
Caller().
|
|
||||||
Str("self", machine.Hostname).
|
|
||||||
Str("peers", peers.String()).
|
|
||||||
Msg("Peers returned to caller")
|
|
||||||
|
|
||||||
return peers, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hsdb *HSDatabase) GetValidPeers(
|
|
||||||
aclRules []tailcfg.FilterRule,
|
|
||||||
machine *types.Machine,
|
|
||||||
) (types.Machines, error) {
|
|
||||||
validPeers := make(types.Machines, 0)
|
|
||||||
|
|
||||||
peers, err := hsdb.getPeers(aclRules, machine)
|
|
||||||
if err != nil {
|
|
||||||
return types.Machines{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, peer := range peers {
|
|
||||||
if !peer.IsExpired() {
|
|
||||||
validPeers = append(validPeers, peer)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return validPeers, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hsdb *HSDatabase) ListMachines() ([]types.Machine, error) {
|
func (hsdb *HSDatabase) ListMachines() ([]types.Machine, error) {
|
||||||
machines := []types.Machine{}
|
machines := []types.Machine{}
|
||||||
if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Find(&machines).Error; err != nil {
|
if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Find(&machines).Error; err != nil {
|
||||||
|
|
|
@ -293,8 +293,8 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
||||||
aclRules, _, err := policy.GenerateFilterRules(aclPolicy, machines, false)
|
aclRules, _, err := policy.GenerateFilterRules(aclPolicy, machines, false)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
peersOfTestMachine := db.filterMachinesByACL(aclRules, testMachine, machines)
|
peersOfTestMachine := policy.FilterMachinesByACL(testMachine, machines, aclRules)
|
||||||
peersOfAdminMachine := db.filterMachinesByACL(aclRules, adminMachine, machines)
|
peersOfAdminMachine := policy.FilterMachinesByACL(adminMachine, machines, aclRules)
|
||||||
|
|
||||||
c.Log(peersOfTestMachine)
|
c.Log(peersOfTestMachine)
|
||||||
c.Assert(len(peersOfTestMachine), check.Equals, 9)
|
c.Assert(len(peersOfTestMachine), check.Equals, 9)
|
||||||
|
|
|
@ -2,13 +2,11 @@ package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"tailscale.com/tailcfg"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -163,32 +161,3 @@ func (hsdb *HSDatabase) SetMachineUser(machine *types.Machine, username string)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) GetMapResponseUserProfiles(
|
|
||||||
machine types.Machine,
|
|
||||||
peers types.Machines,
|
|
||||||
) []tailcfg.UserProfile {
|
|
||||||
userMap := make(map[string]types.User)
|
|
||||||
userMap[machine.User.Name] = machine.User
|
|
||||||
for _, peer := range peers {
|
|
||||||
userMap[peer.User.Name] = peer.User // not worth checking if already is there
|
|
||||||
}
|
|
||||||
|
|
||||||
profiles := []tailcfg.UserProfile{}
|
|
||||||
for _, user := range userMap {
|
|
||||||
displayName := user.Name
|
|
||||||
|
|
||||||
if hsdb.baseDomain != "" {
|
|
||||||
displayName = fmt.Sprintf("%s@%s", user.Name, hsdb.baseDomain)
|
|
||||||
}
|
|
||||||
|
|
||||||
profiles = append(profiles,
|
|
||||||
tailcfg.UserProfile{
|
|
||||||
ID: tailcfg.UserID(user.ID),
|
|
||||||
LoginName: user.Name,
|
|
||||||
DisplayName: displayName,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return profiles
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,13 +1,10 @@
|
||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/netip"
|
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"gopkg.in/check.v1"
|
"gopkg.in/check.v1"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"tailscale.com/tailcfg"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Suite) TestCreateAndDestroyUser(c *check.C) {
|
func (s *Suite) TestCreateAndDestroyUser(c *check.C) {
|
||||||
|
@ -94,151 +91,6 @@ func (s *Suite) TestRenameUser(c *check.C) {
|
||||||
c.Assert(err, check.Equals, ErrUserExists)
|
c.Assert(err, check.Equals, ErrUserExists)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
|
|
||||||
userShared1, err := db.CreateUser("shared1")
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
userShared2, err := db.CreateUser("shared2")
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
userShared3, err := db.CreateUser("shared3")
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
preAuthKeyShared1, err := db.CreatePreAuthKey(
|
|
||||||
userShared1.Name,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
preAuthKeyShared2, err := db.CreatePreAuthKey(
|
|
||||||
userShared2.Name,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
preAuthKeyShared3, err := db.CreatePreAuthKey(
|
|
||||||
userShared3.Name,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
preAuthKey2Shared1, err := db.CreatePreAuthKey(
|
|
||||||
userShared1.Name,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
_, err = db.GetMachine(userShared1.Name, "test_get_shared_nodes_1")
|
|
||||||
c.Assert(err, check.NotNil)
|
|
||||||
|
|
||||||
machineInShared1 := &types.Machine{
|
|
||||||
ID: 1,
|
|
||||||
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
|
||||||
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
|
||||||
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
|
||||||
Hostname: "test_get_shared_nodes_1",
|
|
||||||
UserID: userShared1.ID,
|
|
||||||
User: *userShared1,
|
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
|
||||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
|
|
||||||
AuthKeyID: uint(preAuthKeyShared1.ID),
|
|
||||||
}
|
|
||||||
db.db.Save(machineInShared1)
|
|
||||||
|
|
||||||
_, err = db.GetMachine(userShared1.Name, machineInShared1.Hostname)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
machineInShared2 := &types.Machine{
|
|
||||||
ID: 2,
|
|
||||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
|
||||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
|
||||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
|
||||||
Hostname: "test_get_shared_nodes_2",
|
|
||||||
UserID: userShared2.ID,
|
|
||||||
User: *userShared2,
|
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
|
||||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
|
|
||||||
AuthKeyID: uint(preAuthKeyShared2.ID),
|
|
||||||
}
|
|
||||||
db.db.Save(machineInShared2)
|
|
||||||
|
|
||||||
_, err = db.GetMachine(userShared2.Name, machineInShared2.Hostname)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
machineInShared3 := &types.Machine{
|
|
||||||
ID: 3,
|
|
||||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
|
||||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
|
||||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
|
||||||
Hostname: "test_get_shared_nodes_3",
|
|
||||||
UserID: userShared3.ID,
|
|
||||||
User: *userShared3,
|
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
|
||||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
|
|
||||||
AuthKeyID: uint(preAuthKeyShared3.ID),
|
|
||||||
}
|
|
||||||
db.db.Save(machineInShared3)
|
|
||||||
|
|
||||||
_, err = db.GetMachine(userShared3.Name, machineInShared3.Hostname)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
machine2InShared1 := &types.Machine{
|
|
||||||
ID: 4,
|
|
||||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
|
||||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
|
||||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
|
||||||
Hostname: "test_get_shared_nodes_4",
|
|
||||||
UserID: userShared1.ID,
|
|
||||||
User: *userShared1,
|
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
|
||||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")},
|
|
||||||
AuthKeyID: uint(preAuthKey2Shared1.ID),
|
|
||||||
}
|
|
||||||
db.db.Save(machine2InShared1)
|
|
||||||
|
|
||||||
peersOfMachine1InShared1, err := db.getPeers([]tailcfg.FilterRule{}, machineInShared1)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
userProfiles := db.GetMapResponseUserProfiles(
|
|
||||||
*machineInShared1,
|
|
||||||
peersOfMachine1InShared1,
|
|
||||||
)
|
|
||||||
|
|
||||||
c.Assert(len(userProfiles), check.Equals, 3)
|
|
||||||
|
|
||||||
found := false
|
|
||||||
for _, userProfiles := range userProfiles {
|
|
||||||
if userProfiles.DisplayName == userShared1.Name {
|
|
||||||
found = true
|
|
||||||
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.Assert(found, check.Equals, true)
|
|
||||||
|
|
||||||
found = false
|
|
||||||
for _, userProfile := range userProfiles {
|
|
||||||
if userProfile.DisplayName == userShared2.Name {
|
|
||||||
found = true
|
|
||||||
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.Assert(found, check.Equals, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Suite) TestSetMachineUser(c *check.C) {
|
func (s *Suite) TestSetMachineUser(c *check.C) {
|
||||||
oldUser, err := db.CreateUser("old")
|
oldUser, err := db.CreateUser("old")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
|
@ -32,7 +32,7 @@ type DERPServer struct {
|
||||||
|
|
||||||
func (h *Headscale) NewDERPServer() (*DERPServer, error) {
|
func (h *Headscale) NewDERPServer() (*DERPServer, error) {
|
||||||
log.Trace().Caller().Msg("Creating new embedded DERP server")
|
log.Trace().Caller().Msg("Creating new embedded DERP server")
|
||||||
server := derp.NewServer(key.NodePrivate(*h.privateKey), log.Info().Msgf)
|
server := derp.NewServer(key.NodePrivate(*h.privateKey2019), log.Info().Msgf)
|
||||||
region, err := h.generateRegionLocalDERP()
|
region, err := h.generateRegionLocalDERP()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -156,7 +156,7 @@ func (h *Headscale) DERPHandler(
|
||||||
log.Trace().Caller().Msgf("Hijacked connection from %v", req.RemoteAddr)
|
log.Trace().Caller().Msgf("Hijacked connection from %v", req.RemoteAddr)
|
||||||
|
|
||||||
if !fastStart {
|
if !fastStart {
|
||||||
pubKey := h.privateKey.Public()
|
pubKey := h.privateKey2019.Public()
|
||||||
pubKeyStr, _ := pubKey.MarshalText() //nolint
|
pubKeyStr, _ := pubKey.MarshalText() //nolint
|
||||||
fmt.Fprintf(conn, "HTTP/1.1 101 Switching Protocols\r\n"+
|
fmt.Fprintf(conn, "HTTP/1.1 101 Switching Protocols\r\n"+
|
||||||
"Upgrade: DERP\r\n"+
|
"Upgrade: DERP\r\n"+
|
||||||
|
|
|
@ -3,14 +3,9 @@ package hscontrol
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"net/url"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
mapset "github.com/deckarep/golang-set/v2"
|
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
|
||||||
"go4.org/netipx"
|
"go4.org/netipx"
|
||||||
"tailscale.com/tailcfg"
|
|
||||||
"tailscale.com/types/dnstype"
|
|
||||||
"tailscale.com/util/dnsname"
|
"tailscale.com/util/dnsname"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -23,10 +18,6 @@ const (
|
||||||
ipv6AddressLength = 128
|
ipv6AddressLength = 128
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
nextDNSDoHPrefix = "https://dns.nextdns.io"
|
|
||||||
)
|
|
||||||
|
|
||||||
// generateMagicDNSRootDomains generates a list of DNS entries to be included in `Routes` in `MapResponse`.
|
// generateMagicDNSRootDomains generates a list of DNS entries to be included in `Routes` in `MapResponse`.
|
||||||
// This list of reverse DNS entries instructs the OS on what subnets and domains the Tailscale embedded DNS
|
// This list of reverse DNS entries instructs the OS on what subnets and domains the Tailscale embedded DNS
|
||||||
// server (listening in 100.100.100.100 udp/53) should be used for.
|
// server (listening in 100.100.100.100 udp/53) should be used for.
|
||||||
|
@ -158,63 +149,3 @@ func generateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
|
||||||
|
|
||||||
return fqdns
|
return fqdns
|
||||||
}
|
}
|
||||||
|
|
||||||
// If any nextdns DoH resolvers are present in the list of resolvers it will
|
|
||||||
// take metadata from the machine metadata and instruct tailscale to add it
|
|
||||||
// to the requests. This makes it possible to identify from which device the
|
|
||||||
// requests come in the NextDNS dashboard.
|
|
||||||
//
|
|
||||||
// This will produce a resolver like:
|
|
||||||
// `https://dns.nextdns.io/<nextdns-id>?device_name=node-name&device_model=linux&device_ip=100.64.0.1`
|
|
||||||
func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine types.Machine) {
|
|
||||||
for _, resolver := range resolvers {
|
|
||||||
if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) {
|
|
||||||
attrs := url.Values{
|
|
||||||
"device_name": []string{machine.Hostname},
|
|
||||||
"device_model": []string{machine.HostInfo.OS},
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(machine.IPAddresses) > 0 {
|
|
||||||
attrs.Add("device_ip", machine.IPAddresses[0].String())
|
|
||||||
}
|
|
||||||
|
|
||||||
resolver.Addr = fmt.Sprintf("%s?%s", resolver.Addr, attrs.Encode())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getMapResponseDNSConfig(
|
|
||||||
dnsConfigOrig *tailcfg.DNSConfig,
|
|
||||||
baseDomain string,
|
|
||||||
machine types.Machine,
|
|
||||||
peers types.Machines,
|
|
||||||
) *tailcfg.DNSConfig {
|
|
||||||
var dnsConfig *tailcfg.DNSConfig = dnsConfigOrig.Clone()
|
|
||||||
if dnsConfigOrig != nil && dnsConfigOrig.Proxied { // if MagicDNS is enabled
|
|
||||||
// Only inject the Search Domain of the current user - shared nodes should use their full FQDN
|
|
||||||
dnsConfig.Domains = append(
|
|
||||||
dnsConfig.Domains,
|
|
||||||
fmt.Sprintf(
|
|
||||||
"%s.%s",
|
|
||||||
machine.User.Name,
|
|
||||||
baseDomain,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
userSet := mapset.NewSet[types.User]()
|
|
||||||
userSet.Add(machine.User)
|
|
||||||
for _, p := range peers {
|
|
||||||
userSet.Add(p.User)
|
|
||||||
}
|
|
||||||
for _, user := range userSet.ToSlice() {
|
|
||||||
dnsRoute := fmt.Sprintf("%v.%v", user.Name, baseDomain)
|
|
||||||
dnsConfig.Routes[dnsRoute] = nil
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
dnsConfig = dnsConfigOrig
|
|
||||||
}
|
|
||||||
|
|
||||||
addNextDNSMetadata(dnsConfig.Resolvers, machine)
|
|
||||||
|
|
||||||
return dnsConfig
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,14 +1,9 @@
|
||||||
package hscontrol
|
package hscontrol
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
|
||||||
"gopkg.in/check.v1"
|
"gopkg.in/check.v1"
|
||||||
"tailscale.com/tailcfg"
|
|
||||||
"tailscale.com/types/dnstype"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Suite) TestMagicDNSRootDomains100(c *check.C) {
|
func (s *Suite) TestMagicDNSRootDomains100(c *check.C) {
|
||||||
|
@ -112,293 +107,3 @@ func (s *Suite) TestMagicDNSRootDomainsIPv6SingleMultiple(c *check.C) {
|
||||||
c.Assert(yieldsRoot("2.0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa."), check.Equals, true)
|
c.Assert(yieldsRoot("2.0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa."), check.Equals, true)
|
||||||
c.Assert(yieldsRoot("3.0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa."), check.Equals, true)
|
c.Assert(yieldsRoot("3.0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa."), check.Equals, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
|
|
||||||
userShared1, err := app.db.CreateUser("shared1")
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
userShared2, err := app.db.CreateUser("shared2")
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
userShared3, err := app.db.CreateUser("shared3")
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
preAuthKeyInShared1, err := app.db.CreatePreAuthKey(
|
|
||||||
userShared1.Name,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
preAuthKeyInShared2, err := app.db.CreatePreAuthKey(
|
|
||||||
userShared2.Name,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
preAuthKeyInShared3, err := app.db.CreatePreAuthKey(
|
|
||||||
userShared3.Name,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
PreAuthKey2InShared1, err := app.db.CreatePreAuthKey(
|
|
||||||
userShared1.Name,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
_, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1")
|
|
||||||
c.Assert(err, check.NotNil)
|
|
||||||
|
|
||||||
machineInShared1 := &types.Machine{
|
|
||||||
ID: 1,
|
|
||||||
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
|
||||||
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
|
||||||
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
|
||||||
Hostname: "test_get_shared_nodes_1",
|
|
||||||
UserID: userShared1.ID,
|
|
||||||
User: *userShared1,
|
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
|
||||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
|
|
||||||
AuthKeyID: uint(preAuthKeyInShared1.ID),
|
|
||||||
}
|
|
||||||
err = app.db.MachineSave(machineInShared1)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
_, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
machineInShared2 := &types.Machine{
|
|
||||||
ID: 2,
|
|
||||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
|
||||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
|
||||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
|
||||||
Hostname: "test_get_shared_nodes_2",
|
|
||||||
UserID: userShared2.ID,
|
|
||||||
User: *userShared2,
|
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
|
||||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
|
|
||||||
AuthKeyID: uint(preAuthKeyInShared2.ID),
|
|
||||||
}
|
|
||||||
err = app.db.MachineSave(machineInShared2)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
_, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
machineInShared3 := &types.Machine{
|
|
||||||
ID: 3,
|
|
||||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
|
||||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
|
||||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
|
||||||
Hostname: "test_get_shared_nodes_3",
|
|
||||||
UserID: userShared3.ID,
|
|
||||||
User: *userShared3,
|
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
|
||||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
|
|
||||||
AuthKeyID: uint(preAuthKeyInShared3.ID),
|
|
||||||
}
|
|
||||||
err = app.db.MachineSave(machineInShared3)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
_, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
machine2InShared1 := &types.Machine{
|
|
||||||
ID: 4,
|
|
||||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
|
||||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
|
||||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
|
||||||
Hostname: "test_get_shared_nodes_4",
|
|
||||||
UserID: userShared1.ID,
|
|
||||||
User: *userShared1,
|
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
|
||||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")},
|
|
||||||
AuthKeyID: uint(PreAuthKey2InShared1.ID),
|
|
||||||
}
|
|
||||||
err = app.db.MachineSave(machine2InShared1)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
baseDomain := "foobar.headscale.net"
|
|
||||||
dnsConfigOrig := tailcfg.DNSConfig{
|
|
||||||
Routes: make(map[string][]*dnstype.Resolver),
|
|
||||||
Domains: []string{baseDomain},
|
|
||||||
Proxied: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
peersOfMachineInShared1, err := app.db.GetValidPeers(app.aclRules, machineInShared1)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
dnsConfig := getMapResponseDNSConfig(
|
|
||||||
&dnsConfigOrig,
|
|
||||||
baseDomain,
|
|
||||||
*machineInShared1,
|
|
||||||
peersOfMachineInShared1,
|
|
||||||
)
|
|
||||||
c.Assert(dnsConfig, check.NotNil)
|
|
||||||
|
|
||||||
c.Assert(len(dnsConfig.Routes), check.Equals, 3)
|
|
||||||
|
|
||||||
domainRouteShared1 := fmt.Sprintf("%s.%s", userShared1.Name, baseDomain)
|
|
||||||
_, ok := dnsConfig.Routes[domainRouteShared1]
|
|
||||||
c.Assert(ok, check.Equals, true)
|
|
||||||
|
|
||||||
domainRouteShared2 := fmt.Sprintf("%s.%s", userShared2.Name, baseDomain)
|
|
||||||
_, ok = dnsConfig.Routes[domainRouteShared2]
|
|
||||||
c.Assert(ok, check.Equals, true)
|
|
||||||
|
|
||||||
domainRouteShared3 := fmt.Sprintf("%s.%s", userShared3.Name, baseDomain)
|
|
||||||
_, ok = dnsConfig.Routes[domainRouteShared3]
|
|
||||||
c.Assert(ok, check.Equals, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
|
|
||||||
userShared1, err := app.db.CreateUser("shared1")
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
userShared2, err := app.db.CreateUser("shared2")
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
userShared3, err := app.db.CreateUser("shared3")
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
preAuthKeyInShared1, err := app.db.CreatePreAuthKey(
|
|
||||||
userShared1.Name,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
preAuthKeyInShared2, err := app.db.CreatePreAuthKey(
|
|
||||||
userShared2.Name,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
preAuthKeyInShared3, err := app.db.CreatePreAuthKey(
|
|
||||||
userShared3.Name,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
preAuthKey2InShared1, err := app.db.CreatePreAuthKey(
|
|
||||||
userShared1.Name,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
_, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1")
|
|
||||||
c.Assert(err, check.NotNil)
|
|
||||||
|
|
||||||
machineInShared1 := &types.Machine{
|
|
||||||
ID: 1,
|
|
||||||
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
|
||||||
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
|
||||||
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
|
||||||
Hostname: "test_get_shared_nodes_1",
|
|
||||||
UserID: userShared1.ID,
|
|
||||||
User: *userShared1,
|
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
|
||||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
|
|
||||||
AuthKeyID: uint(preAuthKeyInShared1.ID),
|
|
||||||
}
|
|
||||||
err = app.db.MachineSave(machineInShared1)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
_, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
machineInShared2 := &types.Machine{
|
|
||||||
ID: 2,
|
|
||||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
|
||||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
|
||||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
|
||||||
Hostname: "test_get_shared_nodes_2",
|
|
||||||
UserID: userShared2.ID,
|
|
||||||
User: *userShared2,
|
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
|
||||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
|
|
||||||
AuthKeyID: uint(preAuthKeyInShared2.ID),
|
|
||||||
}
|
|
||||||
err = app.db.MachineSave(machineInShared2)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
_, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
machineInShared3 := &types.Machine{
|
|
||||||
ID: 3,
|
|
||||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
|
||||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
|
||||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
|
||||||
Hostname: "test_get_shared_nodes_3",
|
|
||||||
UserID: userShared3.ID,
|
|
||||||
User: *userShared3,
|
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
|
||||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
|
|
||||||
AuthKeyID: uint(preAuthKeyInShared3.ID),
|
|
||||||
}
|
|
||||||
err = app.db.MachineSave(machineInShared3)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
_, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
machine2InShared1 := &types.Machine{
|
|
||||||
ID: 4,
|
|
||||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
|
||||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
|
||||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
|
||||||
Hostname: "test_get_shared_nodes_4",
|
|
||||||
UserID: userShared1.ID,
|
|
||||||
User: *userShared1,
|
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
|
||||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")},
|
|
||||||
AuthKeyID: uint(preAuthKey2InShared1.ID),
|
|
||||||
}
|
|
||||||
err = app.db.MachineSave(machine2InShared1)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
baseDomain := "foobar.headscale.net"
|
|
||||||
dnsConfigOrig := tailcfg.DNSConfig{
|
|
||||||
Routes: make(map[string][]*dnstype.Resolver),
|
|
||||||
Domains: []string{baseDomain},
|
|
||||||
Proxied: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
peersOfMachine1Shared1, err := app.db.GetValidPeers(app.aclRules, machineInShared1)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
dnsConfig := getMapResponseDNSConfig(
|
|
||||||
&dnsConfigOrig,
|
|
||||||
baseDomain,
|
|
||||||
*machineInShared1,
|
|
||||||
peersOfMachine1Shared1,
|
|
||||||
)
|
|
||||||
c.Assert(dnsConfig, check.NotNil)
|
|
||||||
c.Assert(len(dnsConfig.Routes), check.Equals, 0)
|
|
||||||
c.Assert(len(dnsConfig.Domains), check.Equals, 1)
|
|
||||||
}
|
|
||||||
|
|
418
hscontrol/mapper/mapper.go
Normal file
418
hscontrol/mapper/mapper.go
Normal file
|
@ -0,0 +1,418 @@
|
||||||
|
package mapper
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
mapset "github.com/deckarep/golang-set/v2"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/db"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/policy"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
|
"github.com/klauspost/compress/zstd"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
"tailscale.com/smallzstd"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
"tailscale.com/types/dnstype"
|
||||||
|
"tailscale.com/types/key"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
nextDNSDoHPrefix = "https://dns.nextdns.io"
|
||||||
|
reservedResponseHeaderSize = 4
|
||||||
|
)
|
||||||
|
|
||||||
|
type Mapper struct {
|
||||||
|
db *db.HSDatabase
|
||||||
|
|
||||||
|
privateKey2019 *key.MachinePrivate
|
||||||
|
isNoise bool
|
||||||
|
|
||||||
|
// Configuration
|
||||||
|
// TODO(kradalby): figure out if this is the format we want this in
|
||||||
|
derpMap *tailcfg.DERPMap
|
||||||
|
baseDomain string
|
||||||
|
dnsCfg *tailcfg.DNSConfig
|
||||||
|
logtail bool
|
||||||
|
randomClientPort bool
|
||||||
|
stripEmailDomain bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMapper(
|
||||||
|
db *db.HSDatabase,
|
||||||
|
privateKey *key.MachinePrivate,
|
||||||
|
isNoise bool,
|
||||||
|
derpMap *tailcfg.DERPMap,
|
||||||
|
baseDomain string,
|
||||||
|
dnsCfg *tailcfg.DNSConfig,
|
||||||
|
logtail bool,
|
||||||
|
randomClientPort bool,
|
||||||
|
stripEmailDomain bool,
|
||||||
|
) *Mapper {
|
||||||
|
return &Mapper{
|
||||||
|
db: db,
|
||||||
|
|
||||||
|
privateKey2019: privateKey,
|
||||||
|
isNoise: isNoise,
|
||||||
|
|
||||||
|
derpMap: derpMap,
|
||||||
|
baseDomain: baseDomain,
|
||||||
|
dnsCfg: dnsCfg,
|
||||||
|
logtail: logtail,
|
||||||
|
randomClientPort: randomClientPort,
|
||||||
|
stripEmailDomain: stripEmailDomain,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Mapper) fullMapResponse(
|
||||||
|
mapRequest tailcfg.MapRequest,
|
||||||
|
machine *types.Machine,
|
||||||
|
pol *policy.ACLPolicy,
|
||||||
|
) (*tailcfg.MapResponse, error) {
|
||||||
|
log.Trace().
|
||||||
|
Caller().
|
||||||
|
Str("machine", mapRequest.Hostinfo.Hostname).
|
||||||
|
Msg("Creating Map response")
|
||||||
|
|
||||||
|
// TODO(kradalby): Decouple this from DB?
|
||||||
|
node, err := m.db.TailNode(*machine, pol, m.dnsCfg)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().
|
||||||
|
Caller().
|
||||||
|
Err(err).
|
||||||
|
Msg("Cannot convert to node")
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
peers, err := m.db.ListPeers(machine)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().
|
||||||
|
Caller().
|
||||||
|
Err(err).
|
||||||
|
Msg("Cannot fetch peers")
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rules, sshPolicy, err := policy.GenerateFilterRules(pol, peers, m.stripEmailDomain)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(rules) > 0 {
|
||||||
|
peers = policy.FilterMachinesByACL(machine, peers, rules)
|
||||||
|
}
|
||||||
|
|
||||||
|
profiles := generateUserProfiles(machine, peers, m.baseDomain)
|
||||||
|
|
||||||
|
// TODO(kradalby): Decouple this from DB?
|
||||||
|
nodePeers, err := m.db.TailNodes(peers, pol, m.dnsCfg)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().
|
||||||
|
Caller().
|
||||||
|
Err(err).
|
||||||
|
Msg("Failed to convert peers to Tailscale nodes")
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(kradalby): Shold this mutation happen before TailNode(s) is called?
|
||||||
|
dnsConfig := generateDNSConfig(
|
||||||
|
m.dnsCfg,
|
||||||
|
m.baseDomain,
|
||||||
|
*machine,
|
||||||
|
peers,
|
||||||
|
)
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
resp := tailcfg.MapResponse{
|
||||||
|
KeepAlive: false,
|
||||||
|
Node: node,
|
||||||
|
|
||||||
|
// TODO: Only send if updated
|
||||||
|
DERPMap: m.derpMap,
|
||||||
|
|
||||||
|
// TODO: Only send if updated
|
||||||
|
Peers: nodePeers,
|
||||||
|
|
||||||
|
// TODO(kradalby): Implement:
|
||||||
|
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L1351-L1374
|
||||||
|
// PeersChanged
|
||||||
|
// PeersRemoved
|
||||||
|
// PeersChangedPatch
|
||||||
|
// PeerSeenChange
|
||||||
|
// OnlineChange
|
||||||
|
|
||||||
|
// TODO: Only send if updated
|
||||||
|
DNSConfig: dnsConfig,
|
||||||
|
|
||||||
|
// TODO: Only send if updated
|
||||||
|
Domain: m.baseDomain,
|
||||||
|
|
||||||
|
// Do not instruct clients to collect services, we do not
|
||||||
|
// support or do anything with them
|
||||||
|
CollectServices: "false",
|
||||||
|
|
||||||
|
// TODO: Only send if updated
|
||||||
|
PacketFilter: rules,
|
||||||
|
|
||||||
|
UserProfiles: profiles,
|
||||||
|
|
||||||
|
// TODO: Only send if updated
|
||||||
|
SSHPolicy: sshPolicy,
|
||||||
|
|
||||||
|
ControlTime: &now,
|
||||||
|
|
||||||
|
Debug: &tailcfg.Debug{
|
||||||
|
DisableLogTail: !m.logtail,
|
||||||
|
RandomizeClientPort: m.randomClientPort,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Trace().
|
||||||
|
Caller().
|
||||||
|
Str("machine", mapRequest.Hostinfo.Hostname).
|
||||||
|
// Interface("payload", resp).
|
||||||
|
Msgf("Generated map response: %s", util.TailMapResponseToString(resp))
|
||||||
|
|
||||||
|
return &resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateUserProfiles(
|
||||||
|
machine *types.Machine,
|
||||||
|
peers types.Machines,
|
||||||
|
baseDomain string,
|
||||||
|
) []tailcfg.UserProfile {
|
||||||
|
userMap := make(map[string]types.User)
|
||||||
|
userMap[machine.User.Name] = machine.User
|
||||||
|
for _, peer := range peers {
|
||||||
|
userMap[peer.User.Name] = peer.User // not worth checking if already is there
|
||||||
|
}
|
||||||
|
|
||||||
|
profiles := []tailcfg.UserProfile{}
|
||||||
|
for _, user := range userMap {
|
||||||
|
displayName := user.Name
|
||||||
|
|
||||||
|
if baseDomain != "" {
|
||||||
|
displayName = fmt.Sprintf("%s@%s", user.Name, baseDomain)
|
||||||
|
}
|
||||||
|
|
||||||
|
profiles = append(profiles,
|
||||||
|
tailcfg.UserProfile{
|
||||||
|
ID: tailcfg.UserID(user.ID),
|
||||||
|
LoginName: user.Name,
|
||||||
|
DisplayName: displayName,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return profiles
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateDNSConfig(
|
||||||
|
base *tailcfg.DNSConfig,
|
||||||
|
baseDomain string,
|
||||||
|
machine types.Machine,
|
||||||
|
peers types.Machines,
|
||||||
|
) *tailcfg.DNSConfig {
|
||||||
|
dnsConfig := base.Clone()
|
||||||
|
|
||||||
|
// if MagicDNS is enabled
|
||||||
|
if base != nil && base.Proxied {
|
||||||
|
// Only inject the Search Domain of the current user
|
||||||
|
// shared nodes should use their full FQDN
|
||||||
|
dnsConfig.Domains = append(
|
||||||
|
dnsConfig.Domains,
|
||||||
|
fmt.Sprintf(
|
||||||
|
"%s.%s",
|
||||||
|
machine.User.Name,
|
||||||
|
baseDomain,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
userSet := mapset.NewSet[types.User]()
|
||||||
|
userSet.Add(machine.User)
|
||||||
|
for _, p := range peers {
|
||||||
|
userSet.Add(p.User)
|
||||||
|
}
|
||||||
|
for _, user := range userSet.ToSlice() {
|
||||||
|
dnsRoute := fmt.Sprintf("%v.%v", user.Name, baseDomain)
|
||||||
|
dnsConfig.Routes[dnsRoute] = nil
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
dnsConfig = base
|
||||||
|
}
|
||||||
|
|
||||||
|
addNextDNSMetadata(dnsConfig.Resolvers, machine)
|
||||||
|
|
||||||
|
return dnsConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// If any nextdns DoH resolvers are present in the list of resolvers it will
|
||||||
|
// take metadata from the machine metadata and instruct tailscale to add it
|
||||||
|
// to the requests. This makes it possible to identify from which device the
|
||||||
|
// requests come in the NextDNS dashboard.
|
||||||
|
//
|
||||||
|
// This will produce a resolver like:
|
||||||
|
// `https://dns.nextdns.io/<nextdns-id>?device_name=node-name&device_model=linux&device_ip=100.64.0.1`
|
||||||
|
func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine types.Machine) {
|
||||||
|
for _, resolver := range resolvers {
|
||||||
|
if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) {
|
||||||
|
attrs := url.Values{
|
||||||
|
"device_name": []string{machine.Hostname},
|
||||||
|
"device_model": []string{machine.HostInfo.OS},
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(machine.IPAddresses) > 0 {
|
||||||
|
attrs.Add("device_ip", machine.IPAddresses[0].String())
|
||||||
|
}
|
||||||
|
|
||||||
|
resolver.Addr = fmt.Sprintf("%s?%s", resolver.Addr, attrs.Encode())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Mapper) CreateMapResponse(
|
||||||
|
mapRequest tailcfg.MapRequest,
|
||||||
|
machine *types.Machine,
|
||||||
|
pol *policy.ACLPolicy,
|
||||||
|
) ([]byte, error) {
|
||||||
|
mapResponse, err := m.fullMapResponse(mapRequest, machine, pol)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.isNoise {
|
||||||
|
return m.marshalMapResponse(mapResponse, key.MachinePublic{}, mapRequest.Compress)
|
||||||
|
}
|
||||||
|
|
||||||
|
var machineKey key.MachinePublic
|
||||||
|
err = machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)))
|
||||||
|
if err != nil {
|
||||||
|
log.Error().
|
||||||
|
Caller().
|
||||||
|
Err(err).
|
||||||
|
Msg("Cannot parse client key")
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.marshalMapResponse(mapResponse, machineKey, mapRequest.Compress)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Mapper) CreateKeepAliveResponse(
|
||||||
|
mapRequest tailcfg.MapRequest,
|
||||||
|
machine *types.Machine,
|
||||||
|
) ([]byte, error) {
|
||||||
|
keepAliveResponse := tailcfg.MapResponse{
|
||||||
|
KeepAlive: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.isNoise {
|
||||||
|
return m.marshalMapResponse(
|
||||||
|
keepAliveResponse,
|
||||||
|
key.MachinePublic{},
|
||||||
|
mapRequest.Compress,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
var machineKey key.MachinePublic
|
||||||
|
err := machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)))
|
||||||
|
if err != nil {
|
||||||
|
log.Error().
|
||||||
|
Caller().
|
||||||
|
Err(err).
|
||||||
|
Msg("Cannot parse client key")
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.marshalMapResponse(keepAliveResponse, machineKey, mapRequest.Compress)
|
||||||
|
}
|
||||||
|
|
||||||
|
func MarshalResponse(
|
||||||
|
resp interface{},
|
||||||
|
privateKey2019 *key.MachinePrivate,
|
||||||
|
machineKey key.MachinePublic,
|
||||||
|
) ([]byte, error) {
|
||||||
|
jsonBody, err := json.Marshal(resp)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().
|
||||||
|
Caller().
|
||||||
|
Err(err).
|
||||||
|
Msg("Cannot marshal response")
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if privateKey2019 != nil {
|
||||||
|
return privateKey2019.SealTo(machineKey, jsonBody), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return jsonBody, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Mapper) marshalMapResponse(
|
||||||
|
resp interface{},
|
||||||
|
machineKey key.MachinePublic,
|
||||||
|
compression string,
|
||||||
|
) ([]byte, error) {
|
||||||
|
jsonBody, err := json.Marshal(resp)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().
|
||||||
|
Caller().
|
||||||
|
Err(err).
|
||||||
|
Msg("Cannot marshal map response")
|
||||||
|
}
|
||||||
|
|
||||||
|
var respBody []byte
|
||||||
|
if compression == util.ZstdCompression {
|
||||||
|
respBody = zstdEncode(jsonBody)
|
||||||
|
if !m.isNoise { // if legacy protocol
|
||||||
|
respBody = m.privateKey2019.SealTo(machineKey, respBody)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if !m.isNoise { // if legacy protocol
|
||||||
|
respBody = m.privateKey2019.SealTo(machineKey, jsonBody)
|
||||||
|
} else {
|
||||||
|
respBody = jsonBody
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
data := make([]byte, reservedResponseHeaderSize)
|
||||||
|
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
|
||||||
|
data = append(data, respBody...)
|
||||||
|
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func zstdEncode(in []byte) []byte {
|
||||||
|
encoder, ok := zstdEncoderPool.Get().(*zstd.Encoder)
|
||||||
|
if !ok {
|
||||||
|
panic("invalid type in sync pool")
|
||||||
|
}
|
||||||
|
out := encoder.EncodeAll(in, nil)
|
||||||
|
_ = encoder.Close()
|
||||||
|
zstdEncoderPool.Put(encoder)
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
var zstdEncoderPool = &sync.Pool{
|
||||||
|
New: func() any {
|
||||||
|
encoder, err := smallzstd.NewEncoder(
|
||||||
|
nil,
|
||||||
|
zstd.WithEncoderLevel(zstd.SpeedFastest))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return encoder
|
||||||
|
},
|
||||||
|
}
|
131
hscontrol/mapper/mapper_test.go
Normal file
131
hscontrol/mapper/mapper_test.go
Normal file
|
@ -0,0 +1,131 @@
|
||||||
|
package mapper
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/google/go-cmp/cmp/cmpopts"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"gopkg.in/check.v1"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
"tailscale.com/types/dnstype"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
|
||||||
|
mach := func(hostname, username string, userid uint) types.Machine {
|
||||||
|
return types.Machine{
|
||||||
|
Hostname: hostname,
|
||||||
|
UserID: userid,
|
||||||
|
User: types.User{
|
||||||
|
Name: username,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
machineInShared1 := mach("test_get_shared_nodes_1", "user1", 1)
|
||||||
|
machineInShared2 := mach("test_get_shared_nodes_2", "user2", 2)
|
||||||
|
machineInShared3 := mach("test_get_shared_nodes_3", "user3", 3)
|
||||||
|
machine2InShared1 := mach("test_get_shared_nodes_4", "user1", 1)
|
||||||
|
|
||||||
|
userProfiles := generateUserProfiles(
|
||||||
|
&machineInShared1,
|
||||||
|
types.Machines{
|
||||||
|
machineInShared2, machineInShared3, machine2InShared1,
|
||||||
|
},
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
|
||||||
|
c.Assert(len(userProfiles), check.Equals, 3)
|
||||||
|
|
||||||
|
users := []string{
|
||||||
|
"user1", "user2", "user3",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, user := range users {
|
||||||
|
found := false
|
||||||
|
for _, userProfile := range userProfiles {
|
||||||
|
if userProfile.DisplayName == user {
|
||||||
|
found = true
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.Assert(found, check.Equals, true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSConfigMapResponse(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
magicDNS bool
|
||||||
|
want *tailcfg.DNSConfig
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
magicDNS: true,
|
||||||
|
want: &tailcfg.DNSConfig{
|
||||||
|
Routes: map[string][]*dnstype.Resolver{
|
||||||
|
"shared1.foobar.headscale.net": {},
|
||||||
|
"shared2.foobar.headscale.net": {},
|
||||||
|
"shared3.foobar.headscale.net": {},
|
||||||
|
},
|
||||||
|
Domains: []string{
|
||||||
|
"foobar.headscale.net",
|
||||||
|
"shared1.foobar.headscale.net",
|
||||||
|
},
|
||||||
|
Proxied: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
magicDNS: false,
|
||||||
|
want: &tailcfg.DNSConfig{
|
||||||
|
Domains: []string{"foobar.headscale.net"},
|
||||||
|
Proxied: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(fmt.Sprintf("with-magicdns-%v", tt.magicDNS), func(t *testing.T) {
|
||||||
|
mach := func(hostname, username string, userid uint) types.Machine {
|
||||||
|
return types.Machine{
|
||||||
|
Hostname: hostname,
|
||||||
|
UserID: userid,
|
||||||
|
User: types.User{
|
||||||
|
Name: username,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
baseDomain := "foobar.headscale.net"
|
||||||
|
|
||||||
|
dnsConfigOrig := tailcfg.DNSConfig{
|
||||||
|
Routes: make(map[string][]*dnstype.Resolver),
|
||||||
|
Domains: []string{baseDomain},
|
||||||
|
Proxied: tt.magicDNS,
|
||||||
|
}
|
||||||
|
|
||||||
|
machineInShared1 := mach("test_get_shared_nodes_1", "shared1", 1)
|
||||||
|
machineInShared2 := mach("test_get_shared_nodes_2", "shared2", 2)
|
||||||
|
machineInShared3 := mach("test_get_shared_nodes_3", "shared3", 3)
|
||||||
|
machine2InShared1 := mach("test_get_shared_nodes_4", "shared1", 1)
|
||||||
|
|
||||||
|
peersOfMachineInShared1 := types.Machines{
|
||||||
|
machineInShared1,
|
||||||
|
machineInShared2,
|
||||||
|
machineInShared3,
|
||||||
|
machine2InShared1,
|
||||||
|
}
|
||||||
|
|
||||||
|
got := generateDNSConfig(
|
||||||
|
&dnsConfigOrig,
|
||||||
|
baseDomain,
|
||||||
|
machineInShared1,
|
||||||
|
peersOfMachineInShared1,
|
||||||
|
)
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" {
|
||||||
|
t.Errorf("expandAlias() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
15
hscontrol/mapper/suite_test.go
Normal file
15
hscontrol/mapper/suite_test.go
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
package mapper
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gopkg.in/check.v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test(t *testing.T) {
|
||||||
|
check.TestingT(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ = check.Suite(&Suite{})
|
||||||
|
|
||||||
|
type Suite struct{}
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/mapper"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
@ -61,7 +62,7 @@ func (h *Headscale) KeyHandler(
|
||||||
// TS2021 (Tailscale v2 protocol) requires to have a different key
|
// TS2021 (Tailscale v2 protocol) requires to have a different key
|
||||||
if clientCapabilityVersion >= NoiseCapabilityVersion {
|
if clientCapabilityVersion >= NoiseCapabilityVersion {
|
||||||
resp := tailcfg.OverTLSPublicKeyResponse{
|
resp := tailcfg.OverTLSPublicKeyResponse{
|
||||||
LegacyPublicKey: h.privateKey.Public(),
|
LegacyPublicKey: h.privateKey2019.Public(),
|
||||||
PublicKey: h.noisePrivateKey.Public(),
|
PublicKey: h.noisePrivateKey.Public(),
|
||||||
}
|
}
|
||||||
writer.Header().Set("Content-Type", "application/json")
|
writer.Header().Set("Content-Type", "application/json")
|
||||||
|
@ -84,7 +85,7 @@ func (h *Headscale) KeyHandler(
|
||||||
// Old clients don't send a 'v' parameter, so we send the legacy public key
|
// Old clients don't send a 'v' parameter, so we send the legacy public key
|
||||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
writer.WriteHeader(http.StatusOK)
|
writer.WriteHeader(http.StatusOK)
|
||||||
_, err := writer.Write([]byte(util.MachinePublicKeyStripPrefix(h.privateKey.Public())))
|
_, err := writer.Write([]byte(util.MachinePublicKeyStripPrefix(h.privateKey2019.Public())))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -323,7 +324,7 @@ func (h *Headscale) handleAuthKeyCommon(
|
||||||
Msg("Failed authentication via AuthKey")
|
Msg("Failed authentication via AuthKey")
|
||||||
resp.MachineAuthorized = false
|
resp.MachineAuthorized = false
|
||||||
|
|
||||||
respBody, err := h.marshalResponse(resp, machineKey, isNoise)
|
respBody, err := mapper.MarshalResponse(resp, h.privateKey2019, machineKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -483,7 +484,7 @@ func (h *Headscale) handleAuthKeyCommon(
|
||||||
// Otherwise it will need to exec `tailscale up` twice to fetch the *LoginName*
|
// Otherwise it will need to exec `tailscale up` twice to fetch the *LoginName*
|
||||||
resp.Login = *pak.User.TailscaleLogin()
|
resp.Login = *pak.User.TailscaleLogin()
|
||||||
|
|
||||||
respBody, err := h.marshalResponse(resp, machineKey, isNoise)
|
respBody, err := mapper.MarshalResponse(resp, h.privateKey2019, machineKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -548,7 +549,7 @@ func (h *Headscale) handleNewMachineCommon(
|
||||||
registerRequest.NodeKey)
|
registerRequest.NodeKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
respBody, err := h.marshalResponse(resp, machineKey, isNoise)
|
respBody, err := mapper.MarshalResponse(resp, h.privateKey2019, machineKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -609,7 +610,7 @@ func (h *Headscale) handleMachineLogOutCommon(
|
||||||
resp.MachineAuthorized = false
|
resp.MachineAuthorized = false
|
||||||
resp.NodeKeyExpired = true
|
resp.NodeKeyExpired = true
|
||||||
resp.User = *machine.User.TailscaleUser()
|
resp.User = *machine.User.TailscaleUser()
|
||||||
respBody, err := h.marshalResponse(resp, machineKey, isNoise)
|
respBody, err := mapper.MarshalResponse(resp, h.privateKey2019, machineKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -673,7 +674,7 @@ func (h *Headscale) handleMachineValidRegistrationCommon(
|
||||||
resp.User = *machine.User.TailscaleUser()
|
resp.User = *machine.User.TailscaleUser()
|
||||||
resp.Login = *machine.User.TailscaleLogin()
|
resp.Login = *machine.User.TailscaleLogin()
|
||||||
|
|
||||||
respBody, err := h.marshalResponse(resp, machineKey, isNoise)
|
respBody, err := mapper.MarshalResponse(resp, h.privateKey2019, machineKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -735,7 +736,7 @@ func (h *Headscale) handleMachineRefreshKeyCommon(
|
||||||
|
|
||||||
resp.AuthURL = ""
|
resp.AuthURL = ""
|
||||||
resp.User = *machine.User.TailscaleUser()
|
resp.User = *machine.User.TailscaleUser()
|
||||||
respBody, err := h.marshalResponse(resp, machineKey, isNoise)
|
respBody, err := mapper.MarshalResponse(resp, h.privateKey2019, machineKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -802,7 +803,7 @@ func (h *Headscale) handleMachineExpiredOrLoggedOutCommon(
|
||||||
registerRequest.NodeKey)
|
registerRequest.NodeKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
respBody, err := h.marshalResponse(resp, machineKey, isNoise)
|
respBody, err := mapper.MarshalResponse(resp, h.privateKey2019, machineKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/mapper"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
@ -29,6 +30,19 @@ func (h *Headscale) handlePollCommon(
|
||||||
mapRequest tailcfg.MapRequest,
|
mapRequest tailcfg.MapRequest,
|
||||||
isNoise bool,
|
isNoise bool,
|
||||||
) {
|
) {
|
||||||
|
// TODO(kradalby): This is a stepping stone, mapper should be initiated once
|
||||||
|
// per client or something similar
|
||||||
|
mapp := mapper.NewMapper(h.db,
|
||||||
|
h.privateKey2019,
|
||||||
|
isNoise,
|
||||||
|
h.DERPMap,
|
||||||
|
h.cfg.BaseDomain,
|
||||||
|
h.cfg.DNSConfig,
|
||||||
|
h.cfg.LogTail.Enabled,
|
||||||
|
h.cfg.RandomizeClientPort,
|
||||||
|
h.cfg.OIDC.StripEmaildomain,
|
||||||
|
)
|
||||||
|
|
||||||
machine.Hostname = mapRequest.Hostinfo.Hostname
|
machine.Hostname = mapRequest.Hostinfo.Hostname
|
||||||
machine.HostInfo = types.HostInfo(*mapRequest.Hostinfo)
|
machine.HostInfo = types.HostInfo(*mapRequest.Hostinfo)
|
||||||
machine.DiscoKey = util.DiscoPublicKeyStripPrefix(mapRequest.DiscoKey)
|
machine.DiscoKey = util.DiscoPublicKeyStripPrefix(mapRequest.DiscoKey)
|
||||||
|
@ -87,7 +101,7 @@ func (h *Headscale) handlePollCommon(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
mapResp, err := h.getMapResponseData(mapRequest, machine, isNoise)
|
mapResp, err := mapp.CreateMapResponse(mapRequest, machine, h.ACLPolicy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "PollNetMap").
|
Str("handler", "PollNetMap").
|
||||||
|
@ -245,6 +259,19 @@ func (h *Headscale) pollNetMapStream(
|
||||||
updateChan chan struct{},
|
updateChan chan struct{},
|
||||||
isNoise bool,
|
isNoise bool,
|
||||||
) {
|
) {
|
||||||
|
// TODO(kradalby): This is a stepping stone, mapper should be initiated once
|
||||||
|
// per client or something similar
|
||||||
|
mapp := mapper.NewMapper(h.db,
|
||||||
|
h.privateKey2019,
|
||||||
|
isNoise,
|
||||||
|
h.DERPMap,
|
||||||
|
h.cfg.BaseDomain,
|
||||||
|
h.cfg.DNSConfig,
|
||||||
|
h.cfg.LogTail.Enabled,
|
||||||
|
h.cfg.RandomizeClientPort,
|
||||||
|
h.cfg.OIDC.StripEmaildomain,
|
||||||
|
)
|
||||||
|
|
||||||
h.pollNetMapStreamWG.Add(1)
|
h.pollNetMapStreamWG.Add(1)
|
||||||
defer h.pollNetMapStreamWG.Done()
|
defer h.pollNetMapStreamWG.Done()
|
||||||
|
|
||||||
|
@ -463,7 +490,7 @@ func (h *Headscale) pollNetMapStream(
|
||||||
Time("last_successful_update", lastUpdate).
|
Time("last_successful_update", lastUpdate).
|
||||||
Time("last_state_change", h.getLastStateChange(machine.User)).
|
Time("last_state_change", h.getLastStateChange(machine.User)).
|
||||||
Msgf("There has been updates since the last successful update to %s", machine.Hostname)
|
Msgf("There has been updates since the last successful update to %s", machine.Hostname)
|
||||||
data, err := h.getMapResponseData(mapRequest, machine, isNoise)
|
data, err := mapp.CreateMapResponse(mapRequest, machine, h.ACLPolicy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
|
@ -623,6 +650,19 @@ func (h *Headscale) scheduledPollWorker(
|
||||||
machine *types.Machine,
|
machine *types.Machine,
|
||||||
isNoise bool,
|
isNoise bool,
|
||||||
) {
|
) {
|
||||||
|
// TODO(kradalby): This is a stepping stone, mapper should be initiated once
|
||||||
|
// per client or something similar
|
||||||
|
mapp := mapper.NewMapper(h.db,
|
||||||
|
h.privateKey2019,
|
||||||
|
isNoise,
|
||||||
|
h.DERPMap,
|
||||||
|
h.cfg.BaseDomain,
|
||||||
|
h.cfg.DNSConfig,
|
||||||
|
h.cfg.LogTail.Enabled,
|
||||||
|
h.cfg.RandomizeClientPort,
|
||||||
|
h.cfg.OIDC.StripEmaildomain,
|
||||||
|
)
|
||||||
|
|
||||||
keepAliveTicker := time.NewTicker(keepAliveInterval)
|
keepAliveTicker := time.NewTicker(keepAliveInterval)
|
||||||
updateCheckerTicker := time.NewTicker(h.cfg.NodeUpdateCheckInterval)
|
updateCheckerTicker := time.NewTicker(h.cfg.NodeUpdateCheckInterval)
|
||||||
|
|
||||||
|
@ -643,7 +683,7 @@ func (h *Headscale) scheduledPollWorker(
|
||||||
return
|
return
|
||||||
|
|
||||||
case <-keepAliveTicker.C:
|
case <-keepAliveTicker.C:
|
||||||
data, err := h.getMapKeepAliveResponseData(mapRequest, machine, isNoise)
|
data, err := mapp.CreateKeepAliveResponse(mapRequest, machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("func", "keepAlive").
|
Str("func", "keepAlive").
|
||||||
|
|
|
@ -1,157 +1 @@
|
||||||
package hscontrol
|
package hscontrol
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"encoding/json"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
|
||||||
"github.com/klauspost/compress/zstd"
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
"tailscale.com/smallzstd"
|
|
||||||
"tailscale.com/tailcfg"
|
|
||||||
"tailscale.com/types/key"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (h *Headscale) getMapResponseData(
|
|
||||||
mapRequest tailcfg.MapRequest,
|
|
||||||
machine *types.Machine,
|
|
||||||
isNoise bool,
|
|
||||||
) ([]byte, error) {
|
|
||||||
mapResponse, err := h.generateMapResponse(mapRequest, machine)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if isNoise {
|
|
||||||
return h.marshalMapResponse(mapResponse, key.MachinePublic{}, mapRequest.Compress, isNoise)
|
|
||||||
}
|
|
||||||
|
|
||||||
var machineKey key.MachinePublic
|
|
||||||
err = machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)))
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Caller().
|
|
||||||
Err(err).
|
|
||||||
Msg("Cannot parse client key")
|
|
||||||
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return h.marshalMapResponse(mapResponse, machineKey, mapRequest.Compress, isNoise)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Headscale) getMapKeepAliveResponseData(
|
|
||||||
mapRequest tailcfg.MapRequest,
|
|
||||||
machine *types.Machine,
|
|
||||||
isNoise bool,
|
|
||||||
) ([]byte, error) {
|
|
||||||
keepAliveResponse := tailcfg.MapResponse{
|
|
||||||
KeepAlive: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
if isNoise {
|
|
||||||
return h.marshalMapResponse(
|
|
||||||
keepAliveResponse,
|
|
||||||
key.MachinePublic{},
|
|
||||||
mapRequest.Compress,
|
|
||||||
isNoise,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
var machineKey key.MachinePublic
|
|
||||||
err := machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)))
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Caller().
|
|
||||||
Err(err).
|
|
||||||
Msg("Cannot parse client key")
|
|
||||||
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return h.marshalMapResponse(keepAliveResponse, machineKey, mapRequest.Compress, isNoise)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Headscale) marshalResponse(
|
|
||||||
resp interface{},
|
|
||||||
machineKey key.MachinePublic,
|
|
||||||
isNoise bool,
|
|
||||||
) ([]byte, error) {
|
|
||||||
jsonBody, err := json.Marshal(resp)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Caller().
|
|
||||||
Err(err).
|
|
||||||
Msg("Cannot marshal response")
|
|
||||||
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if isNoise {
|
|
||||||
return jsonBody, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return h.privateKey.SealTo(machineKey, jsonBody), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Headscale) marshalMapResponse(
|
|
||||||
resp interface{},
|
|
||||||
machineKey key.MachinePublic,
|
|
||||||
compression string,
|
|
||||||
isNoise bool,
|
|
||||||
) ([]byte, error) {
|
|
||||||
jsonBody, err := json.Marshal(resp)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Caller().
|
|
||||||
Err(err).
|
|
||||||
Msg("Cannot marshal map response")
|
|
||||||
}
|
|
||||||
|
|
||||||
var respBody []byte
|
|
||||||
if compression == util.ZstdCompression {
|
|
||||||
respBody = zstdEncode(jsonBody)
|
|
||||||
if !isNoise { // if legacy protocol
|
|
||||||
respBody = h.privateKey.SealTo(machineKey, respBody)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if !isNoise { // if legacy protocol
|
|
||||||
respBody = h.privateKey.SealTo(machineKey, jsonBody)
|
|
||||||
} else {
|
|
||||||
respBody = jsonBody
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
data := make([]byte, reservedResponseHeaderSize)
|
|
||||||
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
|
|
||||||
data = append(data, respBody...)
|
|
||||||
|
|
||||||
return data, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func zstdEncode(in []byte) []byte {
|
|
||||||
encoder, ok := zstdEncoderPool.Get().(*zstd.Encoder)
|
|
||||||
if !ok {
|
|
||||||
panic("invalid type in sync pool")
|
|
||||||
}
|
|
||||||
out := encoder.EncodeAll(in, nil)
|
|
||||||
_ = encoder.Close()
|
|
||||||
zstdEncoderPool.Put(encoder)
|
|
||||||
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
var zstdEncoderPool = &sync.Pool{
|
|
||||||
New: func() any {
|
|
||||||
encoder, err := smallzstd.NewEncoder(
|
|
||||||
nil,
|
|
||||||
zstd.WithEncoderLevel(zstd.SpeedFastest))
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return encoder
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
|
@ -45,7 +45,7 @@ func (h *Headscale) RegistrationHandler(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
registerRequest := tailcfg.RegisterRequest{}
|
registerRequest := tailcfg.RegisterRequest{}
|
||||||
err = util.DecodeAndUnmarshalNaCl(body, ®isterRequest, &machineKey, h.privateKey)
|
err = util.DecodeAndUnmarshalNaCl(body, ®isterRequest, &machineKey, h.privateKey2019)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
|
|
@ -57,7 +57,7 @@ func (h *Headscale) PollNetMapHandler(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
mapRequest := tailcfg.MapRequest{}
|
mapRequest := tailcfg.MapRequest{}
|
||||||
err = util.DecodeAndUnmarshalNaCl(body, &mapRequest, &machineKey, h.privateKey)
|
err = util.DecodeAndUnmarshalNaCl(body, &mapRequest, &machineKey, h.privateKey2019)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "PollNetMap").
|
Str("handler", "PollNetMap").
|
||||||
|
|
Loading…
Reference in a new issue