introduce mapper package

The mapper package contains functions related to creating and marshalling
reponses to machines.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2023-05-26 11:26:34 +01:00 committed by Kristoffer Dalby
parent 699655a93f
commit f7f472ae07
18 changed files with 780 additions and 1063 deletions

View file

@ -6,39 +6,36 @@
flake-utils.url = "github:numtide/flake-utils"; flake-utils.url = "github:numtide/flake-utils";
}; };
outputs = outputs = {
{ self self,
, nixpkgs nixpkgs,
, flake-utils flake-utils,
, ... ...
}: }: let
let
headscaleVersion = headscaleVersion =
if (self ? shortRev) if (self ? shortRev)
then self.shortRev then self.shortRev
else "dev"; else "dev";
in in
{ {
overlay = _: prev: overlay = _: prev: let
let
pkgs = nixpkgs.legacyPackages.${prev.system}; pkgs = nixpkgs.legacyPackages.${prev.system};
in in rec {
rec {
headscale = pkgs.buildGo120Module rec { headscale = pkgs.buildGo120Module rec {
pname = "headscale"; pname = "headscale";
version = headscaleVersion; version = headscaleVersion;
src = pkgs.lib.cleanSource self; src = pkgs.lib.cleanSource self;
tags = [ "ts2019" ]; tags = ["ts2019"];
# Only run unit tests when testing a build # Only run unit tests when testing a build
checkFlags = [ "-short" ]; checkFlags = ["-short"];
# When updating go.mod or go.sum, a new sha will need to be calculated, # When updating go.mod or go.sum, a new sha will need to be calculated,
# update this if you have a mismatch after doing a change to thos files. # update this if you have a mismatch after doing a change to thos files.
vendorSha256 = "sha256-IOkbbFtE6+tNKnglE/8ZuNxhPSnloqM2sLgTvagMmnc="; vendorSha256 = "sha256-ui0V7a8bAAm5B7zfN9g2pWTyMpudtm10RgYWQwC6kcA=";
ldflags = [ "-s" "-w" "-X github.com/juanfont/headscale/cmd/headscale/cli.Version=v${version}" ]; ldflags = ["-s" "-w" "-X github.com/juanfont/headscale/cmd/headscale/cli.Version=v${version}"];
}; };
golines = pkgs.buildGoModule rec { golines = pkgs.buildGoModule rec {
@ -54,7 +51,7 @@
vendorSha256 = "sha256-rxYuzn4ezAxaeDhxd8qdOzt+CKYIh03A9zKNdzILq18="; vendorSha256 = "sha256-rxYuzn4ezAxaeDhxd8qdOzt+CKYIh03A9zKNdzILq18=";
nativeBuildInputs = [ pkgs.installShellFiles ]; nativeBuildInputs = [pkgs.installShellFiles];
}; };
golangci-lint = prev.golangci-lint.override { golangci-lint = prev.golangci-lint.override {
@ -76,20 +73,19 @@
vendorSha256 = "sha256-dGdnDuRbwg8fU7uB5GaHEWa/zI3w06onqjturvooJQA="; vendorSha256 = "sha256-dGdnDuRbwg8fU7uB5GaHEWa/zI3w06onqjturvooJQA=";
nativeBuildInputs = [ pkgs.installShellFiles ]; nativeBuildInputs = [pkgs.installShellFiles];
subPackages = [ "protoc-gen-grpc-gateway" "protoc-gen-openapiv2" ]; subPackages = ["protoc-gen-grpc-gateway" "protoc-gen-openapiv2"];
}; };
}; };
} }
// flake-utils.lib.eachDefaultSystem // flake-utils.lib.eachDefaultSystem
(system: (system: let
let
pkgs = import nixpkgs { pkgs = import nixpkgs {
overlays = [ self.overlay ]; overlays = [self.overlay];
inherit system; inherit system;
}; };
buildDeps = with pkgs; [ git go_1_20 gnumake ]; buildDeps = with pkgs; [git go_1_20 gnumake];
devDeps = with pkgs; devDeps = with pkgs;
buildDeps buildDeps
++ [ ++ [
@ -123,11 +119,10 @@
headscale-docker = pkgs.dockerTools.buildLayeredImage { headscale-docker = pkgs.dockerTools.buildLayeredImage {
name = "headscale"; name = "headscale";
tag = headscaleVersion; tag = headscaleVersion;
contents = [ pkgs.headscale ]; contents = [pkgs.headscale];
config.Entrypoint = [ (pkgs.headscale + "/bin/headscale") ]; config.Entrypoint = [(pkgs.headscale + "/bin/headscale")];
}; };
in in rec {
rec {
# `nix develop` # `nix develop`
devShell = pkgs.mkShell { devShell = pkgs.mkShell {
buildInputs = devDeps; buildInputs = devDeps;

View file

@ -1,115 +0,0 @@
package hscontrol
import (
"time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"tailscale.com/tailcfg"
)
func (h *Headscale) generateMapResponse(
mapRequest tailcfg.MapRequest,
machine *types.Machine,
) (*tailcfg.MapResponse, error) {
log.Trace().
Str("func", "generateMapResponse").
Str("machine", mapRequest.Hostinfo.Hostname).
Msg("Creating Map response")
node, err := h.db.TailNode(*machine, h.ACLPolicy, h.cfg.DNSConfig)
if err != nil {
log.Error().
Caller().
Str("func", "generateMapResponse").
Err(err).
Msg("Cannot convert to node")
return nil, err
}
peers, err := h.db.GetValidPeers(h.aclRules, machine)
if err != nil {
log.Error().
Caller().
Str("func", "generateMapResponse").
Err(err).
Msg("Cannot fetch peers")
return nil, err
}
profiles := h.db.GetMapResponseUserProfiles(*machine, peers)
nodePeers, err := h.db.TailNodes(peers, h.ACLPolicy, h.cfg.DNSConfig)
if err != nil {
log.Error().
Caller().
Str("func", "generateMapResponse").
Err(err).
Msg("Failed to convert peers to Tailscale nodes")
return nil, err
}
dnsConfig := getMapResponseDNSConfig(
h.cfg.DNSConfig,
h.cfg.BaseDomain,
*machine,
peers,
)
now := time.Now()
resp := tailcfg.MapResponse{
KeepAlive: false,
Node: node,
// TODO: Only send if updated
DERPMap: h.DERPMap,
// TODO: Only send if updated
Peers: nodePeers,
// TODO(kradalby): Implement:
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L1351-L1374
// PeersChanged
// PeersRemoved
// PeersChangedPatch
// PeerSeenChange
// OnlineChange
// TODO: Only send if updated
DNSConfig: dnsConfig,
// TODO: Only send if updated
Domain: h.cfg.BaseDomain,
// Do not instruct clients to collect services, we do not
// support or do anything with them
CollectServices: "false",
// TODO: Only send if updated
PacketFilter: h.aclRules,
UserProfiles: profiles,
// TODO: Only send if updated
SSHPolicy: h.sshPolicy,
ControlTime: &now,
Debug: &tailcfg.Debug{
DisableLogTail: !h.cfg.LogTail.Enabled,
RandomizeClientPort: h.cfg.RandomizeClientPort,
},
}
log.Trace().
Str("func", "generateMapResponse").
Str("machine", mapRequest.Hostinfo.Hostname).
// Interface("payload", resp).
Msgf("Generated map response: %s", util.TailMapResponseToString(resp))
return &resp, nil
}

View file

@ -80,7 +80,7 @@ type Headscale struct {
dbString string dbString string
dbType string dbType string
dbDebug bool dbDebug bool
privateKey *key.MachinePrivate privateKey2019 *key.MachinePrivate
noisePrivateKey *key.MachinePrivate noisePrivateKey *key.MachinePrivate
DERPMap *tailcfg.DERPMap DERPMap *tailcfg.DERPMap
@ -166,7 +166,7 @@ func NewHeadscale(cfg *Config) (*Headscale, error) {
cfg: cfg, cfg: cfg,
dbType: cfg.DBtype, dbType: cfg.DBtype,
dbString: dbString, dbString: dbString,
privateKey: privateKey, privateKey2019: privateKey,
noisePrivateKey: noisePrivateKey, noisePrivateKey: noisePrivateKey,
aclRules: tailcfg.FilterAllowAll, // default allowall aclRules: tailcfg.FilterAllowAll, // default allowall
registrationCache: registrationCache, registrationCache: registrationCache,

View file

@ -39,15 +39,7 @@ var (
) )
) )
// filterMachinesByACL wrapper function to not have devs pass around locks and maps // ListPeers returns all peers of machine, regardless of any Policy.
// related to the application outside of tests.
func (hsdb *HSDatabase) filterMachinesByACL(
aclRules []tailcfg.FilterRule,
currentMachine *types.Machine, peers types.Machines,
) types.Machines {
return policy.FilterMachinesByACL(currentMachine, peers, aclRules)
}
func (hsdb *HSDatabase) ListPeers(machine *types.Machine) (types.Machines, error) { func (hsdb *HSDatabase) ListPeers(machine *types.Machine) (types.Machines, error) {
log.Trace(). log.Trace().
Caller(). Caller().
@ -72,67 +64,6 @@ func (hsdb *HSDatabase) ListPeers(machine *types.Machine) (types.Machines, error
return machines, nil return machines, nil
} }
func (hsdb *HSDatabase) getPeers(
aclRules []tailcfg.FilterRule,
machine *types.Machine,
) (types.Machines, error) {
var peers types.Machines
var err error
// If ACLs rules are defined, filter visible host list with the ACLs
// else use the classic user scope
if len(aclRules) > 0 {
var machines []types.Machine
machines, err = hsdb.ListMachines()
if err != nil {
log.Error().Err(err).Msg("Error retrieving list of machines")
return types.Machines{}, err
}
peers = hsdb.filterMachinesByACL(aclRules, machine, machines)
} else {
peers, err = hsdb.ListPeers(machine)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot fetch peers")
return types.Machines{}, err
}
}
sort.Slice(peers, func(i, j int) bool { return peers[i].ID < peers[j].ID })
log.Trace().
Caller().
Str("self", machine.Hostname).
Str("peers", peers.String()).
Msg("Peers returned to caller")
return peers, nil
}
func (hsdb *HSDatabase) GetValidPeers(
aclRules []tailcfg.FilterRule,
machine *types.Machine,
) (types.Machines, error) {
validPeers := make(types.Machines, 0)
peers, err := hsdb.getPeers(aclRules, machine)
if err != nil {
return types.Machines{}, err
}
for _, peer := range peers {
if !peer.IsExpired() {
validPeers = append(validPeers, peer)
}
}
return validPeers, nil
}
func (hsdb *HSDatabase) ListMachines() ([]types.Machine, error) { func (hsdb *HSDatabase) ListMachines() ([]types.Machine, error) {
machines := []types.Machine{} machines := []types.Machine{}
if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Find(&machines).Error; err != nil { if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Find(&machines).Error; err != nil {

View file

@ -293,8 +293,8 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
aclRules, _, err := policy.GenerateFilterRules(aclPolicy, machines, false) aclRules, _, err := policy.GenerateFilterRules(aclPolicy, machines, false)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
peersOfTestMachine := db.filterMachinesByACL(aclRules, testMachine, machines) peersOfTestMachine := policy.FilterMachinesByACL(testMachine, machines, aclRules)
peersOfAdminMachine := db.filterMachinesByACL(aclRules, adminMachine, machines) peersOfAdminMachine := policy.FilterMachinesByACL(adminMachine, machines, aclRules)
c.Log(peersOfTestMachine) c.Log(peersOfTestMachine)
c.Assert(len(peersOfTestMachine), check.Equals, 9) c.Assert(len(peersOfTestMachine), check.Equals, 9)

View file

@ -2,13 +2,11 @@ package db
import ( import (
"errors" "errors"
"fmt"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/tailcfg"
) )
var ( var (
@ -163,32 +161,3 @@ func (hsdb *HSDatabase) SetMachineUser(machine *types.Machine, username string)
return nil return nil
} }
func (hsdb *HSDatabase) GetMapResponseUserProfiles(
machine types.Machine,
peers types.Machines,
) []tailcfg.UserProfile {
userMap := make(map[string]types.User)
userMap[machine.User.Name] = machine.User
for _, peer := range peers {
userMap[peer.User.Name] = peer.User // not worth checking if already is there
}
profiles := []tailcfg.UserProfile{}
for _, user := range userMap {
displayName := user.Name
if hsdb.baseDomain != "" {
displayName = fmt.Sprintf("%s@%s", user.Name, hsdb.baseDomain)
}
profiles = append(profiles,
tailcfg.UserProfile{
ID: tailcfg.UserID(user.ID),
LoginName: user.Name,
DisplayName: displayName,
})
}
return profiles
}

View file

@ -1,13 +1,10 @@
package db package db
import ( import (
"net/netip"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"gopkg.in/check.v1" "gopkg.in/check.v1"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/tailcfg"
) )
func (s *Suite) TestCreateAndDestroyUser(c *check.C) { func (s *Suite) TestCreateAndDestroyUser(c *check.C) {
@ -94,151 +91,6 @@ func (s *Suite) TestRenameUser(c *check.C) {
c.Assert(err, check.Equals, ErrUserExists) c.Assert(err, check.Equals, ErrUserExists)
} }
func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
userShared1, err := db.CreateUser("shared1")
c.Assert(err, check.IsNil)
userShared2, err := db.CreateUser("shared2")
c.Assert(err, check.IsNil)
userShared3, err := db.CreateUser("shared3")
c.Assert(err, check.IsNil)
preAuthKeyShared1, err := db.CreatePreAuthKey(
userShared1.Name,
false,
false,
nil,
nil,
)
c.Assert(err, check.IsNil)
preAuthKeyShared2, err := db.CreatePreAuthKey(
userShared2.Name,
false,
false,
nil,
nil,
)
c.Assert(err, check.IsNil)
preAuthKeyShared3, err := db.CreatePreAuthKey(
userShared3.Name,
false,
false,
nil,
nil,
)
c.Assert(err, check.IsNil)
preAuthKey2Shared1, err := db.CreatePreAuthKey(
userShared1.Name,
false,
false,
nil,
nil,
)
c.Assert(err, check.IsNil)
_, err = db.GetMachine(userShared1.Name, "test_get_shared_nodes_1")
c.Assert(err, check.NotNil)
machineInShared1 := &types.Machine{
ID: 1,
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
Hostname: "test_get_shared_nodes_1",
UserID: userShared1.ID,
User: *userShared1,
RegisterMethod: util.RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
AuthKeyID: uint(preAuthKeyShared1.ID),
}
db.db.Save(machineInShared1)
_, err = db.GetMachine(userShared1.Name, machineInShared1.Hostname)
c.Assert(err, check.IsNil)
machineInShared2 := &types.Machine{
ID: 2,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Hostname: "test_get_shared_nodes_2",
UserID: userShared2.ID,
User: *userShared2,
RegisterMethod: util.RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
AuthKeyID: uint(preAuthKeyShared2.ID),
}
db.db.Save(machineInShared2)
_, err = db.GetMachine(userShared2.Name, machineInShared2.Hostname)
c.Assert(err, check.IsNil)
machineInShared3 := &types.Machine{
ID: 3,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Hostname: "test_get_shared_nodes_3",
UserID: userShared3.ID,
User: *userShared3,
RegisterMethod: util.RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
AuthKeyID: uint(preAuthKeyShared3.ID),
}
db.db.Save(machineInShared3)
_, err = db.GetMachine(userShared3.Name, machineInShared3.Hostname)
c.Assert(err, check.IsNil)
machine2InShared1 := &types.Machine{
ID: 4,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Hostname: "test_get_shared_nodes_4",
UserID: userShared1.ID,
User: *userShared1,
RegisterMethod: util.RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")},
AuthKeyID: uint(preAuthKey2Shared1.ID),
}
db.db.Save(machine2InShared1)
peersOfMachine1InShared1, err := db.getPeers([]tailcfg.FilterRule{}, machineInShared1)
c.Assert(err, check.IsNil)
userProfiles := db.GetMapResponseUserProfiles(
*machineInShared1,
peersOfMachine1InShared1,
)
c.Assert(len(userProfiles), check.Equals, 3)
found := false
for _, userProfiles := range userProfiles {
if userProfiles.DisplayName == userShared1.Name {
found = true
break
}
}
c.Assert(found, check.Equals, true)
found = false
for _, userProfile := range userProfiles {
if userProfile.DisplayName == userShared2.Name {
found = true
break
}
}
c.Assert(found, check.Equals, true)
}
func (s *Suite) TestSetMachineUser(c *check.C) { func (s *Suite) TestSetMachineUser(c *check.C) {
oldUser, err := db.CreateUser("old") oldUser, err := db.CreateUser("old")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)

View file

@ -32,7 +32,7 @@ type DERPServer struct {
func (h *Headscale) NewDERPServer() (*DERPServer, error) { func (h *Headscale) NewDERPServer() (*DERPServer, error) {
log.Trace().Caller().Msg("Creating new embedded DERP server") log.Trace().Caller().Msg("Creating new embedded DERP server")
server := derp.NewServer(key.NodePrivate(*h.privateKey), log.Info().Msgf) server := derp.NewServer(key.NodePrivate(*h.privateKey2019), log.Info().Msgf)
region, err := h.generateRegionLocalDERP() region, err := h.generateRegionLocalDERP()
if err != nil { if err != nil {
return nil, err return nil, err
@ -156,7 +156,7 @@ func (h *Headscale) DERPHandler(
log.Trace().Caller().Msgf("Hijacked connection from %v", req.RemoteAddr) log.Trace().Caller().Msgf("Hijacked connection from %v", req.RemoteAddr)
if !fastStart { if !fastStart {
pubKey := h.privateKey.Public() pubKey := h.privateKey2019.Public()
pubKeyStr, _ := pubKey.MarshalText() //nolint pubKeyStr, _ := pubKey.MarshalText() //nolint
fmt.Fprintf(conn, "HTTP/1.1 101 Switching Protocols\r\n"+ fmt.Fprintf(conn, "HTTP/1.1 101 Switching Protocols\r\n"+
"Upgrade: DERP\r\n"+ "Upgrade: DERP\r\n"+

View file

@ -3,14 +3,9 @@ package hscontrol
import ( import (
"fmt" "fmt"
"net/netip" "net/netip"
"net/url"
"strings" "strings"
mapset "github.com/deckarep/golang-set/v2"
"github.com/juanfont/headscale/hscontrol/types"
"go4.org/netipx" "go4.org/netipx"
"tailscale.com/tailcfg"
"tailscale.com/types/dnstype"
"tailscale.com/util/dnsname" "tailscale.com/util/dnsname"
) )
@ -23,10 +18,6 @@ const (
ipv6AddressLength = 128 ipv6AddressLength = 128
) )
const (
nextDNSDoHPrefix = "https://dns.nextdns.io"
)
// generateMagicDNSRootDomains generates a list of DNS entries to be included in `Routes` in `MapResponse`. // generateMagicDNSRootDomains generates a list of DNS entries to be included in `Routes` in `MapResponse`.
// This list of reverse DNS entries instructs the OS on what subnets and domains the Tailscale embedded DNS // This list of reverse DNS entries instructs the OS on what subnets and domains the Tailscale embedded DNS
// server (listening in 100.100.100.100 udp/53) should be used for. // server (listening in 100.100.100.100 udp/53) should be used for.
@ -158,63 +149,3 @@ func generateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
return fqdns return fqdns
} }
// If any nextdns DoH resolvers are present in the list of resolvers it will
// take metadata from the machine metadata and instruct tailscale to add it
// to the requests. This makes it possible to identify from which device the
// requests come in the NextDNS dashboard.
//
// This will produce a resolver like:
// `https://dns.nextdns.io/<nextdns-id>?device_name=node-name&device_model=linux&device_ip=100.64.0.1`
func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine types.Machine) {
for _, resolver := range resolvers {
if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) {
attrs := url.Values{
"device_name": []string{machine.Hostname},
"device_model": []string{machine.HostInfo.OS},
}
if len(machine.IPAddresses) > 0 {
attrs.Add("device_ip", machine.IPAddresses[0].String())
}
resolver.Addr = fmt.Sprintf("%s?%s", resolver.Addr, attrs.Encode())
}
}
}
func getMapResponseDNSConfig(
dnsConfigOrig *tailcfg.DNSConfig,
baseDomain string,
machine types.Machine,
peers types.Machines,
) *tailcfg.DNSConfig {
var dnsConfig *tailcfg.DNSConfig = dnsConfigOrig.Clone()
if dnsConfigOrig != nil && dnsConfigOrig.Proxied { // if MagicDNS is enabled
// Only inject the Search Domain of the current user - shared nodes should use their full FQDN
dnsConfig.Domains = append(
dnsConfig.Domains,
fmt.Sprintf(
"%s.%s",
machine.User.Name,
baseDomain,
),
)
userSet := mapset.NewSet[types.User]()
userSet.Add(machine.User)
for _, p := range peers {
userSet.Add(p.User)
}
for _, user := range userSet.ToSlice() {
dnsRoute := fmt.Sprintf("%v.%v", user.Name, baseDomain)
dnsConfig.Routes[dnsRoute] = nil
}
} else {
dnsConfig = dnsConfigOrig
}
addNextDNSMetadata(dnsConfig.Resolvers, machine)
return dnsConfig
}

View file

@ -1,14 +1,9 @@
package hscontrol package hscontrol
import ( import (
"fmt"
"net/netip" "net/netip"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"gopkg.in/check.v1" "gopkg.in/check.v1"
"tailscale.com/tailcfg"
"tailscale.com/types/dnstype"
) )
func (s *Suite) TestMagicDNSRootDomains100(c *check.C) { func (s *Suite) TestMagicDNSRootDomains100(c *check.C) {
@ -112,293 +107,3 @@ func (s *Suite) TestMagicDNSRootDomainsIPv6SingleMultiple(c *check.C) {
c.Assert(yieldsRoot("2.0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa."), check.Equals, true) c.Assert(yieldsRoot("2.0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa."), check.Equals, true)
c.Assert(yieldsRoot("3.0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa."), check.Equals, true) c.Assert(yieldsRoot("3.0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa."), check.Equals, true)
} }
func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
userShared1, err := app.db.CreateUser("shared1")
c.Assert(err, check.IsNil)
userShared2, err := app.db.CreateUser("shared2")
c.Assert(err, check.IsNil)
userShared3, err := app.db.CreateUser("shared3")
c.Assert(err, check.IsNil)
preAuthKeyInShared1, err := app.db.CreatePreAuthKey(
userShared1.Name,
false,
false,
nil,
nil,
)
c.Assert(err, check.IsNil)
preAuthKeyInShared2, err := app.db.CreatePreAuthKey(
userShared2.Name,
false,
false,
nil,
nil,
)
c.Assert(err, check.IsNil)
preAuthKeyInShared3, err := app.db.CreatePreAuthKey(
userShared3.Name,
false,
false,
nil,
nil,
)
c.Assert(err, check.IsNil)
PreAuthKey2InShared1, err := app.db.CreatePreAuthKey(
userShared1.Name,
false,
false,
nil,
nil,
)
c.Assert(err, check.IsNil)
_, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1")
c.Assert(err, check.NotNil)
machineInShared1 := &types.Machine{
ID: 1,
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
Hostname: "test_get_shared_nodes_1",
UserID: userShared1.ID,
User: *userShared1,
RegisterMethod: util.RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
AuthKeyID: uint(preAuthKeyInShared1.ID),
}
err = app.db.MachineSave(machineInShared1)
c.Assert(err, check.IsNil)
_, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname)
c.Assert(err, check.IsNil)
machineInShared2 := &types.Machine{
ID: 2,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Hostname: "test_get_shared_nodes_2",
UserID: userShared2.ID,
User: *userShared2,
RegisterMethod: util.RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
AuthKeyID: uint(preAuthKeyInShared2.ID),
}
err = app.db.MachineSave(machineInShared2)
c.Assert(err, check.IsNil)
_, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname)
c.Assert(err, check.IsNil)
machineInShared3 := &types.Machine{
ID: 3,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Hostname: "test_get_shared_nodes_3",
UserID: userShared3.ID,
User: *userShared3,
RegisterMethod: util.RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
AuthKeyID: uint(preAuthKeyInShared3.ID),
}
err = app.db.MachineSave(machineInShared3)
c.Assert(err, check.IsNil)
_, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname)
c.Assert(err, check.IsNil)
machine2InShared1 := &types.Machine{
ID: 4,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Hostname: "test_get_shared_nodes_4",
UserID: userShared1.ID,
User: *userShared1,
RegisterMethod: util.RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")},
AuthKeyID: uint(PreAuthKey2InShared1.ID),
}
err = app.db.MachineSave(machine2InShared1)
c.Assert(err, check.IsNil)
baseDomain := "foobar.headscale.net"
dnsConfigOrig := tailcfg.DNSConfig{
Routes: make(map[string][]*dnstype.Resolver),
Domains: []string{baseDomain},
Proxied: true,
}
peersOfMachineInShared1, err := app.db.GetValidPeers(app.aclRules, machineInShared1)
c.Assert(err, check.IsNil)
dnsConfig := getMapResponseDNSConfig(
&dnsConfigOrig,
baseDomain,
*machineInShared1,
peersOfMachineInShared1,
)
c.Assert(dnsConfig, check.NotNil)
c.Assert(len(dnsConfig.Routes), check.Equals, 3)
domainRouteShared1 := fmt.Sprintf("%s.%s", userShared1.Name, baseDomain)
_, ok := dnsConfig.Routes[domainRouteShared1]
c.Assert(ok, check.Equals, true)
domainRouteShared2 := fmt.Sprintf("%s.%s", userShared2.Name, baseDomain)
_, ok = dnsConfig.Routes[domainRouteShared2]
c.Assert(ok, check.Equals, true)
domainRouteShared3 := fmt.Sprintf("%s.%s", userShared3.Name, baseDomain)
_, ok = dnsConfig.Routes[domainRouteShared3]
c.Assert(ok, check.Equals, true)
}
func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
userShared1, err := app.db.CreateUser("shared1")
c.Assert(err, check.IsNil)
userShared2, err := app.db.CreateUser("shared2")
c.Assert(err, check.IsNil)
userShared3, err := app.db.CreateUser("shared3")
c.Assert(err, check.IsNil)
preAuthKeyInShared1, err := app.db.CreatePreAuthKey(
userShared1.Name,
false,
false,
nil,
nil,
)
c.Assert(err, check.IsNil)
preAuthKeyInShared2, err := app.db.CreatePreAuthKey(
userShared2.Name,
false,
false,
nil,
nil,
)
c.Assert(err, check.IsNil)
preAuthKeyInShared3, err := app.db.CreatePreAuthKey(
userShared3.Name,
false,
false,
nil,
nil,
)
c.Assert(err, check.IsNil)
preAuthKey2InShared1, err := app.db.CreatePreAuthKey(
userShared1.Name,
false,
false,
nil,
nil,
)
c.Assert(err, check.IsNil)
_, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1")
c.Assert(err, check.NotNil)
machineInShared1 := &types.Machine{
ID: 1,
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
Hostname: "test_get_shared_nodes_1",
UserID: userShared1.ID,
User: *userShared1,
RegisterMethod: util.RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
AuthKeyID: uint(preAuthKeyInShared1.ID),
}
err = app.db.MachineSave(machineInShared1)
c.Assert(err, check.IsNil)
_, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname)
c.Assert(err, check.IsNil)
machineInShared2 := &types.Machine{
ID: 2,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Hostname: "test_get_shared_nodes_2",
UserID: userShared2.ID,
User: *userShared2,
RegisterMethod: util.RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
AuthKeyID: uint(preAuthKeyInShared2.ID),
}
err = app.db.MachineSave(machineInShared2)
c.Assert(err, check.IsNil)
_, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname)
c.Assert(err, check.IsNil)
machineInShared3 := &types.Machine{
ID: 3,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Hostname: "test_get_shared_nodes_3",
UserID: userShared3.ID,
User: *userShared3,
RegisterMethod: util.RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
AuthKeyID: uint(preAuthKeyInShared3.ID),
}
err = app.db.MachineSave(machineInShared3)
c.Assert(err, check.IsNil)
_, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname)
c.Assert(err, check.IsNil)
machine2InShared1 := &types.Machine{
ID: 4,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Hostname: "test_get_shared_nodes_4",
UserID: userShared1.ID,
User: *userShared1,
RegisterMethod: util.RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")},
AuthKeyID: uint(preAuthKey2InShared1.ID),
}
err = app.db.MachineSave(machine2InShared1)
c.Assert(err, check.IsNil)
baseDomain := "foobar.headscale.net"
dnsConfigOrig := tailcfg.DNSConfig{
Routes: make(map[string][]*dnstype.Resolver),
Domains: []string{baseDomain},
Proxied: false,
}
peersOfMachine1Shared1, err := app.db.GetValidPeers(app.aclRules, machineInShared1)
c.Assert(err, check.IsNil)
dnsConfig := getMapResponseDNSConfig(
&dnsConfigOrig,
baseDomain,
*machineInShared1,
peersOfMachine1Shared1,
)
c.Assert(dnsConfig, check.NotNil)
c.Assert(len(dnsConfig.Routes), check.Equals, 0)
c.Assert(len(dnsConfig.Domains), check.Equals, 1)
}

418
hscontrol/mapper/mapper.go Normal file
View file

@ -0,0 +1,418 @@
package mapper
import (
"encoding/binary"
"encoding/json"
"fmt"
"net/url"
"strings"
"sync"
"time"
mapset "github.com/deckarep/golang-set/v2"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/klauspost/compress/zstd"
"github.com/rs/zerolog/log"
"tailscale.com/smallzstd"
"tailscale.com/tailcfg"
"tailscale.com/types/dnstype"
"tailscale.com/types/key"
)
const (
nextDNSDoHPrefix = "https://dns.nextdns.io"
reservedResponseHeaderSize = 4
)
type Mapper struct {
db *db.HSDatabase
privateKey2019 *key.MachinePrivate
isNoise bool
// Configuration
// TODO(kradalby): figure out if this is the format we want this in
derpMap *tailcfg.DERPMap
baseDomain string
dnsCfg *tailcfg.DNSConfig
logtail bool
randomClientPort bool
stripEmailDomain bool
}
func NewMapper(
db *db.HSDatabase,
privateKey *key.MachinePrivate,
isNoise bool,
derpMap *tailcfg.DERPMap,
baseDomain string,
dnsCfg *tailcfg.DNSConfig,
logtail bool,
randomClientPort bool,
stripEmailDomain bool,
) *Mapper {
return &Mapper{
db: db,
privateKey2019: privateKey,
isNoise: isNoise,
derpMap: derpMap,
baseDomain: baseDomain,
dnsCfg: dnsCfg,
logtail: logtail,
randomClientPort: randomClientPort,
stripEmailDomain: stripEmailDomain,
}
}
func (m Mapper) fullMapResponse(
mapRequest tailcfg.MapRequest,
machine *types.Machine,
pol *policy.ACLPolicy,
) (*tailcfg.MapResponse, error) {
log.Trace().
Caller().
Str("machine", mapRequest.Hostinfo.Hostname).
Msg("Creating Map response")
// TODO(kradalby): Decouple this from DB?
node, err := m.db.TailNode(*machine, pol, m.dnsCfg)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot convert to node")
return nil, err
}
peers, err := m.db.ListPeers(machine)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot fetch peers")
return nil, err
}
rules, sshPolicy, err := policy.GenerateFilterRules(pol, peers, m.stripEmailDomain)
if err != nil {
return nil, err
}
if len(rules) > 0 {
peers = policy.FilterMachinesByACL(machine, peers, rules)
}
profiles := generateUserProfiles(machine, peers, m.baseDomain)
// TODO(kradalby): Decouple this from DB?
nodePeers, err := m.db.TailNodes(peers, pol, m.dnsCfg)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to convert peers to Tailscale nodes")
return nil, err
}
// TODO(kradalby): Shold this mutation happen before TailNode(s) is called?
dnsConfig := generateDNSConfig(
m.dnsCfg,
m.baseDomain,
*machine,
peers,
)
now := time.Now()
resp := tailcfg.MapResponse{
KeepAlive: false,
Node: node,
// TODO: Only send if updated
DERPMap: m.derpMap,
// TODO: Only send if updated
Peers: nodePeers,
// TODO(kradalby): Implement:
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L1351-L1374
// PeersChanged
// PeersRemoved
// PeersChangedPatch
// PeerSeenChange
// OnlineChange
// TODO: Only send if updated
DNSConfig: dnsConfig,
// TODO: Only send if updated
Domain: m.baseDomain,
// Do not instruct clients to collect services, we do not
// support or do anything with them
CollectServices: "false",
// TODO: Only send if updated
PacketFilter: rules,
UserProfiles: profiles,
// TODO: Only send if updated
SSHPolicy: sshPolicy,
ControlTime: &now,
Debug: &tailcfg.Debug{
DisableLogTail: !m.logtail,
RandomizeClientPort: m.randomClientPort,
},
}
log.Trace().
Caller().
Str("machine", mapRequest.Hostinfo.Hostname).
// Interface("payload", resp).
Msgf("Generated map response: %s", util.TailMapResponseToString(resp))
return &resp, nil
}
func generateUserProfiles(
machine *types.Machine,
peers types.Machines,
baseDomain string,
) []tailcfg.UserProfile {
userMap := make(map[string]types.User)
userMap[machine.User.Name] = machine.User
for _, peer := range peers {
userMap[peer.User.Name] = peer.User // not worth checking if already is there
}
profiles := []tailcfg.UserProfile{}
for _, user := range userMap {
displayName := user.Name
if baseDomain != "" {
displayName = fmt.Sprintf("%s@%s", user.Name, baseDomain)
}
profiles = append(profiles,
tailcfg.UserProfile{
ID: tailcfg.UserID(user.ID),
LoginName: user.Name,
DisplayName: displayName,
})
}
return profiles
}
func generateDNSConfig(
base *tailcfg.DNSConfig,
baseDomain string,
machine types.Machine,
peers types.Machines,
) *tailcfg.DNSConfig {
dnsConfig := base.Clone()
// if MagicDNS is enabled
if base != nil && base.Proxied {
// Only inject the Search Domain of the current user
// shared nodes should use their full FQDN
dnsConfig.Domains = append(
dnsConfig.Domains,
fmt.Sprintf(
"%s.%s",
machine.User.Name,
baseDomain,
),
)
userSet := mapset.NewSet[types.User]()
userSet.Add(machine.User)
for _, p := range peers {
userSet.Add(p.User)
}
for _, user := range userSet.ToSlice() {
dnsRoute := fmt.Sprintf("%v.%v", user.Name, baseDomain)
dnsConfig.Routes[dnsRoute] = nil
}
} else {
dnsConfig = base
}
addNextDNSMetadata(dnsConfig.Resolvers, machine)
return dnsConfig
}
// If any nextdns DoH resolvers are present in the list of resolvers it will
// take metadata from the machine metadata and instruct tailscale to add it
// to the requests. This makes it possible to identify from which device the
// requests come in the NextDNS dashboard.
//
// This will produce a resolver like:
// `https://dns.nextdns.io/<nextdns-id>?device_name=node-name&device_model=linux&device_ip=100.64.0.1`
func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine types.Machine) {
for _, resolver := range resolvers {
if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) {
attrs := url.Values{
"device_name": []string{machine.Hostname},
"device_model": []string{machine.HostInfo.OS},
}
if len(machine.IPAddresses) > 0 {
attrs.Add("device_ip", machine.IPAddresses[0].String())
}
resolver.Addr = fmt.Sprintf("%s?%s", resolver.Addr, attrs.Encode())
}
}
}
func (m Mapper) CreateMapResponse(
mapRequest tailcfg.MapRequest,
machine *types.Machine,
pol *policy.ACLPolicy,
) ([]byte, error) {
mapResponse, err := m.fullMapResponse(mapRequest, machine, pol)
if err != nil {
return nil, err
}
if m.isNoise {
return m.marshalMapResponse(mapResponse, key.MachinePublic{}, mapRequest.Compress)
}
var machineKey key.MachinePublic
err = machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot parse client key")
return nil, err
}
return m.marshalMapResponse(mapResponse, machineKey, mapRequest.Compress)
}
func (m Mapper) CreateKeepAliveResponse(
mapRequest tailcfg.MapRequest,
machine *types.Machine,
) ([]byte, error) {
keepAliveResponse := tailcfg.MapResponse{
KeepAlive: true,
}
if m.isNoise {
return m.marshalMapResponse(
keepAliveResponse,
key.MachinePublic{},
mapRequest.Compress,
)
}
var machineKey key.MachinePublic
err := machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot parse client key")
return nil, err
}
return m.marshalMapResponse(keepAliveResponse, machineKey, mapRequest.Compress)
}
func MarshalResponse(
resp interface{},
privateKey2019 *key.MachinePrivate,
machineKey key.MachinePublic,
) ([]byte, error) {
jsonBody, err := json.Marshal(resp)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot marshal response")
return nil, err
}
if privateKey2019 != nil {
return privateKey2019.SealTo(machineKey, jsonBody), nil
}
return jsonBody, nil
}
func (m Mapper) marshalMapResponse(
resp interface{},
machineKey key.MachinePublic,
compression string,
) ([]byte, error) {
jsonBody, err := json.Marshal(resp)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot marshal map response")
}
var respBody []byte
if compression == util.ZstdCompression {
respBody = zstdEncode(jsonBody)
if !m.isNoise { // if legacy protocol
respBody = m.privateKey2019.SealTo(machineKey, respBody)
}
} else {
if !m.isNoise { // if legacy protocol
respBody = m.privateKey2019.SealTo(machineKey, jsonBody)
} else {
respBody = jsonBody
}
}
data := make([]byte, reservedResponseHeaderSize)
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
data = append(data, respBody...)
return data, nil
}
func zstdEncode(in []byte) []byte {
encoder, ok := zstdEncoderPool.Get().(*zstd.Encoder)
if !ok {
panic("invalid type in sync pool")
}
out := encoder.EncodeAll(in, nil)
_ = encoder.Close()
zstdEncoderPool.Put(encoder)
return out
}
var zstdEncoderPool = &sync.Pool{
New: func() any {
encoder, err := smallzstd.NewEncoder(
nil,
zstd.WithEncoderLevel(zstd.SpeedFastest))
if err != nil {
panic(err)
}
return encoder
},
}

View file

@ -0,0 +1,131 @@
package mapper
import (
"fmt"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/juanfont/headscale/hscontrol/types"
"gopkg.in/check.v1"
"tailscale.com/tailcfg"
"tailscale.com/types/dnstype"
)
func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
mach := func(hostname, username string, userid uint) types.Machine {
return types.Machine{
Hostname: hostname,
UserID: userid,
User: types.User{
Name: username,
},
}
}
machineInShared1 := mach("test_get_shared_nodes_1", "user1", 1)
machineInShared2 := mach("test_get_shared_nodes_2", "user2", 2)
machineInShared3 := mach("test_get_shared_nodes_3", "user3", 3)
machine2InShared1 := mach("test_get_shared_nodes_4", "user1", 1)
userProfiles := generateUserProfiles(
&machineInShared1,
types.Machines{
machineInShared2, machineInShared3, machine2InShared1,
},
"",
)
c.Assert(len(userProfiles), check.Equals, 3)
users := []string{
"user1", "user2", "user3",
}
for _, user := range users {
found := false
for _, userProfile := range userProfiles {
if userProfile.DisplayName == user {
found = true
break
}
}
c.Assert(found, check.Equals, true)
}
}
func TestDNSConfigMapResponse(t *testing.T) {
tests := []struct {
magicDNS bool
want *tailcfg.DNSConfig
}{
{
magicDNS: true,
want: &tailcfg.DNSConfig{
Routes: map[string][]*dnstype.Resolver{
"shared1.foobar.headscale.net": {},
"shared2.foobar.headscale.net": {},
"shared3.foobar.headscale.net": {},
},
Domains: []string{
"foobar.headscale.net",
"shared1.foobar.headscale.net",
},
Proxied: true,
},
},
{
magicDNS: false,
want: &tailcfg.DNSConfig{
Domains: []string{"foobar.headscale.net"},
Proxied: false,
},
},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("with-magicdns-%v", tt.magicDNS), func(t *testing.T) {
mach := func(hostname, username string, userid uint) types.Machine {
return types.Machine{
Hostname: hostname,
UserID: userid,
User: types.User{
Name: username,
},
}
}
baseDomain := "foobar.headscale.net"
dnsConfigOrig := tailcfg.DNSConfig{
Routes: make(map[string][]*dnstype.Resolver),
Domains: []string{baseDomain},
Proxied: tt.magicDNS,
}
machineInShared1 := mach("test_get_shared_nodes_1", "shared1", 1)
machineInShared2 := mach("test_get_shared_nodes_2", "shared2", 2)
machineInShared3 := mach("test_get_shared_nodes_3", "shared3", 3)
machine2InShared1 := mach("test_get_shared_nodes_4", "shared1", 1)
peersOfMachineInShared1 := types.Machines{
machineInShared1,
machineInShared2,
machineInShared3,
machine2InShared1,
}
got := generateDNSConfig(
&dnsConfigOrig,
baseDomain,
machineInShared1,
peersOfMachineInShared1,
)
if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" {
t.Errorf("expandAlias() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -0,0 +1,15 @@
package mapper
import (
"testing"
"gopkg.in/check.v1"
)
func Test(t *testing.T) {
check.TestingT(t)
}
var _ = check.Suite(&Suite{})
type Suite struct{}

View file

@ -9,6 +9,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/juanfont/headscale/hscontrol/mapper"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -61,7 +62,7 @@ func (h *Headscale) KeyHandler(
// TS2021 (Tailscale v2 protocol) requires to have a different key // TS2021 (Tailscale v2 protocol) requires to have a different key
if clientCapabilityVersion >= NoiseCapabilityVersion { if clientCapabilityVersion >= NoiseCapabilityVersion {
resp := tailcfg.OverTLSPublicKeyResponse{ resp := tailcfg.OverTLSPublicKeyResponse{
LegacyPublicKey: h.privateKey.Public(), LegacyPublicKey: h.privateKey2019.Public(),
PublicKey: h.noisePrivateKey.Public(), PublicKey: h.noisePrivateKey.Public(),
} }
writer.Header().Set("Content-Type", "application/json") writer.Header().Set("Content-Type", "application/json")
@ -84,7 +85,7 @@ func (h *Headscale) KeyHandler(
// Old clients don't send a 'v' parameter, so we send the legacy public key // Old clients don't send a 'v' parameter, so we send the legacy public key
writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusOK) writer.WriteHeader(http.StatusOK)
_, err := writer.Write([]byte(util.MachinePublicKeyStripPrefix(h.privateKey.Public()))) _, err := writer.Write([]byte(util.MachinePublicKeyStripPrefix(h.privateKey2019.Public())))
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
@ -323,7 +324,7 @@ func (h *Headscale) handleAuthKeyCommon(
Msg("Failed authentication via AuthKey") Msg("Failed authentication via AuthKey")
resp.MachineAuthorized = false resp.MachineAuthorized = false
respBody, err := h.marshalResponse(resp, machineKey, isNoise) respBody, err := mapper.MarshalResponse(resp, h.privateKey2019, machineKey)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
@ -483,7 +484,7 @@ func (h *Headscale) handleAuthKeyCommon(
// Otherwise it will need to exec `tailscale up` twice to fetch the *LoginName* // Otherwise it will need to exec `tailscale up` twice to fetch the *LoginName*
resp.Login = *pak.User.TailscaleLogin() resp.Login = *pak.User.TailscaleLogin()
respBody, err := h.marshalResponse(resp, machineKey, isNoise) respBody, err := mapper.MarshalResponse(resp, h.privateKey2019, machineKey)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
@ -548,7 +549,7 @@ func (h *Headscale) handleNewMachineCommon(
registerRequest.NodeKey) registerRequest.NodeKey)
} }
respBody, err := h.marshalResponse(resp, machineKey, isNoise) respBody, err := mapper.MarshalResponse(resp, h.privateKey2019, machineKey)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
@ -609,7 +610,7 @@ func (h *Headscale) handleMachineLogOutCommon(
resp.MachineAuthorized = false resp.MachineAuthorized = false
resp.NodeKeyExpired = true resp.NodeKeyExpired = true
resp.User = *machine.User.TailscaleUser() resp.User = *machine.User.TailscaleUser()
respBody, err := h.marshalResponse(resp, machineKey, isNoise) respBody, err := mapper.MarshalResponse(resp, h.privateKey2019, machineKey)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
@ -673,7 +674,7 @@ func (h *Headscale) handleMachineValidRegistrationCommon(
resp.User = *machine.User.TailscaleUser() resp.User = *machine.User.TailscaleUser()
resp.Login = *machine.User.TailscaleLogin() resp.Login = *machine.User.TailscaleLogin()
respBody, err := h.marshalResponse(resp, machineKey, isNoise) respBody, err := mapper.MarshalResponse(resp, h.privateKey2019, machineKey)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
@ -735,7 +736,7 @@ func (h *Headscale) handleMachineRefreshKeyCommon(
resp.AuthURL = "" resp.AuthURL = ""
resp.User = *machine.User.TailscaleUser() resp.User = *machine.User.TailscaleUser()
respBody, err := h.marshalResponse(resp, machineKey, isNoise) respBody, err := mapper.MarshalResponse(resp, h.privateKey2019, machineKey)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
@ -802,7 +803,7 @@ func (h *Headscale) handleMachineExpiredOrLoggedOutCommon(
registerRequest.NodeKey) registerRequest.NodeKey)
} }
respBody, err := h.marshalResponse(resp, machineKey, isNoise) respBody, err := mapper.MarshalResponse(resp, h.privateKey2019, machineKey)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().

View file

@ -6,6 +6,7 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/juanfont/headscale/hscontrol/mapper"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -29,6 +30,19 @@ func (h *Headscale) handlePollCommon(
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
isNoise bool, isNoise bool,
) { ) {
// TODO(kradalby): This is a stepping stone, mapper should be initiated once
// per client or something similar
mapp := mapper.NewMapper(h.db,
h.privateKey2019,
isNoise,
h.DERPMap,
h.cfg.BaseDomain,
h.cfg.DNSConfig,
h.cfg.LogTail.Enabled,
h.cfg.RandomizeClientPort,
h.cfg.OIDC.StripEmaildomain,
)
machine.Hostname = mapRequest.Hostinfo.Hostname machine.Hostname = mapRequest.Hostinfo.Hostname
machine.HostInfo = types.HostInfo(*mapRequest.Hostinfo) machine.HostInfo = types.HostInfo(*mapRequest.Hostinfo)
machine.DiscoKey = util.DiscoPublicKeyStripPrefix(mapRequest.DiscoKey) machine.DiscoKey = util.DiscoPublicKeyStripPrefix(mapRequest.DiscoKey)
@ -87,7 +101,7 @@ func (h *Headscale) handlePollCommon(
return return
} }
mapResp, err := h.getMapResponseData(mapRequest, machine, isNoise) mapResp, err := mapp.CreateMapResponse(mapRequest, machine, h.ACLPolicy)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
@ -245,6 +259,19 @@ func (h *Headscale) pollNetMapStream(
updateChan chan struct{}, updateChan chan struct{},
isNoise bool, isNoise bool,
) { ) {
// TODO(kradalby): This is a stepping stone, mapper should be initiated once
// per client or something similar
mapp := mapper.NewMapper(h.db,
h.privateKey2019,
isNoise,
h.DERPMap,
h.cfg.BaseDomain,
h.cfg.DNSConfig,
h.cfg.LogTail.Enabled,
h.cfg.RandomizeClientPort,
h.cfg.OIDC.StripEmaildomain,
)
h.pollNetMapStreamWG.Add(1) h.pollNetMapStreamWG.Add(1)
defer h.pollNetMapStreamWG.Done() defer h.pollNetMapStreamWG.Done()
@ -463,7 +490,7 @@ func (h *Headscale) pollNetMapStream(
Time("last_successful_update", lastUpdate). Time("last_successful_update", lastUpdate).
Time("last_state_change", h.getLastStateChange(machine.User)). Time("last_state_change", h.getLastStateChange(machine.User)).
Msgf("There has been updates since the last successful update to %s", machine.Hostname) Msgf("There has been updates since the last successful update to %s", machine.Hostname)
data, err := h.getMapResponseData(mapRequest, machine, isNoise) data, err := mapp.CreateMapResponse(mapRequest, machine, h.ACLPolicy)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
@ -623,6 +650,19 @@ func (h *Headscale) scheduledPollWorker(
machine *types.Machine, machine *types.Machine,
isNoise bool, isNoise bool,
) { ) {
// TODO(kradalby): This is a stepping stone, mapper should be initiated once
// per client or something similar
mapp := mapper.NewMapper(h.db,
h.privateKey2019,
isNoise,
h.DERPMap,
h.cfg.BaseDomain,
h.cfg.DNSConfig,
h.cfg.LogTail.Enabled,
h.cfg.RandomizeClientPort,
h.cfg.OIDC.StripEmaildomain,
)
keepAliveTicker := time.NewTicker(keepAliveInterval) keepAliveTicker := time.NewTicker(keepAliveInterval)
updateCheckerTicker := time.NewTicker(h.cfg.NodeUpdateCheckInterval) updateCheckerTicker := time.NewTicker(h.cfg.NodeUpdateCheckInterval)
@ -643,7 +683,7 @@ func (h *Headscale) scheduledPollWorker(
return return
case <-keepAliveTicker.C: case <-keepAliveTicker.C:
data, err := h.getMapKeepAliveResponseData(mapRequest, machine, isNoise) data, err := mapp.CreateKeepAliveResponse(mapRequest, machine)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("func", "keepAlive"). Str("func", "keepAlive").

View file

@ -1,157 +1 @@
package hscontrol package hscontrol
import (
"encoding/binary"
"encoding/json"
"sync"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/klauspost/compress/zstd"
"github.com/rs/zerolog/log"
"tailscale.com/smallzstd"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
func (h *Headscale) getMapResponseData(
mapRequest tailcfg.MapRequest,
machine *types.Machine,
isNoise bool,
) ([]byte, error) {
mapResponse, err := h.generateMapResponse(mapRequest, machine)
if err != nil {
return nil, err
}
if isNoise {
return h.marshalMapResponse(mapResponse, key.MachinePublic{}, mapRequest.Compress, isNoise)
}
var machineKey key.MachinePublic
err = machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot parse client key")
return nil, err
}
return h.marshalMapResponse(mapResponse, machineKey, mapRequest.Compress, isNoise)
}
func (h *Headscale) getMapKeepAliveResponseData(
mapRequest tailcfg.MapRequest,
machine *types.Machine,
isNoise bool,
) ([]byte, error) {
keepAliveResponse := tailcfg.MapResponse{
KeepAlive: true,
}
if isNoise {
return h.marshalMapResponse(
keepAliveResponse,
key.MachinePublic{},
mapRequest.Compress,
isNoise,
)
}
var machineKey key.MachinePublic
err := machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot parse client key")
return nil, err
}
return h.marshalMapResponse(keepAliveResponse, machineKey, mapRequest.Compress, isNoise)
}
func (h *Headscale) marshalResponse(
resp interface{},
machineKey key.MachinePublic,
isNoise bool,
) ([]byte, error) {
jsonBody, err := json.Marshal(resp)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot marshal response")
return nil, err
}
if isNoise {
return jsonBody, nil
}
return h.privateKey.SealTo(machineKey, jsonBody), nil
}
func (h *Headscale) marshalMapResponse(
resp interface{},
machineKey key.MachinePublic,
compression string,
isNoise bool,
) ([]byte, error) {
jsonBody, err := json.Marshal(resp)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot marshal map response")
}
var respBody []byte
if compression == util.ZstdCompression {
respBody = zstdEncode(jsonBody)
if !isNoise { // if legacy protocol
respBody = h.privateKey.SealTo(machineKey, respBody)
}
} else {
if !isNoise { // if legacy protocol
respBody = h.privateKey.SealTo(machineKey, jsonBody)
} else {
respBody = jsonBody
}
}
data := make([]byte, reservedResponseHeaderSize)
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
data = append(data, respBody...)
return data, nil
}
func zstdEncode(in []byte) []byte {
encoder, ok := zstdEncoderPool.Get().(*zstd.Encoder)
if !ok {
panic("invalid type in sync pool")
}
out := encoder.EncodeAll(in, nil)
_ = encoder.Close()
zstdEncoderPool.Put(encoder)
return out
}
var zstdEncoderPool = &sync.Pool{
New: func() any {
encoder, err := smallzstd.NewEncoder(
nil,
zstd.WithEncoderLevel(zstd.SpeedFastest))
if err != nil {
panic(err)
}
return encoder
},
}

View file

@ -45,7 +45,7 @@ func (h *Headscale) RegistrationHandler(
return return
} }
registerRequest := tailcfg.RegisterRequest{} registerRequest := tailcfg.RegisterRequest{}
err = util.DecodeAndUnmarshalNaCl(body, &registerRequest, &machineKey, h.privateKey) err = util.DecodeAndUnmarshalNaCl(body, &registerRequest, &machineKey, h.privateKey2019)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().

View file

@ -57,7 +57,7 @@ func (h *Headscale) PollNetMapHandler(
return return
} }
mapRequest := tailcfg.MapRequest{} mapRequest := tailcfg.MapRequest{}
err = util.DecodeAndUnmarshalNaCl(body, &mapRequest, &machineKey, h.privateKey) err = util.DecodeAndUnmarshalNaCl(body, &mapRequest, &machineKey, h.privateKey2019)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").