mirror of
https://github.com/juanfont/headscale.git
synced 2024-11-26 08:53:05 +00:00
introduce mapper package
The mapper package contains functions related to creating and marshalling reponses to machines. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
699655a93f
commit
f7f472ae07
18 changed files with 780 additions and 1063 deletions
47
flake.nix
47
flake.nix
|
@ -6,39 +6,36 @@
|
|||
flake-utils.url = "github:numtide/flake-utils";
|
||||
};
|
||||
|
||||
outputs =
|
||||
{ self
|
||||
, nixpkgs
|
||||
, flake-utils
|
||||
, ...
|
||||
}:
|
||||
let
|
||||
outputs = {
|
||||
self,
|
||||
nixpkgs,
|
||||
flake-utils,
|
||||
...
|
||||
}: let
|
||||
headscaleVersion =
|
||||
if (self ? shortRev)
|
||||
then self.shortRev
|
||||
else "dev";
|
||||
in
|
||||
{
|
||||
overlay = _: prev:
|
||||
let
|
||||
overlay = _: prev: let
|
||||
pkgs = nixpkgs.legacyPackages.${prev.system};
|
||||
in
|
||||
rec {
|
||||
in rec {
|
||||
headscale = pkgs.buildGo120Module rec {
|
||||
pname = "headscale";
|
||||
version = headscaleVersion;
|
||||
src = pkgs.lib.cleanSource self;
|
||||
|
||||
tags = [ "ts2019" ];
|
||||
tags = ["ts2019"];
|
||||
|
||||
# Only run unit tests when testing a build
|
||||
checkFlags = [ "-short" ];
|
||||
checkFlags = ["-short"];
|
||||
|
||||
# When updating go.mod or go.sum, a new sha will need to be calculated,
|
||||
# update this if you have a mismatch after doing a change to thos files.
|
||||
vendorSha256 = "sha256-IOkbbFtE6+tNKnglE/8ZuNxhPSnloqM2sLgTvagMmnc=";
|
||||
vendorSha256 = "sha256-ui0V7a8bAAm5B7zfN9g2pWTyMpudtm10RgYWQwC6kcA=";
|
||||
|
||||
ldflags = [ "-s" "-w" "-X github.com/juanfont/headscale/cmd/headscale/cli.Version=v${version}" ];
|
||||
ldflags = ["-s" "-w" "-X github.com/juanfont/headscale/cmd/headscale/cli.Version=v${version}"];
|
||||
};
|
||||
|
||||
golines = pkgs.buildGoModule rec {
|
||||
|
@ -54,7 +51,7 @@
|
|||
|
||||
vendorSha256 = "sha256-rxYuzn4ezAxaeDhxd8qdOzt+CKYIh03A9zKNdzILq18=";
|
||||
|
||||
nativeBuildInputs = [ pkgs.installShellFiles ];
|
||||
nativeBuildInputs = [pkgs.installShellFiles];
|
||||
};
|
||||
|
||||
golangci-lint = prev.golangci-lint.override {
|
||||
|
@ -76,20 +73,19 @@
|
|||
|
||||
vendorSha256 = "sha256-dGdnDuRbwg8fU7uB5GaHEWa/zI3w06onqjturvooJQA=";
|
||||
|
||||
nativeBuildInputs = [ pkgs.installShellFiles ];
|
||||
nativeBuildInputs = [pkgs.installShellFiles];
|
||||
|
||||
subPackages = [ "protoc-gen-grpc-gateway" "protoc-gen-openapiv2" ];
|
||||
subPackages = ["protoc-gen-grpc-gateway" "protoc-gen-openapiv2"];
|
||||
};
|
||||
};
|
||||
}
|
||||
// flake-utils.lib.eachDefaultSystem
|
||||
(system:
|
||||
let
|
||||
(system: let
|
||||
pkgs = import nixpkgs {
|
||||
overlays = [ self.overlay ];
|
||||
overlays = [self.overlay];
|
||||
inherit system;
|
||||
};
|
||||
buildDeps = with pkgs; [ git go_1_20 gnumake ];
|
||||
buildDeps = with pkgs; [git go_1_20 gnumake];
|
||||
devDeps = with pkgs;
|
||||
buildDeps
|
||||
++ [
|
||||
|
@ -123,11 +119,10 @@
|
|||
headscale-docker = pkgs.dockerTools.buildLayeredImage {
|
||||
name = "headscale";
|
||||
tag = headscaleVersion;
|
||||
contents = [ pkgs.headscale ];
|
||||
config.Entrypoint = [ (pkgs.headscale + "/bin/headscale") ];
|
||||
contents = [pkgs.headscale];
|
||||
config.Entrypoint = [(pkgs.headscale + "/bin/headscale")];
|
||||
};
|
||||
in
|
||||
rec {
|
||||
in rec {
|
||||
# `nix develop`
|
||||
devShell = pkgs.mkShell {
|
||||
buildInputs = devDeps;
|
||||
|
|
|
@ -1,115 +0,0 @@
|
|||
package hscontrol
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
func (h *Headscale) generateMapResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
machine *types.Machine,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
log.Trace().
|
||||
Str("func", "generateMapResponse").
|
||||
Str("machine", mapRequest.Hostinfo.Hostname).
|
||||
Msg("Creating Map response")
|
||||
node, err := h.db.TailNode(*machine, h.ACLPolicy, h.cfg.DNSConfig)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("func", "generateMapResponse").
|
||||
Err(err).
|
||||
Msg("Cannot convert to node")
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peers, err := h.db.GetValidPeers(h.aclRules, machine)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("func", "generateMapResponse").
|
||||
Err(err).
|
||||
Msg("Cannot fetch peers")
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
profiles := h.db.GetMapResponseUserProfiles(*machine, peers)
|
||||
|
||||
nodePeers, err := h.db.TailNodes(peers, h.ACLPolicy, h.cfg.DNSConfig)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("func", "generateMapResponse").
|
||||
Err(err).
|
||||
Msg("Failed to convert peers to Tailscale nodes")
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dnsConfig := getMapResponseDNSConfig(
|
||||
h.cfg.DNSConfig,
|
||||
h.cfg.BaseDomain,
|
||||
*machine,
|
||||
peers,
|
||||
)
|
||||
|
||||
now := time.Now()
|
||||
|
||||
resp := tailcfg.MapResponse{
|
||||
KeepAlive: false,
|
||||
Node: node,
|
||||
|
||||
// TODO: Only send if updated
|
||||
DERPMap: h.DERPMap,
|
||||
|
||||
// TODO: Only send if updated
|
||||
Peers: nodePeers,
|
||||
|
||||
// TODO(kradalby): Implement:
|
||||
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L1351-L1374
|
||||
// PeersChanged
|
||||
// PeersRemoved
|
||||
// PeersChangedPatch
|
||||
// PeerSeenChange
|
||||
// OnlineChange
|
||||
|
||||
// TODO: Only send if updated
|
||||
DNSConfig: dnsConfig,
|
||||
|
||||
// TODO: Only send if updated
|
||||
Domain: h.cfg.BaseDomain,
|
||||
|
||||
// Do not instruct clients to collect services, we do not
|
||||
// support or do anything with them
|
||||
CollectServices: "false",
|
||||
|
||||
// TODO: Only send if updated
|
||||
PacketFilter: h.aclRules,
|
||||
|
||||
UserProfiles: profiles,
|
||||
|
||||
// TODO: Only send if updated
|
||||
SSHPolicy: h.sshPolicy,
|
||||
|
||||
ControlTime: &now,
|
||||
|
||||
Debug: &tailcfg.Debug{
|
||||
DisableLogTail: !h.cfg.LogTail.Enabled,
|
||||
RandomizeClientPort: h.cfg.RandomizeClientPort,
|
||||
},
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
Str("func", "generateMapResponse").
|
||||
Str("machine", mapRequest.Hostinfo.Hostname).
|
||||
// Interface("payload", resp).
|
||||
Msgf("Generated map response: %s", util.TailMapResponseToString(resp))
|
||||
|
||||
return &resp, nil
|
||||
}
|
|
@ -80,7 +80,7 @@ type Headscale struct {
|
|||
dbString string
|
||||
dbType string
|
||||
dbDebug bool
|
||||
privateKey *key.MachinePrivate
|
||||
privateKey2019 *key.MachinePrivate
|
||||
noisePrivateKey *key.MachinePrivate
|
||||
|
||||
DERPMap *tailcfg.DERPMap
|
||||
|
@ -166,7 +166,7 @@ func NewHeadscale(cfg *Config) (*Headscale, error) {
|
|||
cfg: cfg,
|
||||
dbType: cfg.DBtype,
|
||||
dbString: dbString,
|
||||
privateKey: privateKey,
|
||||
privateKey2019: privateKey,
|
||||
noisePrivateKey: noisePrivateKey,
|
||||
aclRules: tailcfg.FilterAllowAll, // default allowall
|
||||
registrationCache: registrationCache,
|
||||
|
|
|
@ -39,15 +39,7 @@ var (
|
|||
)
|
||||
)
|
||||
|
||||
// filterMachinesByACL wrapper function to not have devs pass around locks and maps
|
||||
// related to the application outside of tests.
|
||||
func (hsdb *HSDatabase) filterMachinesByACL(
|
||||
aclRules []tailcfg.FilterRule,
|
||||
currentMachine *types.Machine, peers types.Machines,
|
||||
) types.Machines {
|
||||
return policy.FilterMachinesByACL(currentMachine, peers, aclRules)
|
||||
}
|
||||
|
||||
// ListPeers returns all peers of machine, regardless of any Policy.
|
||||
func (hsdb *HSDatabase) ListPeers(machine *types.Machine) (types.Machines, error) {
|
||||
log.Trace().
|
||||
Caller().
|
||||
|
@ -72,67 +64,6 @@ func (hsdb *HSDatabase) ListPeers(machine *types.Machine) (types.Machines, error
|
|||
return machines, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) getPeers(
|
||||
aclRules []tailcfg.FilterRule,
|
||||
machine *types.Machine,
|
||||
) (types.Machines, error) {
|
||||
var peers types.Machines
|
||||
var err error
|
||||
|
||||
// If ACLs rules are defined, filter visible host list with the ACLs
|
||||
// else use the classic user scope
|
||||
if len(aclRules) > 0 {
|
||||
var machines []types.Machine
|
||||
machines, err = hsdb.ListMachines()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Error retrieving list of machines")
|
||||
|
||||
return types.Machines{}, err
|
||||
}
|
||||
peers = hsdb.filterMachinesByACL(aclRules, machine, machines)
|
||||
} else {
|
||||
peers, err = hsdb.ListPeers(machine)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Cannot fetch peers")
|
||||
|
||||
return types.Machines{}, err
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(peers, func(i, j int) bool { return peers[i].ID < peers[j].ID })
|
||||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("self", machine.Hostname).
|
||||
Str("peers", peers.String()).
|
||||
Msg("Peers returned to caller")
|
||||
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetValidPeers(
|
||||
aclRules []tailcfg.FilterRule,
|
||||
machine *types.Machine,
|
||||
) (types.Machines, error) {
|
||||
validPeers := make(types.Machines, 0)
|
||||
|
||||
peers, err := hsdb.getPeers(aclRules, machine)
|
||||
if err != nil {
|
||||
return types.Machines{}, err
|
||||
}
|
||||
|
||||
for _, peer := range peers {
|
||||
if !peer.IsExpired() {
|
||||
validPeers = append(validPeers, peer)
|
||||
}
|
||||
}
|
||||
|
||||
return validPeers, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) ListMachines() ([]types.Machine, error) {
|
||||
machines := []types.Machine{}
|
||||
if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Find(&machines).Error; err != nil {
|
||||
|
|
|
@ -293,8 +293,8 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
|||
aclRules, _, err := policy.GenerateFilterRules(aclPolicy, machines, false)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
peersOfTestMachine := db.filterMachinesByACL(aclRules, testMachine, machines)
|
||||
peersOfAdminMachine := db.filterMachinesByACL(aclRules, adminMachine, machines)
|
||||
peersOfTestMachine := policy.FilterMachinesByACL(testMachine, machines, aclRules)
|
||||
peersOfAdminMachine := policy.FilterMachinesByACL(adminMachine, machines, aclRules)
|
||||
|
||||
c.Log(peersOfTestMachine)
|
||||
c.Assert(len(peersOfTestMachine), check.Equals, 9)
|
||||
|
|
|
@ -2,13 +2,11 @@ package db
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -163,32 +161,3 @@ func (hsdb *HSDatabase) SetMachineUser(machine *types.Machine, username string)
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetMapResponseUserProfiles(
|
||||
machine types.Machine,
|
||||
peers types.Machines,
|
||||
) []tailcfg.UserProfile {
|
||||
userMap := make(map[string]types.User)
|
||||
userMap[machine.User.Name] = machine.User
|
||||
for _, peer := range peers {
|
||||
userMap[peer.User.Name] = peer.User // not worth checking if already is there
|
||||
}
|
||||
|
||||
profiles := []tailcfg.UserProfile{}
|
||||
for _, user := range userMap {
|
||||
displayName := user.Name
|
||||
|
||||
if hsdb.baseDomain != "" {
|
||||
displayName = fmt.Sprintf("%s@%s", user.Name, hsdb.baseDomain)
|
||||
}
|
||||
|
||||
profiles = append(profiles,
|
||||
tailcfg.UserProfile{
|
||||
ID: tailcfg.UserID(user.ID),
|
||||
LoginName: user.Name,
|
||||
DisplayName: displayName,
|
||||
})
|
||||
}
|
||||
|
||||
return profiles
|
||||
}
|
||||
|
|
|
@ -1,13 +1,10 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"gopkg.in/check.v1"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
func (s *Suite) TestCreateAndDestroyUser(c *check.C) {
|
||||
|
@ -94,151 +91,6 @@ func (s *Suite) TestRenameUser(c *check.C) {
|
|||
c.Assert(err, check.Equals, ErrUserExists)
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
|
||||
userShared1, err := db.CreateUser("shared1")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
userShared2, err := db.CreateUser("shared2")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
userShared3, err := db.CreateUser("shared3")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
preAuthKeyShared1, err := db.CreatePreAuthKey(
|
||||
userShared1.Name,
|
||||
false,
|
||||
false,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
preAuthKeyShared2, err := db.CreatePreAuthKey(
|
||||
userShared2.Name,
|
||||
false,
|
||||
false,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
preAuthKeyShared3, err := db.CreatePreAuthKey(
|
||||
userShared3.Name,
|
||||
false,
|
||||
false,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
preAuthKey2Shared1, err := db.CreatePreAuthKey(
|
||||
userShared1.Name,
|
||||
false,
|
||||
false,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine(userShared1.Name, "test_get_shared_nodes_1")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
machineInShared1 := &types.Machine{
|
||||
ID: 1,
|
||||
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||
Hostname: "test_get_shared_nodes_1",
|
||||
UserID: userShared1.ID,
|
||||
User: *userShared1,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
|
||||
AuthKeyID: uint(preAuthKeyShared1.ID),
|
||||
}
|
||||
db.db.Save(machineInShared1)
|
||||
|
||||
_, err = db.GetMachine(userShared1.Name, machineInShared1.Hostname)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machineInShared2 := &types.Machine{
|
||||
ID: 2,
|
||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
Hostname: "test_get_shared_nodes_2",
|
||||
UserID: userShared2.ID,
|
||||
User: *userShared2,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
|
||||
AuthKeyID: uint(preAuthKeyShared2.ID),
|
||||
}
|
||||
db.db.Save(machineInShared2)
|
||||
|
||||
_, err = db.GetMachine(userShared2.Name, machineInShared2.Hostname)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machineInShared3 := &types.Machine{
|
||||
ID: 3,
|
||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
Hostname: "test_get_shared_nodes_3",
|
||||
UserID: userShared3.ID,
|
||||
User: *userShared3,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
|
||||
AuthKeyID: uint(preAuthKeyShared3.ID),
|
||||
}
|
||||
db.db.Save(machineInShared3)
|
||||
|
||||
_, err = db.GetMachine(userShared3.Name, machineInShared3.Hostname)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machine2InShared1 := &types.Machine{
|
||||
ID: 4,
|
||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
Hostname: "test_get_shared_nodes_4",
|
||||
UserID: userShared1.ID,
|
||||
User: *userShared1,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")},
|
||||
AuthKeyID: uint(preAuthKey2Shared1.ID),
|
||||
}
|
||||
db.db.Save(machine2InShared1)
|
||||
|
||||
peersOfMachine1InShared1, err := db.getPeers([]tailcfg.FilterRule{}, machineInShared1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
userProfiles := db.GetMapResponseUserProfiles(
|
||||
*machineInShared1,
|
||||
peersOfMachine1InShared1,
|
||||
)
|
||||
|
||||
c.Assert(len(userProfiles), check.Equals, 3)
|
||||
|
||||
found := false
|
||||
for _, userProfiles := range userProfiles {
|
||||
if userProfiles.DisplayName == userShared1.Name {
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
c.Assert(found, check.Equals, true)
|
||||
|
||||
found = false
|
||||
for _, userProfile := range userProfiles {
|
||||
if userProfile.DisplayName == userShared2.Name {
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
c.Assert(found, check.Equals, true)
|
||||
}
|
||||
|
||||
func (s *Suite) TestSetMachineUser(c *check.C) {
|
||||
oldUser, err := db.CreateUser("old")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
|
|
@ -32,7 +32,7 @@ type DERPServer struct {
|
|||
|
||||
func (h *Headscale) NewDERPServer() (*DERPServer, error) {
|
||||
log.Trace().Caller().Msg("Creating new embedded DERP server")
|
||||
server := derp.NewServer(key.NodePrivate(*h.privateKey), log.Info().Msgf)
|
||||
server := derp.NewServer(key.NodePrivate(*h.privateKey2019), log.Info().Msgf)
|
||||
region, err := h.generateRegionLocalDERP()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -156,7 +156,7 @@ func (h *Headscale) DERPHandler(
|
|||
log.Trace().Caller().Msgf("Hijacked connection from %v", req.RemoteAddr)
|
||||
|
||||
if !fastStart {
|
||||
pubKey := h.privateKey.Public()
|
||||
pubKey := h.privateKey2019.Public()
|
||||
pubKeyStr, _ := pubKey.MarshalText() //nolint
|
||||
fmt.Fprintf(conn, "HTTP/1.1 101 Switching Protocols\r\n"+
|
||||
"Upgrade: DERP\r\n"+
|
||||
|
|
|
@ -3,14 +3,9 @@ package hscontrol
|
|||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
mapset "github.com/deckarep/golang-set/v2"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"go4.org/netipx"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/dnstype"
|
||||
"tailscale.com/util/dnsname"
|
||||
)
|
||||
|
||||
|
@ -23,10 +18,6 @@ const (
|
|||
ipv6AddressLength = 128
|
||||
)
|
||||
|
||||
const (
|
||||
nextDNSDoHPrefix = "https://dns.nextdns.io"
|
||||
)
|
||||
|
||||
// generateMagicDNSRootDomains generates a list of DNS entries to be included in `Routes` in `MapResponse`.
|
||||
// This list of reverse DNS entries instructs the OS on what subnets and domains the Tailscale embedded DNS
|
||||
// server (listening in 100.100.100.100 udp/53) should be used for.
|
||||
|
@ -158,63 +149,3 @@ func generateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
|
|||
|
||||
return fqdns
|
||||
}
|
||||
|
||||
// If any nextdns DoH resolvers are present in the list of resolvers it will
|
||||
// take metadata from the machine metadata and instruct tailscale to add it
|
||||
// to the requests. This makes it possible to identify from which device the
|
||||
// requests come in the NextDNS dashboard.
|
||||
//
|
||||
// This will produce a resolver like:
|
||||
// `https://dns.nextdns.io/<nextdns-id>?device_name=node-name&device_model=linux&device_ip=100.64.0.1`
|
||||
func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine types.Machine) {
|
||||
for _, resolver := range resolvers {
|
||||
if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) {
|
||||
attrs := url.Values{
|
||||
"device_name": []string{machine.Hostname},
|
||||
"device_model": []string{machine.HostInfo.OS},
|
||||
}
|
||||
|
||||
if len(machine.IPAddresses) > 0 {
|
||||
attrs.Add("device_ip", machine.IPAddresses[0].String())
|
||||
}
|
||||
|
||||
resolver.Addr = fmt.Sprintf("%s?%s", resolver.Addr, attrs.Encode())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getMapResponseDNSConfig(
|
||||
dnsConfigOrig *tailcfg.DNSConfig,
|
||||
baseDomain string,
|
||||
machine types.Machine,
|
||||
peers types.Machines,
|
||||
) *tailcfg.DNSConfig {
|
||||
var dnsConfig *tailcfg.DNSConfig = dnsConfigOrig.Clone()
|
||||
if dnsConfigOrig != nil && dnsConfigOrig.Proxied { // if MagicDNS is enabled
|
||||
// Only inject the Search Domain of the current user - shared nodes should use their full FQDN
|
||||
dnsConfig.Domains = append(
|
||||
dnsConfig.Domains,
|
||||
fmt.Sprintf(
|
||||
"%s.%s",
|
||||
machine.User.Name,
|
||||
baseDomain,
|
||||
),
|
||||
)
|
||||
|
||||
userSet := mapset.NewSet[types.User]()
|
||||
userSet.Add(machine.User)
|
||||
for _, p := range peers {
|
||||
userSet.Add(p.User)
|
||||
}
|
||||
for _, user := range userSet.ToSlice() {
|
||||
dnsRoute := fmt.Sprintf("%v.%v", user.Name, baseDomain)
|
||||
dnsConfig.Routes[dnsRoute] = nil
|
||||
}
|
||||
} else {
|
||||
dnsConfig = dnsConfigOrig
|
||||
}
|
||||
|
||||
addNextDNSMetadata(dnsConfig.Resolvers, machine)
|
||||
|
||||
return dnsConfig
|
||||
}
|
||||
|
|
|
@ -1,14 +1,9 @@
|
|||
package hscontrol
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"gopkg.in/check.v1"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/dnstype"
|
||||
)
|
||||
|
||||
func (s *Suite) TestMagicDNSRootDomains100(c *check.C) {
|
||||
|
@ -112,293 +107,3 @@ func (s *Suite) TestMagicDNSRootDomainsIPv6SingleMultiple(c *check.C) {
|
|||
c.Assert(yieldsRoot("2.0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa."), check.Equals, true)
|
||||
c.Assert(yieldsRoot("3.0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa."), check.Equals, true)
|
||||
}
|
||||
|
||||
func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
|
||||
userShared1, err := app.db.CreateUser("shared1")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
userShared2, err := app.db.CreateUser("shared2")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
userShared3, err := app.db.CreateUser("shared3")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
preAuthKeyInShared1, err := app.db.CreatePreAuthKey(
|
||||
userShared1.Name,
|
||||
false,
|
||||
false,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
preAuthKeyInShared2, err := app.db.CreatePreAuthKey(
|
||||
userShared2.Name,
|
||||
false,
|
||||
false,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
preAuthKeyInShared3, err := app.db.CreatePreAuthKey(
|
||||
userShared3.Name,
|
||||
false,
|
||||
false,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
PreAuthKey2InShared1, err := app.db.CreatePreAuthKey(
|
||||
userShared1.Name,
|
||||
false,
|
||||
false,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
machineInShared1 := &types.Machine{
|
||||
ID: 1,
|
||||
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||
Hostname: "test_get_shared_nodes_1",
|
||||
UserID: userShared1.ID,
|
||||
User: *userShared1,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
|
||||
AuthKeyID: uint(preAuthKeyInShared1.ID),
|
||||
}
|
||||
err = app.db.MachineSave(machineInShared1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machineInShared2 := &types.Machine{
|
||||
ID: 2,
|
||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
Hostname: "test_get_shared_nodes_2",
|
||||
UserID: userShared2.ID,
|
||||
User: *userShared2,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
|
||||
AuthKeyID: uint(preAuthKeyInShared2.ID),
|
||||
}
|
||||
err = app.db.MachineSave(machineInShared2)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machineInShared3 := &types.Machine{
|
||||
ID: 3,
|
||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
Hostname: "test_get_shared_nodes_3",
|
||||
UserID: userShared3.ID,
|
||||
User: *userShared3,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
|
||||
AuthKeyID: uint(preAuthKeyInShared3.ID),
|
||||
}
|
||||
err = app.db.MachineSave(machineInShared3)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machine2InShared1 := &types.Machine{
|
||||
ID: 4,
|
||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
Hostname: "test_get_shared_nodes_4",
|
||||
UserID: userShared1.ID,
|
||||
User: *userShared1,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")},
|
||||
AuthKeyID: uint(PreAuthKey2InShared1.ID),
|
||||
}
|
||||
err = app.db.MachineSave(machine2InShared1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
baseDomain := "foobar.headscale.net"
|
||||
dnsConfigOrig := tailcfg.DNSConfig{
|
||||
Routes: make(map[string][]*dnstype.Resolver),
|
||||
Domains: []string{baseDomain},
|
||||
Proxied: true,
|
||||
}
|
||||
|
||||
peersOfMachineInShared1, err := app.db.GetValidPeers(app.aclRules, machineInShared1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
dnsConfig := getMapResponseDNSConfig(
|
||||
&dnsConfigOrig,
|
||||
baseDomain,
|
||||
*machineInShared1,
|
||||
peersOfMachineInShared1,
|
||||
)
|
||||
c.Assert(dnsConfig, check.NotNil)
|
||||
|
||||
c.Assert(len(dnsConfig.Routes), check.Equals, 3)
|
||||
|
||||
domainRouteShared1 := fmt.Sprintf("%s.%s", userShared1.Name, baseDomain)
|
||||
_, ok := dnsConfig.Routes[domainRouteShared1]
|
||||
c.Assert(ok, check.Equals, true)
|
||||
|
||||
domainRouteShared2 := fmt.Sprintf("%s.%s", userShared2.Name, baseDomain)
|
||||
_, ok = dnsConfig.Routes[domainRouteShared2]
|
||||
c.Assert(ok, check.Equals, true)
|
||||
|
||||
domainRouteShared3 := fmt.Sprintf("%s.%s", userShared3.Name, baseDomain)
|
||||
_, ok = dnsConfig.Routes[domainRouteShared3]
|
||||
c.Assert(ok, check.Equals, true)
|
||||
}
|
||||
|
||||
func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
|
||||
userShared1, err := app.db.CreateUser("shared1")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
userShared2, err := app.db.CreateUser("shared2")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
userShared3, err := app.db.CreateUser("shared3")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
preAuthKeyInShared1, err := app.db.CreatePreAuthKey(
|
||||
userShared1.Name,
|
||||
false,
|
||||
false,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
preAuthKeyInShared2, err := app.db.CreatePreAuthKey(
|
||||
userShared2.Name,
|
||||
false,
|
||||
false,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
preAuthKeyInShared3, err := app.db.CreatePreAuthKey(
|
||||
userShared3.Name,
|
||||
false,
|
||||
false,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
preAuthKey2InShared1, err := app.db.CreatePreAuthKey(
|
||||
userShared1.Name,
|
||||
false,
|
||||
false,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
machineInShared1 := &types.Machine{
|
||||
ID: 1,
|
||||
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||
Hostname: "test_get_shared_nodes_1",
|
||||
UserID: userShared1.ID,
|
||||
User: *userShared1,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
|
||||
AuthKeyID: uint(preAuthKeyInShared1.ID),
|
||||
}
|
||||
err = app.db.MachineSave(machineInShared1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machineInShared2 := &types.Machine{
|
||||
ID: 2,
|
||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
Hostname: "test_get_shared_nodes_2",
|
||||
UserID: userShared2.ID,
|
||||
User: *userShared2,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
|
||||
AuthKeyID: uint(preAuthKeyInShared2.ID),
|
||||
}
|
||||
err = app.db.MachineSave(machineInShared2)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machineInShared3 := &types.Machine{
|
||||
ID: 3,
|
||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
Hostname: "test_get_shared_nodes_3",
|
||||
UserID: userShared3.ID,
|
||||
User: *userShared3,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
|
||||
AuthKeyID: uint(preAuthKeyInShared3.ID),
|
||||
}
|
||||
err = app.db.MachineSave(machineInShared3)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machine2InShared1 := &types.Machine{
|
||||
ID: 4,
|
||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
Hostname: "test_get_shared_nodes_4",
|
||||
UserID: userShared1.ID,
|
||||
User: *userShared1,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")},
|
||||
AuthKeyID: uint(preAuthKey2InShared1.ID),
|
||||
}
|
||||
err = app.db.MachineSave(machine2InShared1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
baseDomain := "foobar.headscale.net"
|
||||
dnsConfigOrig := tailcfg.DNSConfig{
|
||||
Routes: make(map[string][]*dnstype.Resolver),
|
||||
Domains: []string{baseDomain},
|
||||
Proxied: false,
|
||||
}
|
||||
|
||||
peersOfMachine1Shared1, err := app.db.GetValidPeers(app.aclRules, machineInShared1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
dnsConfig := getMapResponseDNSConfig(
|
||||
&dnsConfigOrig,
|
||||
baseDomain,
|
||||
*machineInShared1,
|
||||
peersOfMachine1Shared1,
|
||||
)
|
||||
c.Assert(dnsConfig, check.NotNil)
|
||||
c.Assert(len(dnsConfig.Routes), check.Equals, 0)
|
||||
c.Assert(len(dnsConfig.Domains), check.Equals, 1)
|
||||
}
|
||||
|
|
418
hscontrol/mapper/mapper.go
Normal file
418
hscontrol/mapper/mapper.go
Normal file
|
@ -0,0 +1,418 @@
|
|||
package mapper
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
mapset "github.com/deckarep/golang-set/v2"
|
||||
"github.com/juanfont/headscale/hscontrol/db"
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/rs/zerolog/log"
|
||||
"tailscale.com/smallzstd"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/dnstype"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
const (
|
||||
nextDNSDoHPrefix = "https://dns.nextdns.io"
|
||||
reservedResponseHeaderSize = 4
|
||||
)
|
||||
|
||||
type Mapper struct {
|
||||
db *db.HSDatabase
|
||||
|
||||
privateKey2019 *key.MachinePrivate
|
||||
isNoise bool
|
||||
|
||||
// Configuration
|
||||
// TODO(kradalby): figure out if this is the format we want this in
|
||||
derpMap *tailcfg.DERPMap
|
||||
baseDomain string
|
||||
dnsCfg *tailcfg.DNSConfig
|
||||
logtail bool
|
||||
randomClientPort bool
|
||||
stripEmailDomain bool
|
||||
}
|
||||
|
||||
func NewMapper(
|
||||
db *db.HSDatabase,
|
||||
privateKey *key.MachinePrivate,
|
||||
isNoise bool,
|
||||
derpMap *tailcfg.DERPMap,
|
||||
baseDomain string,
|
||||
dnsCfg *tailcfg.DNSConfig,
|
||||
logtail bool,
|
||||
randomClientPort bool,
|
||||
stripEmailDomain bool,
|
||||
) *Mapper {
|
||||
return &Mapper{
|
||||
db: db,
|
||||
|
||||
privateKey2019: privateKey,
|
||||
isNoise: isNoise,
|
||||
|
||||
derpMap: derpMap,
|
||||
baseDomain: baseDomain,
|
||||
dnsCfg: dnsCfg,
|
||||
logtail: logtail,
|
||||
randomClientPort: randomClientPort,
|
||||
stripEmailDomain: stripEmailDomain,
|
||||
}
|
||||
}
|
||||
|
||||
func (m Mapper) fullMapResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
machine *types.Machine,
|
||||
pol *policy.ACLPolicy,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", mapRequest.Hostinfo.Hostname).
|
||||
Msg("Creating Map response")
|
||||
|
||||
// TODO(kradalby): Decouple this from DB?
|
||||
node, err := m.db.TailNode(*machine, pol, m.dnsCfg)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Cannot convert to node")
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peers, err := m.db.ListPeers(machine)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Cannot fetch peers")
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rules, sshPolicy, err := policy.GenerateFilterRules(pol, peers, m.stripEmailDomain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(rules) > 0 {
|
||||
peers = policy.FilterMachinesByACL(machine, peers, rules)
|
||||
}
|
||||
|
||||
profiles := generateUserProfiles(machine, peers, m.baseDomain)
|
||||
|
||||
// TODO(kradalby): Decouple this from DB?
|
||||
nodePeers, err := m.db.TailNodes(peers, pol, m.dnsCfg)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Failed to convert peers to Tailscale nodes")
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO(kradalby): Shold this mutation happen before TailNode(s) is called?
|
||||
dnsConfig := generateDNSConfig(
|
||||
m.dnsCfg,
|
||||
m.baseDomain,
|
||||
*machine,
|
||||
peers,
|
||||
)
|
||||
|
||||
now := time.Now()
|
||||
|
||||
resp := tailcfg.MapResponse{
|
||||
KeepAlive: false,
|
||||
Node: node,
|
||||
|
||||
// TODO: Only send if updated
|
||||
DERPMap: m.derpMap,
|
||||
|
||||
// TODO: Only send if updated
|
||||
Peers: nodePeers,
|
||||
|
||||
// TODO(kradalby): Implement:
|
||||
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L1351-L1374
|
||||
// PeersChanged
|
||||
// PeersRemoved
|
||||
// PeersChangedPatch
|
||||
// PeerSeenChange
|
||||
// OnlineChange
|
||||
|
||||
// TODO: Only send if updated
|
||||
DNSConfig: dnsConfig,
|
||||
|
||||
// TODO: Only send if updated
|
||||
Domain: m.baseDomain,
|
||||
|
||||
// Do not instruct clients to collect services, we do not
|
||||
// support or do anything with them
|
||||
CollectServices: "false",
|
||||
|
||||
// TODO: Only send if updated
|
||||
PacketFilter: rules,
|
||||
|
||||
UserProfiles: profiles,
|
||||
|
||||
// TODO: Only send if updated
|
||||
SSHPolicy: sshPolicy,
|
||||
|
||||
ControlTime: &now,
|
||||
|
||||
Debug: &tailcfg.Debug{
|
||||
DisableLogTail: !m.logtail,
|
||||
RandomizeClientPort: m.randomClientPort,
|
||||
},
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", mapRequest.Hostinfo.Hostname).
|
||||
// Interface("payload", resp).
|
||||
Msgf("Generated map response: %s", util.TailMapResponseToString(resp))
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func generateUserProfiles(
|
||||
machine *types.Machine,
|
||||
peers types.Machines,
|
||||
baseDomain string,
|
||||
) []tailcfg.UserProfile {
|
||||
userMap := make(map[string]types.User)
|
||||
userMap[machine.User.Name] = machine.User
|
||||
for _, peer := range peers {
|
||||
userMap[peer.User.Name] = peer.User // not worth checking if already is there
|
||||
}
|
||||
|
||||
profiles := []tailcfg.UserProfile{}
|
||||
for _, user := range userMap {
|
||||
displayName := user.Name
|
||||
|
||||
if baseDomain != "" {
|
||||
displayName = fmt.Sprintf("%s@%s", user.Name, baseDomain)
|
||||
}
|
||||
|
||||
profiles = append(profiles,
|
||||
tailcfg.UserProfile{
|
||||
ID: tailcfg.UserID(user.ID),
|
||||
LoginName: user.Name,
|
||||
DisplayName: displayName,
|
||||
})
|
||||
}
|
||||
|
||||
return profiles
|
||||
}
|
||||
|
||||
func generateDNSConfig(
|
||||
base *tailcfg.DNSConfig,
|
||||
baseDomain string,
|
||||
machine types.Machine,
|
||||
peers types.Machines,
|
||||
) *tailcfg.DNSConfig {
|
||||
dnsConfig := base.Clone()
|
||||
|
||||
// if MagicDNS is enabled
|
||||
if base != nil && base.Proxied {
|
||||
// Only inject the Search Domain of the current user
|
||||
// shared nodes should use their full FQDN
|
||||
dnsConfig.Domains = append(
|
||||
dnsConfig.Domains,
|
||||
fmt.Sprintf(
|
||||
"%s.%s",
|
||||
machine.User.Name,
|
||||
baseDomain,
|
||||
),
|
||||
)
|
||||
|
||||
userSet := mapset.NewSet[types.User]()
|
||||
userSet.Add(machine.User)
|
||||
for _, p := range peers {
|
||||
userSet.Add(p.User)
|
||||
}
|
||||
for _, user := range userSet.ToSlice() {
|
||||
dnsRoute := fmt.Sprintf("%v.%v", user.Name, baseDomain)
|
||||
dnsConfig.Routes[dnsRoute] = nil
|
||||
}
|
||||
} else {
|
||||
dnsConfig = base
|
||||
}
|
||||
|
||||
addNextDNSMetadata(dnsConfig.Resolvers, machine)
|
||||
|
||||
return dnsConfig
|
||||
}
|
||||
|
||||
// If any nextdns DoH resolvers are present in the list of resolvers it will
|
||||
// take metadata from the machine metadata and instruct tailscale to add it
|
||||
// to the requests. This makes it possible to identify from which device the
|
||||
// requests come in the NextDNS dashboard.
|
||||
//
|
||||
// This will produce a resolver like:
|
||||
// `https://dns.nextdns.io/<nextdns-id>?device_name=node-name&device_model=linux&device_ip=100.64.0.1`
|
||||
func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine types.Machine) {
|
||||
for _, resolver := range resolvers {
|
||||
if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) {
|
||||
attrs := url.Values{
|
||||
"device_name": []string{machine.Hostname},
|
||||
"device_model": []string{machine.HostInfo.OS},
|
||||
}
|
||||
|
||||
if len(machine.IPAddresses) > 0 {
|
||||
attrs.Add("device_ip", machine.IPAddresses[0].String())
|
||||
}
|
||||
|
||||
resolver.Addr = fmt.Sprintf("%s?%s", resolver.Addr, attrs.Encode())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m Mapper) CreateMapResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
machine *types.Machine,
|
||||
pol *policy.ACLPolicy,
|
||||
) ([]byte, error) {
|
||||
mapResponse, err := m.fullMapResponse(mapRequest, machine, pol)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if m.isNoise {
|
||||
return m.marshalMapResponse(mapResponse, key.MachinePublic{}, mapRequest.Compress)
|
||||
}
|
||||
|
||||
var machineKey key.MachinePublic
|
||||
err = machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)))
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Cannot parse client key")
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m.marshalMapResponse(mapResponse, machineKey, mapRequest.Compress)
|
||||
}
|
||||
|
||||
func (m Mapper) CreateKeepAliveResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
machine *types.Machine,
|
||||
) ([]byte, error) {
|
||||
keepAliveResponse := tailcfg.MapResponse{
|
||||
KeepAlive: true,
|
||||
}
|
||||
|
||||
if m.isNoise {
|
||||
return m.marshalMapResponse(
|
||||
keepAliveResponse,
|
||||
key.MachinePublic{},
|
||||
mapRequest.Compress,
|
||||
)
|
||||
}
|
||||
|
||||
var machineKey key.MachinePublic
|
||||
err := machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)))
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Cannot parse client key")
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m.marshalMapResponse(keepAliveResponse, machineKey, mapRequest.Compress)
|
||||
}
|
||||
|
||||
func MarshalResponse(
|
||||
resp interface{},
|
||||
privateKey2019 *key.MachinePrivate,
|
||||
machineKey key.MachinePublic,
|
||||
) ([]byte, error) {
|
||||
jsonBody, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Cannot marshal response")
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if privateKey2019 != nil {
|
||||
return privateKey2019.SealTo(machineKey, jsonBody), nil
|
||||
}
|
||||
|
||||
return jsonBody, nil
|
||||
}
|
||||
|
||||
func (m Mapper) marshalMapResponse(
|
||||
resp interface{},
|
||||
machineKey key.MachinePublic,
|
||||
compression string,
|
||||
) ([]byte, error) {
|
||||
jsonBody, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Cannot marshal map response")
|
||||
}
|
||||
|
||||
var respBody []byte
|
||||
if compression == util.ZstdCompression {
|
||||
respBody = zstdEncode(jsonBody)
|
||||
if !m.isNoise { // if legacy protocol
|
||||
respBody = m.privateKey2019.SealTo(machineKey, respBody)
|
||||
}
|
||||
} else {
|
||||
if !m.isNoise { // if legacy protocol
|
||||
respBody = m.privateKey2019.SealTo(machineKey, jsonBody)
|
||||
} else {
|
||||
respBody = jsonBody
|
||||
}
|
||||
}
|
||||
|
||||
data := make([]byte, reservedResponseHeaderSize)
|
||||
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
|
||||
data = append(data, respBody...)
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func zstdEncode(in []byte) []byte {
|
||||
encoder, ok := zstdEncoderPool.Get().(*zstd.Encoder)
|
||||
if !ok {
|
||||
panic("invalid type in sync pool")
|
||||
}
|
||||
out := encoder.EncodeAll(in, nil)
|
||||
_ = encoder.Close()
|
||||
zstdEncoderPool.Put(encoder)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
var zstdEncoderPool = &sync.Pool{
|
||||
New: func() any {
|
||||
encoder, err := smallzstd.NewEncoder(
|
||||
nil,
|
||||
zstd.WithEncoderLevel(zstd.SpeedFastest))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return encoder
|
||||
},
|
||||
}
|
131
hscontrol/mapper/mapper_test.go
Normal file
131
hscontrol/mapper/mapper_test.go
Normal file
|
@ -0,0 +1,131 @@
|
|||
package mapper
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"gopkg.in/check.v1"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/dnstype"
|
||||
)
|
||||
|
||||
func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
|
||||
mach := func(hostname, username string, userid uint) types.Machine {
|
||||
return types.Machine{
|
||||
Hostname: hostname,
|
||||
UserID: userid,
|
||||
User: types.User{
|
||||
Name: username,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
machineInShared1 := mach("test_get_shared_nodes_1", "user1", 1)
|
||||
machineInShared2 := mach("test_get_shared_nodes_2", "user2", 2)
|
||||
machineInShared3 := mach("test_get_shared_nodes_3", "user3", 3)
|
||||
machine2InShared1 := mach("test_get_shared_nodes_4", "user1", 1)
|
||||
|
||||
userProfiles := generateUserProfiles(
|
||||
&machineInShared1,
|
||||
types.Machines{
|
||||
machineInShared2, machineInShared3, machine2InShared1,
|
||||
},
|
||||
"",
|
||||
)
|
||||
|
||||
c.Assert(len(userProfiles), check.Equals, 3)
|
||||
|
||||
users := []string{
|
||||
"user1", "user2", "user3",
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
found := false
|
||||
for _, userProfile := range userProfiles {
|
||||
if userProfile.DisplayName == user {
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
c.Assert(found, check.Equals, true)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSConfigMapResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
magicDNS bool
|
||||
want *tailcfg.DNSConfig
|
||||
}{
|
||||
{
|
||||
magicDNS: true,
|
||||
want: &tailcfg.DNSConfig{
|
||||
Routes: map[string][]*dnstype.Resolver{
|
||||
"shared1.foobar.headscale.net": {},
|
||||
"shared2.foobar.headscale.net": {},
|
||||
"shared3.foobar.headscale.net": {},
|
||||
},
|
||||
Domains: []string{
|
||||
"foobar.headscale.net",
|
||||
"shared1.foobar.headscale.net",
|
||||
},
|
||||
Proxied: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
magicDNS: false,
|
||||
want: &tailcfg.DNSConfig{
|
||||
Domains: []string{"foobar.headscale.net"},
|
||||
Proxied: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(fmt.Sprintf("with-magicdns-%v", tt.magicDNS), func(t *testing.T) {
|
||||
mach := func(hostname, username string, userid uint) types.Machine {
|
||||
return types.Machine{
|
||||
Hostname: hostname,
|
||||
UserID: userid,
|
||||
User: types.User{
|
||||
Name: username,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
baseDomain := "foobar.headscale.net"
|
||||
|
||||
dnsConfigOrig := tailcfg.DNSConfig{
|
||||
Routes: make(map[string][]*dnstype.Resolver),
|
||||
Domains: []string{baseDomain},
|
||||
Proxied: tt.magicDNS,
|
||||
}
|
||||
|
||||
machineInShared1 := mach("test_get_shared_nodes_1", "shared1", 1)
|
||||
machineInShared2 := mach("test_get_shared_nodes_2", "shared2", 2)
|
||||
machineInShared3 := mach("test_get_shared_nodes_3", "shared3", 3)
|
||||
machine2InShared1 := mach("test_get_shared_nodes_4", "shared1", 1)
|
||||
|
||||
peersOfMachineInShared1 := types.Machines{
|
||||
machineInShared1,
|
||||
machineInShared2,
|
||||
machineInShared3,
|
||||
machine2InShared1,
|
||||
}
|
||||
|
||||
got := generateDNSConfig(
|
||||
&dnsConfigOrig,
|
||||
baseDomain,
|
||||
machineInShared1,
|
||||
peersOfMachineInShared1,
|
||||
)
|
||||
|
||||
if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" {
|
||||
t.Errorf("expandAlias() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
15
hscontrol/mapper/suite_test.go
Normal file
15
hscontrol/mapper/suite_test.go
Normal file
|
@ -0,0 +1,15 @@
|
|||
package mapper
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gopkg.in/check.v1"
|
||||
)
|
||||
|
||||
func Test(t *testing.T) {
|
||||
check.TestingT(t)
|
||||
}
|
||||
|
||||
var _ = check.Suite(&Suite{})
|
||||
|
||||
type Suite struct{}
|
|
@ -9,6 +9,7 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/mapper"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
@ -61,7 +62,7 @@ func (h *Headscale) KeyHandler(
|
|||
// TS2021 (Tailscale v2 protocol) requires to have a different key
|
||||
if clientCapabilityVersion >= NoiseCapabilityVersion {
|
||||
resp := tailcfg.OverTLSPublicKeyResponse{
|
||||
LegacyPublicKey: h.privateKey.Public(),
|
||||
LegacyPublicKey: h.privateKey2019.Public(),
|
||||
PublicKey: h.noisePrivateKey.Public(),
|
||||
}
|
||||
writer.Header().Set("Content-Type", "application/json")
|
||||
|
@ -84,7 +85,7 @@ func (h *Headscale) KeyHandler(
|
|||
// Old clients don't send a 'v' parameter, so we send the legacy public key
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
_, err := writer.Write([]byte(util.MachinePublicKeyStripPrefix(h.privateKey.Public())))
|
||||
_, err := writer.Write([]byte(util.MachinePublicKeyStripPrefix(h.privateKey2019.Public())))
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
@ -323,7 +324,7 @@ func (h *Headscale) handleAuthKeyCommon(
|
|||
Msg("Failed authentication via AuthKey")
|
||||
resp.MachineAuthorized = false
|
||||
|
||||
respBody, err := h.marshalResponse(resp, machineKey, isNoise)
|
||||
respBody, err := mapper.MarshalResponse(resp, h.privateKey2019, machineKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
@ -483,7 +484,7 @@ func (h *Headscale) handleAuthKeyCommon(
|
|||
// Otherwise it will need to exec `tailscale up` twice to fetch the *LoginName*
|
||||
resp.Login = *pak.User.TailscaleLogin()
|
||||
|
||||
respBody, err := h.marshalResponse(resp, machineKey, isNoise)
|
||||
respBody, err := mapper.MarshalResponse(resp, h.privateKey2019, machineKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
@ -548,7 +549,7 @@ func (h *Headscale) handleNewMachineCommon(
|
|||
registerRequest.NodeKey)
|
||||
}
|
||||
|
||||
respBody, err := h.marshalResponse(resp, machineKey, isNoise)
|
||||
respBody, err := mapper.MarshalResponse(resp, h.privateKey2019, machineKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
@ -609,7 +610,7 @@ func (h *Headscale) handleMachineLogOutCommon(
|
|||
resp.MachineAuthorized = false
|
||||
resp.NodeKeyExpired = true
|
||||
resp.User = *machine.User.TailscaleUser()
|
||||
respBody, err := h.marshalResponse(resp, machineKey, isNoise)
|
||||
respBody, err := mapper.MarshalResponse(resp, h.privateKey2019, machineKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
@ -673,7 +674,7 @@ func (h *Headscale) handleMachineValidRegistrationCommon(
|
|||
resp.User = *machine.User.TailscaleUser()
|
||||
resp.Login = *machine.User.TailscaleLogin()
|
||||
|
||||
respBody, err := h.marshalResponse(resp, machineKey, isNoise)
|
||||
respBody, err := mapper.MarshalResponse(resp, h.privateKey2019, machineKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
@ -735,7 +736,7 @@ func (h *Headscale) handleMachineRefreshKeyCommon(
|
|||
|
||||
resp.AuthURL = ""
|
||||
resp.User = *machine.User.TailscaleUser()
|
||||
respBody, err := h.marshalResponse(resp, machineKey, isNoise)
|
||||
respBody, err := mapper.MarshalResponse(resp, h.privateKey2019, machineKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
@ -802,7 +803,7 @@ func (h *Headscale) handleMachineExpiredOrLoggedOutCommon(
|
|||
registerRequest.NodeKey)
|
||||
}
|
||||
|
||||
respBody, err := h.marshalResponse(resp, machineKey, isNoise)
|
||||
respBody, err := mapper.MarshalResponse(resp, h.privateKey2019, machineKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/mapper"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
@ -29,6 +30,19 @@ func (h *Headscale) handlePollCommon(
|
|||
mapRequest tailcfg.MapRequest,
|
||||
isNoise bool,
|
||||
) {
|
||||
// TODO(kradalby): This is a stepping stone, mapper should be initiated once
|
||||
// per client or something similar
|
||||
mapp := mapper.NewMapper(h.db,
|
||||
h.privateKey2019,
|
||||
isNoise,
|
||||
h.DERPMap,
|
||||
h.cfg.BaseDomain,
|
||||
h.cfg.DNSConfig,
|
||||
h.cfg.LogTail.Enabled,
|
||||
h.cfg.RandomizeClientPort,
|
||||
h.cfg.OIDC.StripEmaildomain,
|
||||
)
|
||||
|
||||
machine.Hostname = mapRequest.Hostinfo.Hostname
|
||||
machine.HostInfo = types.HostInfo(*mapRequest.Hostinfo)
|
||||
machine.DiscoKey = util.DiscoPublicKeyStripPrefix(mapRequest.DiscoKey)
|
||||
|
@ -87,7 +101,7 @@ func (h *Headscale) handlePollCommon(
|
|||
return
|
||||
}
|
||||
|
||||
mapResp, err := h.getMapResponseData(mapRequest, machine, isNoise)
|
||||
mapResp, err := mapp.CreateMapResponse(mapRequest, machine, h.ACLPolicy)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "PollNetMap").
|
||||
|
@ -245,6 +259,19 @@ func (h *Headscale) pollNetMapStream(
|
|||
updateChan chan struct{},
|
||||
isNoise bool,
|
||||
) {
|
||||
// TODO(kradalby): This is a stepping stone, mapper should be initiated once
|
||||
// per client or something similar
|
||||
mapp := mapper.NewMapper(h.db,
|
||||
h.privateKey2019,
|
||||
isNoise,
|
||||
h.DERPMap,
|
||||
h.cfg.BaseDomain,
|
||||
h.cfg.DNSConfig,
|
||||
h.cfg.LogTail.Enabled,
|
||||
h.cfg.RandomizeClientPort,
|
||||
h.cfg.OIDC.StripEmaildomain,
|
||||
)
|
||||
|
||||
h.pollNetMapStreamWG.Add(1)
|
||||
defer h.pollNetMapStreamWG.Done()
|
||||
|
||||
|
@ -463,7 +490,7 @@ func (h *Headscale) pollNetMapStream(
|
|||
Time("last_successful_update", lastUpdate).
|
||||
Time("last_state_change", h.getLastStateChange(machine.User)).
|
||||
Msgf("There has been updates since the last successful update to %s", machine.Hostname)
|
||||
data, err := h.getMapResponseData(mapRequest, machine, isNoise)
|
||||
data, err := mapp.CreateMapResponse(mapRequest, machine, h.ACLPolicy)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "PollNetMapStream").
|
||||
|
@ -623,6 +650,19 @@ func (h *Headscale) scheduledPollWorker(
|
|||
machine *types.Machine,
|
||||
isNoise bool,
|
||||
) {
|
||||
// TODO(kradalby): This is a stepping stone, mapper should be initiated once
|
||||
// per client or something similar
|
||||
mapp := mapper.NewMapper(h.db,
|
||||
h.privateKey2019,
|
||||
isNoise,
|
||||
h.DERPMap,
|
||||
h.cfg.BaseDomain,
|
||||
h.cfg.DNSConfig,
|
||||
h.cfg.LogTail.Enabled,
|
||||
h.cfg.RandomizeClientPort,
|
||||
h.cfg.OIDC.StripEmaildomain,
|
||||
)
|
||||
|
||||
keepAliveTicker := time.NewTicker(keepAliveInterval)
|
||||
updateCheckerTicker := time.NewTicker(h.cfg.NodeUpdateCheckInterval)
|
||||
|
||||
|
@ -643,7 +683,7 @@ func (h *Headscale) scheduledPollWorker(
|
|||
return
|
||||
|
||||
case <-keepAliveTicker.C:
|
||||
data, err := h.getMapKeepAliveResponseData(mapRequest, machine, isNoise)
|
||||
data, err := mapp.CreateKeepAliveResponse(mapRequest, machine)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("func", "keepAlive").
|
||||
|
|
|
@ -1,157 +1 @@
|
|||
package hscontrol
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/rs/zerolog/log"
|
||||
"tailscale.com/smallzstd"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
func (h *Headscale) getMapResponseData(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
machine *types.Machine,
|
||||
isNoise bool,
|
||||
) ([]byte, error) {
|
||||
mapResponse, err := h.generateMapResponse(mapRequest, machine)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if isNoise {
|
||||
return h.marshalMapResponse(mapResponse, key.MachinePublic{}, mapRequest.Compress, isNoise)
|
||||
}
|
||||
|
||||
var machineKey key.MachinePublic
|
||||
err = machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)))
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Cannot parse client key")
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return h.marshalMapResponse(mapResponse, machineKey, mapRequest.Compress, isNoise)
|
||||
}
|
||||
|
||||
func (h *Headscale) getMapKeepAliveResponseData(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
machine *types.Machine,
|
||||
isNoise bool,
|
||||
) ([]byte, error) {
|
||||
keepAliveResponse := tailcfg.MapResponse{
|
||||
KeepAlive: true,
|
||||
}
|
||||
|
||||
if isNoise {
|
||||
return h.marshalMapResponse(
|
||||
keepAliveResponse,
|
||||
key.MachinePublic{},
|
||||
mapRequest.Compress,
|
||||
isNoise,
|
||||
)
|
||||
}
|
||||
|
||||
var machineKey key.MachinePublic
|
||||
err := machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)))
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Cannot parse client key")
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return h.marshalMapResponse(keepAliveResponse, machineKey, mapRequest.Compress, isNoise)
|
||||
}
|
||||
|
||||
func (h *Headscale) marshalResponse(
|
||||
resp interface{},
|
||||
machineKey key.MachinePublic,
|
||||
isNoise bool,
|
||||
) ([]byte, error) {
|
||||
jsonBody, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Cannot marshal response")
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if isNoise {
|
||||
return jsonBody, nil
|
||||
}
|
||||
|
||||
return h.privateKey.SealTo(machineKey, jsonBody), nil
|
||||
}
|
||||
|
||||
func (h *Headscale) marshalMapResponse(
|
||||
resp interface{},
|
||||
machineKey key.MachinePublic,
|
||||
compression string,
|
||||
isNoise bool,
|
||||
) ([]byte, error) {
|
||||
jsonBody, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Cannot marshal map response")
|
||||
}
|
||||
|
||||
var respBody []byte
|
||||
if compression == util.ZstdCompression {
|
||||
respBody = zstdEncode(jsonBody)
|
||||
if !isNoise { // if legacy protocol
|
||||
respBody = h.privateKey.SealTo(machineKey, respBody)
|
||||
}
|
||||
} else {
|
||||
if !isNoise { // if legacy protocol
|
||||
respBody = h.privateKey.SealTo(machineKey, jsonBody)
|
||||
} else {
|
||||
respBody = jsonBody
|
||||
}
|
||||
}
|
||||
|
||||
data := make([]byte, reservedResponseHeaderSize)
|
||||
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
|
||||
data = append(data, respBody...)
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func zstdEncode(in []byte) []byte {
|
||||
encoder, ok := zstdEncoderPool.Get().(*zstd.Encoder)
|
||||
if !ok {
|
||||
panic("invalid type in sync pool")
|
||||
}
|
||||
out := encoder.EncodeAll(in, nil)
|
||||
_ = encoder.Close()
|
||||
zstdEncoderPool.Put(encoder)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
var zstdEncoderPool = &sync.Pool{
|
||||
New: func() any {
|
||||
encoder, err := smallzstd.NewEncoder(
|
||||
nil,
|
||||
zstd.WithEncoderLevel(zstd.SpeedFastest))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return encoder
|
||||
},
|
||||
}
|
||||
|
|
|
@ -45,7 +45,7 @@ func (h *Headscale) RegistrationHandler(
|
|||
return
|
||||
}
|
||||
registerRequest := tailcfg.RegisterRequest{}
|
||||
err = util.DecodeAndUnmarshalNaCl(body, ®isterRequest, &machineKey, h.privateKey)
|
||||
err = util.DecodeAndUnmarshalNaCl(body, ®isterRequest, &machineKey, h.privateKey2019)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
|
|
@ -57,7 +57,7 @@ func (h *Headscale) PollNetMapHandler(
|
|||
return
|
||||
}
|
||||
mapRequest := tailcfg.MapRequest{}
|
||||
err = util.DecodeAndUnmarshalNaCl(body, &mapRequest, &machineKey, h.privateKey)
|
||||
err = util.DecodeAndUnmarshalNaCl(body, &mapRequest, &machineKey, h.privateKey2019)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "PollNetMap").
|
||||
|
|
Loading…
Reference in a new issue