Run the Noise handlers under a new struct so we can access the noiseConn from the handlers

In TS2021 the MachineKey can be obtained from noiseConn.Peer() - contrary to what I thought before,
where I assumed MachineKey was dropped in TS2021.

By having a ts2021App and hanging from there the TS2021 handlers, we can fetch again the MachineKey.
This commit is contained in:
Juan Font 2022-12-09 16:56:43 +00:00
parent 6e890afc5f
commit 593040b73d
11 changed files with 210 additions and 118 deletions

View file

@ -7,6 +7,7 @@
- Reworked routing and added support for subnet router failover [#1024](https://github.com/juanfont/headscale/pull/1024) - Reworked routing and added support for subnet router failover [#1024](https://github.com/juanfont/headscale/pull/1024)
- Added an OIDC AllowGroups Configuration options and authorization check [#1041](https://github.com/juanfont/headscale/pull/1041) - Added an OIDC AllowGroups Configuration options and authorization check [#1041](https://github.com/juanfont/headscale/pull/1041)
- Set `db_ssl` to false by default [#1052](https://github.com/juanfont/headscale/pull/1052) - Set `db_ssl` to false by default [#1052](https://github.com/juanfont/headscale/pull/1052)
- Fix duplicate nodes due to incorrect implementation of the protocol [#1058](https://github.com/juanfont/headscale/pull/1058)
- Report if a machine is online in CLI more accurately [#1062](https://github.com/juanfont/headscale/pull/1062) - Report if a machine is online in CLI more accurately [#1062](https://github.com/juanfont/headscale/pull/1062)
## 0.17.1 (2022-12-05) ## 0.17.1 (2022-12-05)

18
app.go
View file

@ -81,8 +81,6 @@ type Headscale struct {
privateKey *key.MachinePrivate privateKey *key.MachinePrivate
noisePrivateKey *key.MachinePrivate noisePrivateKey *key.MachinePrivate
noiseMux *mux.Router
DERPMap *tailcfg.DERPMap DERPMap *tailcfg.DERPMap
DERPServer *DERPServer DERPServer *DERPServer
@ -472,16 +470,6 @@ func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *mux.Router {
return router return router
} }
func (h *Headscale) createNoiseMux() *mux.Router {
router := mux.NewRouter()
router.HandleFunc("/machine/register", h.NoiseRegistrationHandler).
Methods(http.MethodPost)
router.HandleFunc("/machine/map", h.NoisePollNetMapHandler)
return router
}
// Serve launches a GIN server with the Headscale API. // Serve launches a GIN server with the Headscale API.
func (h *Headscale) Serve() error { func (h *Headscale) Serve() error {
var err error var err error
@ -641,12 +629,6 @@ func (h *Headscale) Serve() error {
// over our main Addr. It also serves the legacy Tailcale API // over our main Addr. It also serves the legacy Tailcale API
router := h.createRouter(grpcGatewayMux) router := h.createRouter(grpcGatewayMux)
// This router is served only over the Noise connection, and exposes only the new API.
//
// The HTTP2 server that exposes this router is created for
// a single hijacked connection from /ts2021, using netutil.NewOneConnListener
h.noiseMux = h.createNoiseMux()
httpServer := &http.Server{ httpServer := &http.Server{
Addr: h.cfg.Addr, Addr: h.cfg.Addr,
Handler: router, Handler: router,

View file

@ -469,6 +469,7 @@ func nodesToPtables(
"ID", "ID",
"Hostname", "Hostname",
"Name", "Name",
"MachineKey",
"NodeKey", "NodeKey",
"Namespace", "Namespace",
"IP addresses", "IP addresses",
@ -504,8 +505,16 @@ func nodesToPtables(
expiry = machine.Expiry.AsTime() expiry = machine.Expiry.AsTime()
} }
var machineKey key.MachinePublic
err := machineKey.UnmarshalText(
[]byte(headscale.MachinePublicKeyEnsurePrefix(machine.MachineKey)),
)
if err != nil {
machineKey = key.MachinePublic{}
}
var nodeKey key.NodePublic var nodeKey key.NodePublic
err := nodeKey.UnmarshalText( err = nodeKey.UnmarshalText(
[]byte(headscale.NodePublicKeyEnsurePrefix(machine.NodeKey)), []byte(headscale.NodePublicKeyEnsurePrefix(machine.NodeKey)),
) )
if err != nil { if err != nil {
@ -568,6 +577,7 @@ func nodesToPtables(
strconv.FormatUint(machine.Id, headscale.Base10), strconv.FormatUint(machine.Id, headscale.Base10),
machine.Name, machine.Name,
machine.GetGivenName(), machine.GetGivenName(),
machineKey.ShortString(),
nodeKey.ShortString(), nodeKey.ShortString(),
namespace, namespace,
strings.Join([]string{IPV4Address, IPV6Address}, ", "), strings.Join([]string{IPV4Address, IPV6Address}, ", "),

View file

@ -418,13 +418,15 @@ func (h *Headscale) GetMachineByNodeKey(
return &machine, nil return &machine, nil
} }
// GetMachineByAnyNodeKey finds a Machine by its current NodeKey or the old one, and returns the Machine struct. // GetMachineByAnyNodeKey finds a Machine by its MachineKey, its current NodeKey or the old one, and returns the Machine struct.
func (h *Headscale) GetMachineByAnyNodeKey( func (h *Headscale) GetMachineByAnyKey(
nodeKey key.NodePublic, oldNodeKey key.NodePublic, machineKey key.MachinePublic, nodeKey key.NodePublic, oldNodeKey key.NodePublic,
) (*Machine, error) { ) (*Machine, error) {
machine := Machine{} machine := Machine{}
if result := h.db.Preload("Namespace").First(&machine, "node_key = ? OR node_key = ?", if result := h.db.Preload("Namespace").First(&machine, "machine_key = ? OR node_key = ? OR node_key = ?",
NodePublicKeyStripPrefix(nodeKey), NodePublicKeyStripPrefix(oldNodeKey)); result.Error != nil { MachinePublicKeyStripPrefix(machineKey),
NodePublicKeyStripPrefix(nodeKey),
NodePublicKeyStripPrefix(oldNodeKey)); result.Error != nil {
return nil, result.Error return nil, result.Error
} }
@ -850,6 +852,12 @@ func (h *Headscale) RegisterMachineFromAuthCallback(
return nil, err return nil, err
} }
log.Debug().
Str("nodeKey", nodeKey.ShortString()).
Str("namespaceName", namespaceName).
Str("registrationMethod", registrationMethod).
Msg("Registering machine from API/CLI or auth callback")
if machineInterface, ok := h.registrationCache.Get(NodePublicKeyStripPrefix(nodeKey)); ok { if machineInterface, ok := h.registrationCache.Get(NodePublicKeyStripPrefix(nodeKey)); ok {
if registrationMachine, ok := machineInterface.(Machine); ok { if registrationMachine, ok := machineInterface.(Machine); ok {
namespace, err := h.GetNamespace(namespaceName) namespace, err := h.GetNamespace(namespaceName)
@ -889,15 +897,31 @@ func (h *Headscale) RegisterMachineFromAuthCallback(
// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey. // RegisterMachine is executed from the CLI to register a new Machine using its MachineKey.
func (h *Headscale) RegisterMachine(machine Machine, func (h *Headscale) RegisterMachine(machine Machine,
) (*Machine, error) { ) (*Machine, error) {
log.Trace(). log.Debug().
Caller(). Str("machine", machine.Hostname).
Str("machine_key", machine.MachineKey). Str("machine_key", machine.MachineKey).
Str("node_key", machine.NodeKey).
Str("namespace", machine.Namespace.Name).
Msg("Registering machine") Msg("Registering machine")
// If the machine exists and we had already IPs for it, we just save it
// so we store the machine.Expire and machine.Nodekey that has been set when
// adding it to the registrationCache
if len(machine.IPAddresses) > 0 {
if err := h.db.Save(&machine).Error; err != nil {
return nil, fmt.Errorf("failed register existing machine in the database: %w", err)
}
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Msg("Attempting to register machine") Str("machine_key", machine.MachineKey).
Str("node_key", machine.NodeKey).
Str("namespace", machine.Namespace.Name).
Msg("Machine authorized again")
return &machine, nil
}
h.ipAllocationMutex.Lock() h.ipAllocationMutex.Lock()
defer h.ipAllocationMutex.Unlock() defer h.ipAllocationMutex.Unlock()

View file

@ -77,10 +77,11 @@ func (s *Suite) TestGetMachineByNodeKey(c *check.C) {
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
nodeKey := key.NewNode() nodeKey := key.NewNode()
machineKey := key.NewMachine()
machine := Machine{ machine := Machine{
ID: 0, ID: 0,
MachineKey: "foo", MachineKey: MachinePublicKeyStripPrefix(machineKey.Public()),
NodeKey: NodePublicKeyStripPrefix(nodeKey.Public()), NodeKey: NodePublicKeyStripPrefix(nodeKey.Public()),
DiscoKey: "faa", DiscoKey: "faa",
Hostname: "testmachine", Hostname: "testmachine",
@ -107,9 +108,11 @@ func (s *Suite) TestGetMachineByAnyNodeKey(c *check.C) {
nodeKey := key.NewNode() nodeKey := key.NewNode()
oldNodeKey := key.NewNode() oldNodeKey := key.NewNode()
machineKey := key.NewMachine()
machine := Machine{ machine := Machine{
ID: 0, ID: 0,
MachineKey: "foo", MachineKey: MachinePublicKeyStripPrefix(machineKey.Public()),
NodeKey: NodePublicKeyStripPrefix(nodeKey.Public()), NodeKey: NodePublicKeyStripPrefix(nodeKey.Public()),
DiscoKey: "faa", DiscoKey: "faa",
Hostname: "testmachine", Hostname: "testmachine",
@ -119,7 +122,7 @@ func (s *Suite) TestGetMachineByAnyNodeKey(c *check.C) {
} }
app.db.Save(&machine) app.db.Save(&machine)
_, err = app.GetMachineByAnyNodeKey(nodeKey.Public(), oldNodeKey.Public()) _, err = app.GetMachineByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public())
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
} }

View file

@ -3,9 +3,11 @@ package headscale
import ( import (
"net/http" "net/http"
"github.com/gorilla/mux"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/h2c" "golang.org/x/net/http2/h2c"
"tailscale.com/control/controlbase"
"tailscale.com/control/controlhttp" "tailscale.com/control/controlhttp"
"tailscale.com/net/netutil" "tailscale.com/net/netutil"
) )
@ -15,6 +17,12 @@ const (
ts2021UpgradePath = "/ts2021" ts2021UpgradePath = "/ts2021"
) )
type ts2021App struct {
headscale *Headscale
conn *controlbase.Conn
}
// NoiseUpgradeHandler is to upgrade the connection and hijack the net.Conn // NoiseUpgradeHandler is to upgrade the connection and hijack the net.Conn
// in order to use the Noise-based TS2021 protocol. Listens in /ts2021. // in order to use the Noise-based TS2021 protocol. Listens in /ts2021.
func (h *Headscale) NoiseUpgradeHandler( func (h *Headscale) NoiseUpgradeHandler(
@ -44,10 +52,25 @@ func (h *Headscale) NoiseUpgradeHandler(
return return
} }
ts2021App := ts2021App{
headscale: h,
conn: noiseConn,
}
// This router is served only over the Noise connection, and exposes only the new API.
//
// The HTTP2 server that exposes this router is created for
// a single hijacked connection from /ts2021, using netutil.NewOneConnListener
router := mux.NewRouter()
router.HandleFunc("/machine/register", ts2021App.NoiseRegistrationHandler).
Methods(http.MethodPost)
router.HandleFunc("/machine/map", ts2021App.NoisePollNetMapHandler)
server := http.Server{ server := http.Server{
ReadTimeout: HTTPReadTimeout, ReadTimeout: HTTPReadTimeout,
} }
server.Handler = h2c.NewHandler(h.noiseMux, &http2.Server{}) server.Handler = h2c.NewHandler(router, &http2.Server{})
err = server.Serve(netutil.NewOneConnListener(noiseConn, nil)) err = server.Serve(netutil.NewOneConnListener(noiseConn, nil))
if err != nil { if err != nil {
log.Info().Err(err).Msg("The HTTP2 server was closed") log.Info().Err(err).Msg("The HTTP2 server was closed")

View file

@ -99,13 +99,14 @@ func (h *Headscale) handleRegisterCommon(
req *http.Request, req *http.Request,
registerRequest tailcfg.RegisterRequest, registerRequest tailcfg.RegisterRequest,
machineKey key.MachinePublic, machineKey key.MachinePublic,
isNoise bool,
) { ) {
now := time.Now().UTC() now := time.Now().UTC()
machine, err := h.GetMachineByAnyNodeKey(registerRequest.NodeKey, registerRequest.OldNodeKey) machine, err := h.GetMachineByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey)
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
// If the machine has AuthKey set, handle registration via PreAuthKeys // If the machine has AuthKey set, handle registration via PreAuthKeys
if registerRequest.Auth.AuthKey != "" { if registerRequest.Auth.AuthKey != "" {
h.handleAuthKeyCommon(writer, registerRequest, machineKey) h.handleAuthKeyCommon(writer, registerRequest, machineKey, isNoise)
return return
} }
@ -123,10 +124,11 @@ func (h *Headscale) handleRegisterCommon(
log.Debug(). log.Debug().
Caller(). Caller().
Str("machine", registerRequest.Hostinfo.Hostname). Str("machine", registerRequest.Hostinfo.Hostname).
Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()). Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()). Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
Str("follow_up", registerRequest.Followup). Str("follow_up", registerRequest.Followup).
Bool("noise", machineKey.IsZero()). Bool("noise", isNoise).
Msg("Machine is waiting for interactive login") Msg("Machine is waiting for interactive login")
ticker := time.NewTicker(registrationHoldoff) ticker := time.NewTicker(registrationHoldoff)
@ -134,7 +136,7 @@ func (h *Headscale) handleRegisterCommon(
case <-req.Context().Done(): case <-req.Context().Done():
return return
case <-ticker.C: case <-ticker.C:
h.handleNewMachineCommon(writer, registerRequest, machineKey) h.handleNewMachineCommon(writer, registerRequest, machineKey, isNoise)
return return
} }
@ -144,10 +146,11 @@ func (h *Headscale) handleRegisterCommon(
log.Info(). log.Info().
Caller(). Caller().
Str("machine", registerRequest.Hostinfo.Hostname). Str("machine", registerRequest.Hostinfo.Hostname).
Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()). Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()). Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
Str("follow_up", registerRequest.Followup). Str("follow_up", registerRequest.Followup).
Bool("noise", machineKey.IsZero()). Bool("noise", isNoise).
Msg("New machine not yet in the database") Msg("New machine not yet in the database")
givenName, err := h.GenerateGivenName( givenName, err := h.GenerateGivenName(
@ -180,7 +183,7 @@ func (h *Headscale) handleRegisterCommon(
if !registerRequest.Expiry.IsZero() { if !registerRequest.Expiry.IsZero() {
log.Trace(). log.Trace().
Caller(). Caller().
Bool("noise", machineKey.IsZero()). Bool("noise", isNoise).
Str("machine", registerRequest.Hostinfo.Hostname). Str("machine", registerRequest.Hostinfo.Hostname).
Time("expiry", registerRequest.Expiry). Time("expiry", registerRequest.Expiry).
Msg("Non-zero expiry time requested") Msg("Non-zero expiry time requested")
@ -193,32 +196,56 @@ func (h *Headscale) handleRegisterCommon(
registerCacheExpiration, registerCacheExpiration,
) )
h.handleNewMachineCommon(writer, registerRequest, machineKey) h.handleNewMachineCommon(writer, registerRequest, machineKey, isNoise)
return return
} }
// The machine is already registered, so we need to pass through reauth or key update. // The machine is already in the DB. This could mean one of the following:
// - The machine is authenticated and ready to /map
// - We are doing a key refresh
// - The machine is logged out (or expired) and pending to be authorized. TODO(juan): We need to keep alive the connection here
if machine != nil { if machine != nil {
// (juan): For a while we had a bug where we were not storing the MachineKey for the nodes using the TS2021,
// due to a misunderstanding of the protocol https://github.com/juanfont/headscale/issues/1054
// So if we have a not valid MachineKey (but we were able to fetch the machine with the NodeKeys), we update it.
var storedMachineKey key.MachinePublic
err = storedMachineKey.UnmarshalText(
[]byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)),
)
if err != nil || storedMachineKey.IsZero() {
machine.MachineKey = MachinePublicKeyStripPrefix(machineKey)
if err := h.db.Save(&machine).Error; err != nil {
log.Error().
Caller().
Str("func", "RegistrationHandler").
Str("machine", machine.Hostname).
Err(err).
Msg("Error saving machine key to database")
return
}
}
// If the NodeKey stored in headscale is the same as the key presented in a registration // If the NodeKey stored in headscale is the same as the key presented in a registration
// request, then we have a node that is either: // request, then we have a node that is either:
// - Trying to log out (sending a expiry in the past) // - Trying to log out (sending a expiry in the past)
// - A valid, registered machine, looking for the node map // - A valid, registered machine, looking for /map
// - Expired machine wanting to reauthenticate // - Expired machine wanting to reauthenticate
if machine.NodeKey == NodePublicKeyStripPrefix(registerRequest.NodeKey) { if machine.NodeKey == NodePublicKeyStripPrefix(registerRequest.NodeKey) {
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout) // The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648 // https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
if !registerRequest.Expiry.IsZero() && if !registerRequest.Expiry.IsZero() &&
registerRequest.Expiry.UTC().Before(now) { registerRequest.Expiry.UTC().Before(now) {
h.handleMachineLogOutCommon(writer, *machine, machineKey) h.handleMachineLogOutCommon(writer, *machine, machineKey, isNoise)
return return
} }
// If machine is not expired, and is register, we have a already accepted this machine, // If machine is not expired, and it is register, we have a already accepted this machine,
// let it proceed with a valid registration // let it proceed with a valid registration
if !machine.isExpired() { if !machine.isExpired() {
h.handleMachineValidRegistrationCommon(writer, *machine, machineKey) h.handleMachineValidRegistrationCommon(writer, *machine, machineKey, isNoise)
return return
} }
@ -232,15 +259,23 @@ func (h *Headscale) handleRegisterCommon(
registerRequest, registerRequest,
*machine, *machine,
machineKey, machineKey,
isNoise,
) )
return return
} }
// The machine has expired // The machine has expired or it is logged out
h.handleMachineExpiredCommon(writer, registerRequest, *machine, machineKey) h.handleMachineExpiredOrLoggedOutCommon(writer, registerRequest, *machine, machineKey, isNoise)
// TODO(juan): RegisterRequest includes an Expiry time, that we could optionally use
machine.Expiry = &time.Time{} machine.Expiry = &time.Time{}
// If we are here it means the client needs to be reauthorized,
// we need to make sure the NodeKey matches the one in the request
// TODO(juan): What happens when using fast user switching between two
// headscale-managed tailnets?
machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey)
h.registrationCache.Set( h.registrationCache.Set(
NodePublicKeyStripPrefix(registerRequest.NodeKey), NodePublicKeyStripPrefix(registerRequest.NodeKey),
*machine, *machine,
@ -260,11 +295,12 @@ func (h *Headscale) handleAuthKeyCommon(
writer http.ResponseWriter, writer http.ResponseWriter,
registerRequest tailcfg.RegisterRequest, registerRequest tailcfg.RegisterRequest,
machineKey key.MachinePublic, machineKey key.MachinePublic,
isNoise bool,
) { ) {
log.Debug(). log.Debug().
Str("func", "handleAuthKeyCommon"). Str("func", "handleAuthKeyCommon").
Str("machine", registerRequest.Hostinfo.Hostname). Str("machine", registerRequest.Hostinfo.Hostname).
Bool("noise", machineKey.IsZero()). Bool("noise", isNoise).
Msgf("Processing auth key for %s", registerRequest.Hostinfo.Hostname) Msgf("Processing auth key for %s", registerRequest.Hostinfo.Hostname)
resp := tailcfg.RegisterResponse{} resp := tailcfg.RegisterResponse{}
@ -273,18 +309,18 @@ func (h *Headscale) handleAuthKeyCommon(
log.Error(). log.Error().
Caller(). Caller().
Str("func", "handleAuthKeyCommon"). Str("func", "handleAuthKeyCommon").
Bool("noise", machineKey.IsZero()). Bool("noise", isNoise).
Str("machine", registerRequest.Hostinfo.Hostname). Str("machine", registerRequest.Hostinfo.Hostname).
Err(err). Err(err).
Msg("Failed authentication via AuthKey") Msg("Failed authentication via AuthKey")
resp.MachineAuthorized = false resp.MachineAuthorized = false
respBody, err := h.marshalResponse(resp, machineKey) respBody, err := h.marshalResponse(resp, machineKey, isNoise)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
Str("func", "handleAuthKeyCommon"). Str("func", "handleAuthKeyCommon").
Bool("noise", machineKey.IsZero()). Bool("noise", isNoise).
Str("machine", registerRequest.Hostinfo.Hostname). Str("machine", registerRequest.Hostinfo.Hostname).
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
@ -301,7 +337,7 @@ func (h *Headscale) handleAuthKeyCommon(
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
Bool("noise", machineKey.IsZero()). Bool("noise", isNoise).
Err(err). Err(err).
Msg("Failed to write response") Msg("Failed to write response")
} }
@ -309,7 +345,7 @@ func (h *Headscale) handleAuthKeyCommon(
log.Error(). log.Error().
Caller(). Caller().
Str("func", "handleAuthKeyCommon"). Str("func", "handleAuthKeyCommon").
Bool("noise", machineKey.IsZero()). Bool("noise", isNoise).
Str("machine", registerRequest.Hostinfo.Hostname). Str("machine", registerRequest.Hostinfo.Hostname).
Msg("Failed authentication via AuthKey") Msg("Failed authentication via AuthKey")
@ -325,7 +361,7 @@ func (h *Headscale) handleAuthKeyCommon(
log.Debug(). log.Debug().
Str("func", "handleAuthKeyCommon"). Str("func", "handleAuthKeyCommon").
Bool("noise", machineKey.IsZero()). Bool("noise", isNoise).
Str("machine", registerRequest.Hostinfo.Hostname). Str("machine", registerRequest.Hostinfo.Hostname).
Msg("Authentication key was valid, proceeding to acquire IP addresses") Msg("Authentication key was valid, proceeding to acquire IP addresses")
@ -335,11 +371,11 @@ func (h *Headscale) handleAuthKeyCommon(
// The error is not important, because if it does not // The error is not important, because if it does not
// exist, then this is a new machine and we will move // exist, then this is a new machine and we will move
// on to registration. // on to registration.
machine, _ := h.GetMachineByAnyNodeKey(registerRequest.NodeKey, registerRequest.OldNodeKey) machine, _ := h.GetMachineByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey)
if machine != nil { if machine != nil {
log.Trace(). log.Trace().
Caller(). Caller().
Bool("noise", machineKey.IsZero()). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Msg("machine was already registered before, refreshing with new auth key") Msg("machine was already registered before, refreshing with new auth key")
@ -349,7 +385,7 @@ func (h *Headscale) handleAuthKeyCommon(
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
Bool("noise", machineKey.IsZero()). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Err(err). Err(err).
Msg("Failed to refresh machine") Msg("Failed to refresh machine")
@ -365,7 +401,7 @@ func (h *Headscale) handleAuthKeyCommon(
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
Bool("noise", machineKey.IsZero()). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Strs("aclTags", aclTags). Strs("aclTags", aclTags).
Err(err). Err(err).
@ -381,7 +417,7 @@ func (h *Headscale) handleAuthKeyCommon(
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
Bool("noise", machineKey.IsZero()). Bool("noise", isNoise).
Str("func", "RegistrationHandler"). Str("func", "RegistrationHandler").
Str("hostinfo.name", registerRequest.Hostinfo.Hostname). Str("hostinfo.name", registerRequest.Hostinfo.Hostname).
Err(err) Err(err)
@ -408,7 +444,7 @@ func (h *Headscale) handleAuthKeyCommon(
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
Bool("noise", machineKey.IsZero()). Bool("noise", isNoise).
Err(err). Err(err).
Msg("could not register machine") Msg("could not register machine")
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name). machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name).
@ -423,7 +459,7 @@ func (h *Headscale) handleAuthKeyCommon(
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
Bool("noise", machineKey.IsZero()). Bool("noise", isNoise).
Err(err). Err(err).
Msg("Failed to use pre-auth key") Msg("Failed to use pre-auth key")
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name). machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name).
@ -439,11 +475,11 @@ func (h *Headscale) handleAuthKeyCommon(
// Otherwise it will need to exec `tailscale up` twice to fetch the *LoginName* // Otherwise it will need to exec `tailscale up` twice to fetch the *LoginName*
resp.Login = *pak.Namespace.toLogin() resp.Login = *pak.Namespace.toLogin()
respBody, err := h.marshalResponse(resp, machineKey) respBody, err := h.marshalResponse(resp, machineKey, isNoise)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
Bool("noise", machineKey.IsZero()). Bool("noise", isNoise).
Str("func", "handleAuthKeyCommon"). Str("func", "handleAuthKeyCommon").
Str("machine", registerRequest.Hostinfo.Hostname). Str("machine", registerRequest.Hostinfo.Hostname).
Err(err). Err(err).
@ -462,14 +498,14 @@ func (h *Headscale) handleAuthKeyCommon(
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
Bool("noise", machineKey.IsZero()). Bool("noise", isNoise).
Err(err). Err(err).
Msg("Failed to write response") Msg("Failed to write response")
} }
log.Info(). log.Info().
Str("func", "handleAuthKeyCommon"). Str("func", "handleAuthKeyCommon").
Bool("noise", machineKey.IsZero()). Bool("noise", isNoise).
Str("machine", registerRequest.Hostinfo.Hostname). Str("machine", registerRequest.Hostinfo.Hostname).
Str("ips", strings.Join(machine.IPAddresses.ToStringSlice(), ", ")). Str("ips", strings.Join(machine.IPAddresses.ToStringSlice(), ", ")).
Msg("Successfully authenticated via AuthKey") Msg("Successfully authenticated via AuthKey")
@ -481,13 +517,14 @@ func (h *Headscale) handleNewMachineCommon(
writer http.ResponseWriter, writer http.ResponseWriter,
registerRequest tailcfg.RegisterRequest, registerRequest tailcfg.RegisterRequest,
machineKey key.MachinePublic, machineKey key.MachinePublic,
isNoise bool,
) { ) {
resp := tailcfg.RegisterResponse{} resp := tailcfg.RegisterResponse{}
// The machine registration is new, redirect the client to the registration URL // The machine registration is new, redirect the client to the registration URL
log.Debug(). log.Debug().
Caller(). Caller().
Bool("noise", machineKey.IsZero()). Bool("noise", isNoise).
Str("machine", registerRequest.Hostinfo.Hostname). Str("machine", registerRequest.Hostinfo.Hostname).
Msg("The node seems to be new, sending auth url") Msg("The node seems to be new, sending auth url")
@ -503,11 +540,11 @@ func (h *Headscale) handleNewMachineCommon(
registerRequest.NodeKey) registerRequest.NodeKey)
} }
respBody, err := h.marshalResponse(resp, machineKey) respBody, err := h.marshalResponse(resp, machineKey, isNoise)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
Bool("noise", machineKey.IsZero()). Bool("noise", isNoise).
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
http.Error(writer, "Internal server error", http.StatusInternalServerError) http.Error(writer, "Internal server error", http.StatusInternalServerError)
@ -520,7 +557,7 @@ func (h *Headscale) handleNewMachineCommon(
_, err = writer.Write(respBody) _, err = writer.Write(respBody)
if err != nil { if err != nil {
log.Error(). log.Error().
Bool("noise", machineKey.IsZero()). Bool("noise", isNoise).
Caller(). Caller().
Err(err). Err(err).
Msg("Failed to write response") Msg("Failed to write response")
@ -528,7 +565,7 @@ func (h *Headscale) handleNewMachineCommon(
log.Info(). log.Info().
Caller(). Caller().
Bool("noise", machineKey.IsZero()). Bool("noise", isNoise).
Str("AuthURL", resp.AuthURL). Str("AuthURL", resp.AuthURL).
Str("machine", registerRequest.Hostinfo.Hostname). Str("machine", registerRequest.Hostinfo.Hostname).
Msg("Successfully sent auth url") Msg("Successfully sent auth url")
@ -538,11 +575,12 @@ func (h *Headscale) handleMachineLogOutCommon(
writer http.ResponseWriter, writer http.ResponseWriter,
machine Machine, machine Machine,
machineKey key.MachinePublic, machineKey key.MachinePublic,
isNoise bool,
) { ) {
resp := tailcfg.RegisterResponse{} resp := tailcfg.RegisterResponse{}
log.Info(). log.Info().
Bool("noise", machineKey.IsZero()). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Msg("Client requested logout") Msg("Client requested logout")
@ -550,7 +588,7 @@ func (h *Headscale) handleMachineLogOutCommon(
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
Bool("noise", machineKey.IsZero()). Bool("noise", isNoise).
Str("func", "handleMachineLogOutCommon"). Str("func", "handleMachineLogOutCommon").
Err(err). Err(err).
Msg("Failed to expire machine") Msg("Failed to expire machine")
@ -561,12 +599,13 @@ func (h *Headscale) handleMachineLogOutCommon(
resp.AuthURL = "" resp.AuthURL = ""
resp.MachineAuthorized = false resp.MachineAuthorized = false
resp.NodeKeyExpired = true
resp.User = *machine.Namespace.toUser() resp.User = *machine.Namespace.toUser()
respBody, err := h.marshalResponse(resp, machineKey) respBody, err := h.marshalResponse(resp, machineKey, isNoise)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
Bool("noise", machineKey.IsZero()). Bool("noise", isNoise).
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
http.Error(writer, "Internal server error", http.StatusInternalServerError) http.Error(writer, "Internal server error", http.StatusInternalServerError)
@ -579,7 +618,7 @@ func (h *Headscale) handleMachineLogOutCommon(
_, err = writer.Write(respBody) _, err = writer.Write(respBody)
if err != nil { if err != nil {
log.Error(). log.Error().
Bool("noise", machineKey.IsZero()). Bool("noise", isNoise).
Caller(). Caller().
Err(err). Err(err).
Msg("Failed to write response") Msg("Failed to write response")
@ -587,7 +626,7 @@ func (h *Headscale) handleMachineLogOutCommon(
log.Info(). log.Info().
Caller(). Caller().
Bool("noise", machineKey.IsZero()). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Msg("Successfully logged out") Msg("Successfully logged out")
} }
@ -596,13 +635,14 @@ func (h *Headscale) handleMachineValidRegistrationCommon(
writer http.ResponseWriter, writer http.ResponseWriter,
machine Machine, machine Machine,
machineKey key.MachinePublic, machineKey key.MachinePublic,
isNoise bool,
) { ) {
resp := tailcfg.RegisterResponse{} resp := tailcfg.RegisterResponse{}
// The machine registration is valid, respond with redirect to /map // The machine registration is valid, respond with redirect to /map
log.Debug(). log.Debug().
Caller(). Caller().
Bool("noise", machineKey.IsZero()). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Msg("Client is registered and we have the current NodeKey. All clear to /map") Msg("Client is registered and we have the current NodeKey. All clear to /map")
@ -611,11 +651,11 @@ func (h *Headscale) handleMachineValidRegistrationCommon(
resp.User = *machine.Namespace.toUser() resp.User = *machine.Namespace.toUser()
resp.Login = *machine.Namespace.toLogin() resp.Login = *machine.Namespace.toLogin()
respBody, err := h.marshalResponse(resp, machineKey) respBody, err := h.marshalResponse(resp, machineKey, isNoise)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
Bool("noise", machineKey.IsZero()). Bool("noise", isNoise).
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
machineRegistrations.WithLabelValues("update", "web", "error", machine.Namespace.Name). machineRegistrations.WithLabelValues("update", "web", "error", machine.Namespace.Name).
@ -633,14 +673,14 @@ func (h *Headscale) handleMachineValidRegistrationCommon(
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
Bool("noise", machineKey.IsZero()). Bool("noise", isNoise).
Err(err). Err(err).
Msg("Failed to write response") Msg("Failed to write response")
} }
log.Info(). log.Info().
Caller(). Caller().
Bool("noise", machineKey.IsZero()). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Msg("Machine successfully authorized") Msg("Machine successfully authorized")
} }
@ -650,12 +690,13 @@ func (h *Headscale) handleMachineRefreshKeyCommon(
registerRequest tailcfg.RegisterRequest, registerRequest tailcfg.RegisterRequest,
machine Machine, machine Machine,
machineKey key.MachinePublic, machineKey key.MachinePublic,
isNoise bool,
) { ) {
resp := tailcfg.RegisterResponse{} resp := tailcfg.RegisterResponse{}
log.Debug(). log.Info().
Caller(). Caller().
Bool("noise", machineKey.IsZero()). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Msg("We have the OldNodeKey in the database. This is a key refresh") Msg("We have the OldNodeKey in the database. This is a key refresh")
machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey) machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey)
@ -672,11 +713,11 @@ func (h *Headscale) handleMachineRefreshKeyCommon(
resp.AuthURL = "" resp.AuthURL = ""
resp.User = *machine.Namespace.toUser() resp.User = *machine.Namespace.toUser()
respBody, err := h.marshalResponse(resp, machineKey) respBody, err := h.marshalResponse(resp, machineKey, isNoise)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
Bool("noise", machineKey.IsZero()). Bool("noise", isNoise).
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
http.Error(writer, "Internal server error", http.StatusInternalServerError) http.Error(writer, "Internal server error", http.StatusInternalServerError)
@ -690,41 +731,45 @@ func (h *Headscale) handleMachineRefreshKeyCommon(
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
Bool("noise", machineKey.IsZero()). Bool("noise", isNoise).
Err(err). Err(err).
Msg("Failed to write response") Msg("Failed to write response")
} }
log.Info(). log.Info().
Caller(). Caller().
Bool("noise", machineKey.IsZero()). Bool("noise", isNoise).
Str("node_key", registerRequest.NodeKey.ShortString()). Str("node_key", registerRequest.NodeKey.ShortString()).
Str("old_node_key", registerRequest.OldNodeKey.ShortString()). Str("old_node_key", registerRequest.OldNodeKey.ShortString()).
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Msg("Machine successfully refreshed") Msg("Node key successfully refreshed")
} }
func (h *Headscale) handleMachineExpiredCommon( func (h *Headscale) handleMachineExpiredOrLoggedOutCommon(
writer http.ResponseWriter, writer http.ResponseWriter,
registerRequest tailcfg.RegisterRequest, registerRequest tailcfg.RegisterRequest,
machine Machine, machine Machine,
machineKey key.MachinePublic, machineKey key.MachinePublic,
isNoise bool,
) { ) {
resp := tailcfg.RegisterResponse{} resp := tailcfg.RegisterResponse{}
// The client has registered before, but has expired
log.Debug().
Caller().
Bool("noise", machineKey.IsZero()).
Str("machine", machine.Hostname).
Msg("Machine registration has expired. Sending a authurl to register")
if registerRequest.Auth.AuthKey != "" { if registerRequest.Auth.AuthKey != "" {
h.handleAuthKeyCommon(writer, registerRequest, machineKey) h.handleAuthKeyCommon(writer, registerRequest, machineKey, isNoise)
return return
} }
// The client has registered before, but has expired or logged out
log.Trace().
Caller().
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
Msg("Machine registration has expired or logged out. Sending a auth url to register")
if h.oauth2Config != nil { if h.oauth2Config != nil {
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), strings.TrimSuffix(h.cfg.ServerURL, "/"),
@ -735,11 +780,11 @@ func (h *Headscale) handleMachineExpiredCommon(
registerRequest.NodeKey) registerRequest.NodeKey)
} }
respBody, err := h.marshalResponse(resp, machineKey) respBody, err := h.marshalResponse(resp, machineKey, isNoise)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
Bool("noise", machineKey.IsZero()). Bool("noise", isNoise).
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
machineRegistrations.WithLabelValues("reauth", "web", "error", machine.Namespace.Name). machineRegistrations.WithLabelValues("reauth", "web", "error", machine.Namespace.Name).
@ -757,14 +802,17 @@ func (h *Headscale) handleMachineExpiredCommon(
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
Bool("noise", machineKey.IsZero()). Bool("noise", isNoise).
Err(err). Err(err).
Msg("Failed to write response") Msg("Failed to write response")
} }
log.Info(). log.Trace().
Caller(). Caller().
Bool("noise", machineKey.IsZero()). Bool("noise", isNoise).
Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Msg("Auth URL for reauthenticate successfully sent") Msg("Machine logged out. Sent AuthURL for reauthentication")
} }

View file

@ -21,7 +21,7 @@ func (h *Headscale) getMapResponseData(
} }
if isNoise { if isNoise {
return h.marshalMapResponse(mapResponse, key.MachinePublic{}, mapRequest.Compress) return h.marshalMapResponse(mapResponse, key.MachinePublic{}, mapRequest.Compress, isNoise)
} }
var machineKey key.MachinePublic var machineKey key.MachinePublic
@ -35,7 +35,7 @@ func (h *Headscale) getMapResponseData(
return nil, err return nil, err
} }
return h.marshalMapResponse(mapResponse, machineKey, mapRequest.Compress) return h.marshalMapResponse(mapResponse, machineKey, mapRequest.Compress, isNoise)
} }
func (h *Headscale) getMapKeepAliveResponseData( func (h *Headscale) getMapKeepAliveResponseData(
@ -48,7 +48,7 @@ func (h *Headscale) getMapKeepAliveResponseData(
} }
if isNoise { if isNoise {
return h.marshalMapResponse(keepAliveResponse, key.MachinePublic{}, mapRequest.Compress) return h.marshalMapResponse(keepAliveResponse, key.MachinePublic{}, mapRequest.Compress, isNoise)
} }
var machineKey key.MachinePublic var machineKey key.MachinePublic
@ -62,12 +62,13 @@ func (h *Headscale) getMapKeepAliveResponseData(
return nil, err return nil, err
} }
return h.marshalMapResponse(keepAliveResponse, machineKey, mapRequest.Compress) return h.marshalMapResponse(keepAliveResponse, machineKey, mapRequest.Compress, isNoise)
} }
func (h *Headscale) marshalResponse( func (h *Headscale) marshalResponse(
resp interface{}, resp interface{},
machineKey key.MachinePublic, machineKey key.MachinePublic,
isNoise bool,
) ([]byte, error) { ) ([]byte, error) {
jsonBody, err := json.Marshal(resp) jsonBody, err := json.Marshal(resp)
if err != nil { if err != nil {
@ -79,7 +80,7 @@ func (h *Headscale) marshalResponse(
return nil, err return nil, err
} }
if machineKey.IsZero() { // if Noise if isNoise {
return jsonBody, nil return jsonBody, nil
} }
@ -90,6 +91,7 @@ func (h *Headscale) marshalMapResponse(
resp interface{}, resp interface{},
machineKey key.MachinePublic, machineKey key.MachinePublic,
compression string, compression string,
isNoise bool,
) ([]byte, error) { ) ([]byte, error) {
jsonBody, err := json.Marshal(resp) jsonBody, err := json.Marshal(resp)
if err != nil { if err != nil {
@ -103,11 +105,11 @@ func (h *Headscale) marshalMapResponse(
if compression == ZstdCompression { if compression == ZstdCompression {
encoder, _ := zstd.NewWriter(nil) encoder, _ := zstd.NewWriter(nil)
respBody = encoder.EncodeAll(jsonBody, nil) respBody = encoder.EncodeAll(jsonBody, nil)
if !machineKey.IsZero() { // if legacy protocol if !isNoise { // if legacy protocol
respBody = h.privateKey.SealTo(machineKey, respBody) respBody = h.privateKey.SealTo(machineKey, respBody)
} }
} else { } else {
if !machineKey.IsZero() { // if legacy protocol if !isNoise { // if legacy protocol
respBody = h.privateKey.SealTo(machineKey, jsonBody) respBody = h.privateKey.SealTo(machineKey, jsonBody)
} else { } else {
respBody = jsonBody respBody = jsonBody

View file

@ -56,5 +56,5 @@ func (h *Headscale) RegistrationHandler(
return return
} }
h.handleRegisterCommon(writer, req, registerRequest, machineKey) h.handleRegisterCommon(writer, req, registerRequest, machineKey, false)
} }

View file

@ -7,11 +7,10 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key"
) )
// // NoiseRegistrationHandler handles the actual registration process of a machine. // // NoiseRegistrationHandler handles the actual registration process of a machine.
func (h *Headscale) NoiseRegistrationHandler( func (t *ts2021App) NoiseRegistrationHandler(
writer http.ResponseWriter, writer http.ResponseWriter,
req *http.Request, req *http.Request,
) { ) {
@ -34,5 +33,5 @@ func (h *Headscale) NoiseRegistrationHandler(
return return
} }
h.handleRegisterCommon(writer, req, registerRequest, key.MachinePublic{}) t.headscale.handleRegisterCommon(writer, req, registerRequest, t.conn.Peer(), true)
} }

View file

@ -21,7 +21,7 @@ import (
// only after their first request (marked with the ReadOnly field). // only after their first request (marked with the ReadOnly field).
// //
// At this moment the updates are sent in a quite horrendous way, but they kinda work. // At this moment the updates are sent in a quite horrendous way, but they kinda work.
func (h *Headscale) NoisePollNetMapHandler( func (t *ts2021App) NoisePollNetMapHandler(
writer http.ResponseWriter, writer http.ResponseWriter,
req *http.Request, req *http.Request,
) { ) {
@ -41,7 +41,7 @@ func (h *Headscale) NoisePollNetMapHandler(
return return
} }
machine, err := h.GetMachineByAnyNodeKey(mapRequest.NodeKey, key.NodePublic{}) machine, err := t.headscale.GetMachineByAnyKey(t.conn.Peer(), mapRequest.NodeKey, key.NodePublic{})
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
log.Warn(). log.Warn().
@ -63,5 +63,5 @@ func (h *Headscale) NoisePollNetMapHandler(
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Msg("A machine is entering polling via the Noise protocol") Msg("A machine is entering polling via the Noise protocol")
h.handlePollCommon(writer, req.Context(), machine, mapRequest, true) t.headscale.handlePollCommon(writer, req.Context(), machine, mapRequest, true)
} }