mirror of
https://github.com/juanfont/headscale.git
synced 2025-01-19 10:20:05 +09:00
Register new machines needing callback in memory
This commit stores temporary registration data in cache, instead of memory allowing us to only have actually registered machines in the database.
This commit is contained in:
parent
1caa6f5d69
commit
469551bc5d
7 changed files with 136 additions and 95 deletions
153
api.go
153
api.go
|
@ -125,25 +125,40 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
|
|||
machine, err := h.GetMachineByMachineKey(machineKey)
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine")
|
||||
newMachine := Machine{
|
||||
Expiry: &time.Time{},
|
||||
MachineKey: MachinePublicKeyStripPrefix(machineKey),
|
||||
Name: req.Hostinfo.Hostname,
|
||||
}
|
||||
if err := h.db.Create(&newMachine).Error; err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Could not create row")
|
||||
machineRegistrations.WithLabelValues("unknown", "web", "error", machine.Namespace.Name).
|
||||
Inc()
|
||||
|
||||
machineKeyStr := MachinePublicKeyStripPrefix(machineKey)
|
||||
|
||||
// If the machine has AuthKey set, handle registration via PreAuthKeys
|
||||
if req.Auth.AuthKey != "" {
|
||||
h.handleAuthKey(ctx, machineKey, req)
|
||||
|
||||
return
|
||||
}
|
||||
machine = &newMachine
|
||||
|
||||
// The machine did not have a key to authenticate, which means
|
||||
// that we rely on a method that calls back some how (OpenID or CLI)
|
||||
// We create the machine and then keep it around until a callback
|
||||
// happens
|
||||
newMachine := Machine{
|
||||
Expiry: &time.Time{},
|
||||
MachineKey: machineKeyStr,
|
||||
Name: req.Hostinfo.Hostname,
|
||||
NodeKey: NodePublicKeyStripPrefix(req.NodeKey),
|
||||
LastSeen: &now,
|
||||
}
|
||||
|
||||
h.registrationCache.Set(
|
||||
machineKeyStr,
|
||||
newMachine,
|
||||
requestedExpiryCacheExpiration,
|
||||
)
|
||||
|
||||
h.handleMachineRegistrationNew(ctx, machineKey, req)
|
||||
return
|
||||
}
|
||||
|
||||
if machine.Registered {
|
||||
// The machine is already registered, so we need to pass through reauth or key update.
|
||||
if machine != nil {
|
||||
// If the NodeKey stored in headscale is the same as the key presented in a registration
|
||||
// request, then we have a node that is either:
|
||||
// - Trying to log out (sending a expiry in the past)
|
||||
|
@ -180,15 +195,6 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
|
|||
|
||||
return
|
||||
}
|
||||
|
||||
// If the machine has AuthKey set, handle registration via PreAuthKeys
|
||||
if req.Auth.AuthKey != "" {
|
||||
h.handleAuthKey(ctx, machineKey, req, *machine)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
h.handleMachineRegistrationNew(ctx, machineKey, req, *machine)
|
||||
}
|
||||
|
||||
func (h *Headscale) getMapResponse(
|
||||
|
@ -402,7 +408,7 @@ func (h *Headscale) handleMachineExpired(
|
|||
Msg("Machine registration has expired. Sending a authurl to register")
|
||||
|
||||
if registerRequest.Auth.AuthKey != "" {
|
||||
h.handleAuthKey(ctx, machineKey, registerRequest, machine)
|
||||
h.handleAuthKey(ctx, machineKey, registerRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -465,13 +471,12 @@ func (h *Headscale) handleMachineRegistrationNew(
|
|||
ctx *gin.Context,
|
||||
machineKey key.MachinePublic,
|
||||
registerRequest tailcfg.RegisterRequest,
|
||||
machine Machine,
|
||||
) {
|
||||
resp := tailcfg.RegisterResponse{}
|
||||
|
||||
// The machine registration is new, redirect the client to the registration URL
|
||||
log.Debug().
|
||||
Str("machine", machine.Name).
|
||||
Str("machine", registerRequest.Hostinfo.Hostname).
|
||||
Msg("The node is sending us a new NodeKey, sending auth url")
|
||||
if h.cfg.OIDC.Issuer != "" {
|
||||
resp.AuthURL = fmt.Sprintf(
|
||||
|
@ -487,7 +492,7 @@ func (h *Headscale) handleMachineRegistrationNew(
|
|||
if !registerRequest.Expiry.IsZero() {
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", machine.Name).
|
||||
Str("machine", registerRequest.Hostinfo.Hostname).
|
||||
Time("expiry", registerRequest.Expiry).
|
||||
Msg("Non-zero expiry time requested, adding to cache")
|
||||
h.requestedExpiryCache.Set(
|
||||
|
@ -497,11 +502,6 @@ func (h *Headscale) handleMachineRegistrationNew(
|
|||
)
|
||||
}
|
||||
|
||||
machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey)
|
||||
|
||||
// save the NodeKey
|
||||
h.db.Save(&machine)
|
||||
|
||||
respBody, err := encode(resp, &machineKey, h.privateKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
|
@ -520,19 +520,21 @@ func (h *Headscale) handleAuthKey(
|
|||
ctx *gin.Context,
|
||||
machineKey key.MachinePublic,
|
||||
registerRequest tailcfg.RegisterRequest,
|
||||
machine Machine,
|
||||
) {
|
||||
machineKeyStr := MachinePublicKeyStripPrefix(machineKey)
|
||||
|
||||
log.Debug().
|
||||
Str("func", "handleAuthKey").
|
||||
Str("machine", registerRequest.Hostinfo.Hostname).
|
||||
Msgf("Processing auth key for %s", registerRequest.Hostinfo.Hostname)
|
||||
resp := tailcfg.RegisterResponse{}
|
||||
|
||||
pak, err := h.checkKeyValidity(registerRequest.Auth.AuthKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("func", "handleAuthKey").
|
||||
Str("machine", machine.Name).
|
||||
Str("machine", registerRequest.Hostinfo.Hostname).
|
||||
Err(err).
|
||||
Msg("Failed authentication via AuthKey")
|
||||
resp.MachineAuthorized = false
|
||||
|
@ -541,69 +543,62 @@ func (h *Headscale) handleAuthKey(
|
|||
log.Error().
|
||||
Caller().
|
||||
Str("func", "handleAuthKey").
|
||||
Str("machine", machine.Name).
|
||||
Str("machine", registerRequest.Hostinfo.Hostname).
|
||||
Err(err).
|
||||
Msg("Cannot encode message")
|
||||
ctx.String(http.StatusInternalServerError, "")
|
||||
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", machine.Namespace.Name).
|
||||
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name).
|
||||
Inc()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Data(http.StatusUnauthorized, "application/json; charset=utf-8", respBody)
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("func", "handleAuthKey").
|
||||
Str("machine", machine.Name).
|
||||
Str("machine", registerRequest.Hostinfo.Hostname).
|
||||
Msg("Failed authentication via AuthKey")
|
||||
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", machine.Namespace.Name).
|
||||
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name).
|
||||
Inc()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if machine.isRegistered() {
|
||||
log.Trace().
|
||||
log.Debug().
|
||||
Str("func", "handleAuthKey").
|
||||
Str("machine", registerRequest.Hostinfo.Hostname).
|
||||
Msg("Authentication key was valid, proceeding to acquire IP addresses")
|
||||
|
||||
nodeKey := NodePublicKeyStripPrefix(registerRequest.NodeKey)
|
||||
now := time.Now().UTC()
|
||||
|
||||
machine, err := h.RegisterMachine(
|
||||
registerRequest.Hostinfo.Hostname,
|
||||
machineKeyStr,
|
||||
pak.Namespace.Name,
|
||||
RegisterMethodAuthKey,
|
||||
®isterRequest.Expiry,
|
||||
pak,
|
||||
&nodeKey,
|
||||
&now,
|
||||
)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("machine", machine.Name).
|
||||
Msg("machine already registered, reauthenticating")
|
||||
|
||||
h.RefreshMachine(&machine, registerRequest.Expiry)
|
||||
} else {
|
||||
log.Debug().
|
||||
Str("func", "handleAuthKey").
|
||||
Str("machine", machine.Name).
|
||||
Msg("Authentication key was valid, proceeding to acquire IP addresses")
|
||||
|
||||
nodeKey := NodePublicKeyStripPrefix(registerRequest.NodeKey)
|
||||
now := time.Now().UTC()
|
||||
|
||||
_, err = h.RegisterMachine(
|
||||
machine.Name,
|
||||
machine.Namespace.Name,
|
||||
RegisterMethodAuthKey,
|
||||
®isterRequest.Expiry,
|
||||
pak,
|
||||
&nodeKey,
|
||||
&now,
|
||||
Err(err).
|
||||
Msg("could not register machine")
|
||||
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name).
|
||||
Inc()
|
||||
ctx.String(
|
||||
http.StatusInternalServerError,
|
||||
"could not register machine",
|
||||
)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("could not register machine")
|
||||
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", machine.Namespace.Name).Inc()
|
||||
ctx.String(
|
||||
http.StatusInternalServerError,
|
||||
"could not register machine",
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
pak.Used = true
|
||||
h.db.Save(&pak)
|
||||
h.UsePreAuthKey(pak)
|
||||
|
||||
resp.MachineAuthorized = true
|
||||
resp.User = *pak.Namespace.toUser()
|
||||
|
@ -612,21 +607,21 @@ func (h *Headscale) handleAuthKey(
|
|||
log.Error().
|
||||
Caller().
|
||||
Str("func", "handleAuthKey").
|
||||
Str("machine", machine.Name).
|
||||
Str("machine", registerRequest.Hostinfo.Hostname).
|
||||
Err(err).
|
||||
Msg("Cannot encode message")
|
||||
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", machine.Namespace.Name).
|
||||
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name).
|
||||
Inc()
|
||||
ctx.String(http.StatusInternalServerError, "Extremely sad!")
|
||||
|
||||
return
|
||||
}
|
||||
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "success", machine.Namespace.Name).
|
||||
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "success", pak.Namespace.Name).
|
||||
Inc()
|
||||
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
|
||||
log.Info().
|
||||
Str("func", "handleAuthKey").
|
||||
Str("machine", machine.Name).
|
||||
Str("machine", registerRequest.Hostinfo.Hostname).
|
||||
Str("ips", strings.Join(machine.IPAddresses.ToStringSlice(), ", ")).
|
||||
Msg("Successfully authenticated via AuthKey")
|
||||
}
|
||||
|
|
9
app.go
9
app.go
|
@ -154,6 +154,8 @@ type Headscale struct {
|
|||
|
||||
requestedExpiryCache *cache.Cache
|
||||
|
||||
registrationCache *cache.Cache
|
||||
|
||||
ipAllocationMutex sync.Mutex
|
||||
}
|
||||
|
||||
|
@ -207,6 +209,12 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
|
|||
requestedExpiryCacheCleanupInterval,
|
||||
)
|
||||
|
||||
registrationCache := cache.New(
|
||||
// TODO(kradalby): Add unified cache expiry config options
|
||||
requestedExpiryCacheExpiration,
|
||||
requestedExpiryCacheCleanupInterval,
|
||||
)
|
||||
|
||||
app := Headscale{
|
||||
cfg: cfg,
|
||||
dbType: cfg.DBtype,
|
||||
|
@ -214,6 +222,7 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
|
|||
privateKey: privKey,
|
||||
aclRules: tailcfg.FilterAllowAll, // default allowall
|
||||
requestedExpiryCache: requestedExpiryCache,
|
||||
registrationCache: registrationCache,
|
||||
}
|
||||
|
||||
err = app.initDB()
|
||||
|
|
|
@ -30,6 +30,7 @@ func (s *Suite) TestRegisterMachine(c *check.C) {
|
|||
c.Assert(err, check.IsNil)
|
||||
|
||||
machineAfterRegistering, err := app.RegisterMachine(
|
||||
"testmachine",
|
||||
machine.MachineKey,
|
||||
namespace.Name,
|
||||
RegisterMethodCLI,
|
||||
|
|
|
@ -174,12 +174,11 @@ func (api headscaleV1APIServer) RegisterMachine(
|
|||
}
|
||||
}
|
||||
|
||||
machine, err := api.h.RegisterMachine(
|
||||
machine, err := api.h.RegisterMachineFromAuthCallback(
|
||||
request.GetKey(),
|
||||
request.GetNamespace(),
|
||||
RegisterMethodCLI,
|
||||
&requestedTime,
|
||||
nil, nil, nil,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
52
machine.go
52
machine.go
|
@ -20,11 +20,15 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
errMachineNotFound = Error("machine not found")
|
||||
errMachineAlreadyRegistered = Error("machine already registered")
|
||||
errMachineRouteIsNotAvailable = Error("route is not available on machine")
|
||||
errMachineAddressesInvalid = Error("failed to parse machine addresses")
|
||||
errHostnameTooLong = Error("Hostname too long")
|
||||
errMachineNotFound = Error("machine not found")
|
||||
errMachineAlreadyRegistered = Error("machine already registered")
|
||||
errMachineRouteIsNotAvailable = Error("route is not available on machine")
|
||||
errMachineAddressesInvalid = Error("failed to parse machine addresses")
|
||||
errMachineNotFoundRegistrationCache = Error(
|
||||
"machine not found in registration cache",
|
||||
)
|
||||
errCouldNotConvertMachineInterface = Error("failed to convert machine interface")
|
||||
errHostnameTooLong = Error("Hostname too long")
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -686,14 +690,44 @@ func (machine *Machine) toProto() *v1.Machine {
|
|||
return machineProto
|
||||
}
|
||||
|
||||
// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey.
|
||||
func (h *Headscale) RegisterMachine(
|
||||
func (h *Headscale) RegisterMachineFromAuthCallback(
|
||||
machineKeyStr string,
|
||||
namespaceName string,
|
||||
registrationMethod string,
|
||||
expiry *time.Time,
|
||||
) (*Machine, error) {
|
||||
if machineInterface, ok := h.registrationCache.Get(machineKeyStr); ok {
|
||||
if registrationMachine, ok := machineInterface.(Machine); ok {
|
||||
machine, err := h.RegisterMachine(
|
||||
registrationMachine.Name,
|
||||
machineKeyStr,
|
||||
namespaceName,
|
||||
registrationMethod,
|
||||
expiry,
|
||||
nil,
|
||||
®istrationMachine.NodeKey,
|
||||
registrationMachine.LastSeen,
|
||||
)
|
||||
|
||||
return machine, err
|
||||
|
||||
} else {
|
||||
return nil, errCouldNotConvertMachineInterface
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errMachineNotFoundRegistrationCache
|
||||
}
|
||||
|
||||
// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey.
|
||||
func (h *Headscale) RegisterMachine(
|
||||
machineName string,
|
||||
machineKeyStr string,
|
||||
namespaceName string,
|
||||
registrationMethod string,
|
||||
expiry *time.Time,
|
||||
|
||||
// Optionals
|
||||
expiry *time.Time,
|
||||
authKey *PreAuthKey,
|
||||
nodePublicKey *string,
|
||||
lastSeen *time.Time,
|
||||
|
@ -768,6 +802,7 @@ func (h *Headscale) RegisterMachine(
|
|||
machine.LastSeen = lastSeen
|
||||
}
|
||||
|
||||
machine.Name = machineName
|
||||
machine.NamespaceID = namespace.ID
|
||||
|
||||
// TODO(kradalby): This field is uneccessary metadata,
|
||||
|
@ -780,6 +815,7 @@ func (h *Headscale) RegisterMachine(
|
|||
// Let us simplify the model, a machine is _only_ saved if
|
||||
// it is registered.
|
||||
machine.Registered = true
|
||||
|
||||
h.db.Save(&machine)
|
||||
|
||||
log.Trace().
|
||||
|
|
7
oidc.go
7
oidc.go
|
@ -279,8 +279,6 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
namespaceName, err := NormalizeNamespaceName(
|
||||
claims.Email,
|
||||
h.cfg.OIDC.StripEmaildomain,
|
||||
|
@ -328,14 +326,11 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
_, err = h.RegisterMachine(
|
||||
_, err = h.RegisterMachineFromAuthCallback(
|
||||
machineKeyStr,
|
||||
namespace.Name,
|
||||
RegisterMethodOIDC,
|
||||
&requestedTime,
|
||||
nil,
|
||||
nil,
|
||||
&now,
|
||||
)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
|
|
|
@ -113,6 +113,12 @@ func (h *Headscale) ExpirePreAuthKey(k *PreAuthKey) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// UsePreAuthKey marks a PreAuthKey as used.
|
||||
func (h *Headscale) UsePreAuthKey(k *PreAuthKey) {
|
||||
k.Used = true
|
||||
h.db.Save(k)
|
||||
}
|
||||
|
||||
// checkKeyValidity does the heavy lifting for validation of the PreAuthKey coming from a node
|
||||
// If returns no error and a PreAuthKey, it can be used.
|
||||
func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) {
|
||||
|
|
Loading…
Reference in a new issue