use registration ID in reg flow, support followup

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2025-01-13 11:54:35 +01:00
parent fbd0d5e7ce
commit b4e7394563
No known key found for this signature in database
8 changed files with 143 additions and 119 deletions

View file

@ -96,7 +96,7 @@ type Headscale struct {
mapper *mapper.Mapper mapper *mapper.Mapper
nodeNotifier *notifier.Notifier nodeNotifier *notifier.Notifier
registrationCache *zcache.Cache[string, types.Node] registrationCache *zcache.Cache[types.RegistrationID, types.RegisterNode]
authProvider AuthProvider authProvider AuthProvider
@ -123,7 +123,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
return nil, fmt.Errorf("failed to read or create Noise protocol private key: %w", err) return nil, fmt.Errorf("failed to read or create Noise protocol private key: %w", err)
} }
registrationCache := zcache.New[string, types.Node]( registrationCache := zcache.New[types.RegistrationID, types.RegisterNode](
registerCacheExpiration, registerCacheExpiration,
registerCacheCleanup, registerCacheCleanup,
) )
@ -462,7 +462,7 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
router.HandleFunc("/health", h.HealthHandler).Methods(http.MethodGet) router.HandleFunc("/health", h.HealthHandler).Methods(http.MethodGet)
router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet) router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet)
router.HandleFunc("/register/{mkey}", h.authProvider.RegisterHandler).Methods(http.MethodGet) router.HandleFunc("/register/{registration_id}", h.authProvider.RegisterHandler).Methods(http.MethodGet)
if provider, ok := h.authProvider.(*AuthProviderOIDC); ok { if provider, ok := h.authProvider.(*AuthProviderOIDC); ok {
router.HandleFunc("/oidc/callback", provider.OIDCCallbackHandler).Methods(http.MethodGet) router.HandleFunc("/oidc/callback", provider.OIDCCallbackHandler).Methods(http.MethodGet)

View file

@ -6,6 +6,8 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"net/url"
"strings"
"time" "time"
"github.com/juanfont/headscale/hscontrol/db" "github.com/juanfont/headscale/hscontrol/db"
@ -20,16 +22,18 @@ import (
type AuthProvider interface { type AuthProvider interface {
RegisterHandler(http.ResponseWriter, *http.Request) RegisterHandler(http.ResponseWriter, *http.Request)
AuthURL(key.MachinePublic) string AuthURL(types.RegistrationID) string
} }
func logAuthFunc( func logAuthFunc(
registerRequest tailcfg.RegisterRequest, registerRequest tailcfg.RegisterRequest,
machineKey key.MachinePublic, machineKey key.MachinePublic,
registrationId types.RegistrationID,
) (func(string), func(string), func(error, string)) { ) (func(string), func(string), func(error, string)) {
return func(msg string) { return func(msg string) {
log.Info(). log.Info().
Caller(). Caller().
Str("registration_id", registrationId.String()).
Str("machine_key", machineKey.ShortString()). Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()). Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()). Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
@ -41,6 +45,7 @@ func logAuthFunc(
func(msg string) { func(msg string) {
log.Trace(). log.Trace().
Caller(). Caller().
Str("registration_id", registrationId.String()).
Str("machine_key", machineKey.ShortString()). Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()). Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()). Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
@ -52,6 +57,7 @@ func logAuthFunc(
func(err error, msg string) { func(err error, msg string) {
log.Error(). log.Error().
Caller(). Caller().
Str("registration_id", registrationId.String()).
Str("machine_key", machineKey.ShortString()). Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()). Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()). Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
@ -63,6 +69,40 @@ func logAuthFunc(
} }
} }
func (h *Headscale) waitForFollowup(
req *http.Request,
regReq tailcfg.RegisterRequest,
logTrace func(string),
) {
logTrace("register request is a followup")
fu, err := url.Parse(regReq.Followup)
if err != nil {
logTrace("failed to parse followup URL")
return
}
followupReg, err := types.RegistrationIDFromString(strings.ReplaceAll(fu.Path, "/register/", ""))
if err != nil {
logTrace("followup URL does not contains a valid registration ID")
return
}
logTrace(fmt.Sprintf("followup URL contains a valid registration ID, looking up in cache: %s", followupReg))
if reg, ok := h.registrationCache.Get(followupReg); ok {
logTrace("Node is waiting for interactive login")
select {
case <-req.Context().Done():
logTrace("node went away before it was registered")
return
case <-reg.Registered:
logTrace("node has successfully registered")
return
}
}
}
// handleRegister is the logic for registering a client. // handleRegister is the logic for registering a client.
func (h *Headscale) handleRegister( func (h *Headscale) handleRegister(
writer http.ResponseWriter, writer http.ResponseWriter,
@ -70,7 +110,18 @@ func (h *Headscale) handleRegister(
regReq tailcfg.RegisterRequest, regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic, machineKey key.MachinePublic,
) { ) {
logInfo, logTrace, _ := logAuthFunc(regReq, machineKey) registrationId, err := types.NewRegistrationID()
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to generate registration ID")
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
logInfo, logTrace, _ := logAuthFunc(regReq, machineKey, registrationId)
now := time.Now().UTC() now := time.Now().UTC()
logTrace("handleRegister called, looking up machine in DB") logTrace("handleRegister called, looking up machine in DB")
node, err := h.db.GetNodeByAnyKey(machineKey, regReq.NodeKey, regReq.OldNodeKey) node, err := h.db.GetNodeByAnyKey(machineKey, regReq.NodeKey, regReq.OldNodeKey)
@ -84,27 +135,9 @@ func (h *Headscale) handleRegister(
} }
// Check if the node is waiting for interactive login. // Check if the node is waiting for interactive login.
//
// TODO(juan): We could use this field to improve our protocol implementation,
// and hold the request until the client closes it, or the interactive
// login is completed (i.e., the user registers the node).
// This is not implemented yet, as it is no strictly required. The only side-effect
// is that the client will hammer headscale with requests until it gets a
// successful RegisterResponse.
if regReq.Followup != "" { if regReq.Followup != "" {
logTrace("register request is a followup") h.waitForFollowup(req, regReq, logTrace)
if _, ok := h.registrationCache.Get(machineKey.String()); ok {
logTrace("Node is waiting for interactive login")
select {
case <-req.Context().Done():
return return
case <-time.After(registrationHoldoff):
h.handleNewNode(writer, regReq, machineKey)
return
}
}
} }
logInfo("Node not found in database, creating new") logInfo("Node not found in database, creating new")
@ -113,25 +146,28 @@ func (h *Headscale) handleRegister(
// that we rely on a method that calls back some how (OpenID or CLI) // that we rely on a method that calls back some how (OpenID or CLI)
// We create the node and then keep it around until a callback // We create the node and then keep it around until a callback
// happens // happens
newNode := types.Node{ newNode := types.RegisterNode{
Node: types.Node{
MachineKey: machineKey, MachineKey: machineKey,
Hostname: regReq.Hostinfo.Hostname, Hostname: regReq.Hostinfo.Hostname,
NodeKey: regReq.NodeKey, NodeKey: regReq.NodeKey,
LastSeen: &now, LastSeen: &now,
Expiry: &time.Time{}, Expiry: &time.Time{},
},
Registered: make(chan struct{}),
} }
if !regReq.Expiry.IsZero() { if !regReq.Expiry.IsZero() {
logTrace("Non-zero expiry time requested") logTrace("Non-zero expiry time requested")
newNode.Expiry = &regReq.Expiry newNode.Node.Expiry = &regReq.Expiry
} }
h.registrationCache.Set( h.registrationCache.Set(
machineKey.String(), registrationId,
newNode, newNode,
) )
h.handleNewNode(writer, regReq, machineKey) h.handleNewNode(writer, regReq, registrationId)
return return
} }
@ -206,15 +242,12 @@ func (h *Headscale) handleRegister(
} }
if regReq.Followup != "" { if regReq.Followup != "" {
select { h.waitForFollowup(req, regReq, logTrace)
case <-req.Context().Done():
return return
case <-time.After(registrationHoldoff):
}
} }
// The node has expired or it is logged out // The node has expired or it is logged out
h.handleNodeExpiredOrLoggedOut(writer, regReq, *node, machineKey) h.handleNodeExpiredOrLoggedOut(writer, regReq, *node, machineKey, registrationId)
// TODO(juan): RegisterRequest includes an Expiry time, that we could optionally use // TODO(juan): RegisterRequest includes an Expiry time, that we could optionally use
node.Expiry = &time.Time{} node.Expiry = &time.Time{}
@ -444,16 +477,16 @@ func (h *Headscale) handleAuthKey(
func (h *Headscale) handleNewNode( func (h *Headscale) handleNewNode(
writer http.ResponseWriter, writer http.ResponseWriter,
registerRequest tailcfg.RegisterRequest, registerRequest tailcfg.RegisterRequest,
machineKey key.MachinePublic, registrationId types.RegistrationID,
) { ) {
logInfo, logTrace, logErr := logAuthFunc(registerRequest, machineKey) logInfo, logTrace, logErr := logAuthFunc(registerRequest, key.MachinePublic{}, registrationId)
resp := tailcfg.RegisterResponse{} resp := tailcfg.RegisterResponse{}
// The node registration is new, redirect the client to the registration URL // The node registration is new, redirect the client to the registration URL
logTrace("The node seems to be new, sending auth url") logTrace("The node is new, sending auth url")
resp.AuthURL = h.authProvider.AuthURL(machineKey) resp.AuthURL = h.authProvider.AuthURL(registrationId)
respBody, err := json.Marshal(resp) respBody, err := json.Marshal(resp)
if err != nil { if err != nil {
@ -660,6 +693,7 @@ func (h *Headscale) handleNodeExpiredOrLoggedOut(
regReq tailcfg.RegisterRequest, regReq tailcfg.RegisterRequest,
node types.Node, node types.Node,
machineKey key.MachinePublic, machineKey key.MachinePublic,
registrationId types.RegistrationID,
) { ) {
resp := tailcfg.RegisterResponse{} resp := tailcfg.RegisterResponse{}
@ -673,12 +707,12 @@ func (h *Headscale) handleNodeExpiredOrLoggedOut(
log.Trace(). log.Trace().
Caller(). Caller().
Str("node", node.Hostname). Str("node", node.Hostname).
Str("machine_key", machineKey.ShortString()). Str("registration_id", registrationId.String()).
Str("node_key", regReq.NodeKey.ShortString()). Str("node_key", regReq.NodeKey.ShortString()).
Str("node_key_old", regReq.OldNodeKey.ShortString()). Str("node_key_old", regReq.OldNodeKey.ShortString()).
Msg("Node registration has expired or logged out. Sending a auth url to register") Msg("Node registration has expired or logged out. Sending a auth url to register")
resp.AuthURL = h.authProvider.AuthURL(machineKey) resp.AuthURL = h.authProvider.AuthURL(registrationId)
respBody, err := json.Marshal(resp) respBody, err := json.Marshal(resp)
if err != nil { if err != nil {
@ -703,7 +737,7 @@ func (h *Headscale) handleNodeExpiredOrLoggedOut(
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine_key", machineKey.ShortString()). Str("registration_id", registrationId.String()).
Str("node_key", regReq.NodeKey.ShortString()). Str("node_key", regReq.NodeKey.ShortString()).
Str("node_key_old", regReq.OldNodeKey.ShortString()). Str("node_key_old", regReq.OldNodeKey.ShortString()).
Str("node", node.Hostname). Str("node", node.Hostname).

View file

@ -41,7 +41,7 @@ type KV struct {
type HSDatabase struct { type HSDatabase struct {
DB *gorm.DB DB *gorm.DB
cfg *types.DatabaseConfig cfg *types.DatabaseConfig
regCache *zcache.Cache[string, types.Node] regCache *zcache.Cache[types.RegistrationID, types.RegisterNode]
baseDomain string baseDomain string
} }
@ -51,7 +51,7 @@ type HSDatabase struct {
func NewHeadscaleDatabase( func NewHeadscaleDatabase(
cfg types.DatabaseConfig, cfg types.DatabaseConfig,
baseDomain string, baseDomain string,
regCache *zcache.Cache[string, types.Node], regCache *zcache.Cache[types.RegistrationID, types.RegisterNode],
) (*HSDatabase, error) { ) (*HSDatabase, error) {
dbConn, err := openDB(cfg) dbConn, err := openDB(cfg)
if err != nil { if err != nil {

View file

@ -256,8 +256,8 @@ func testCopyOfDatabase(src string) (string, error) {
return dst, err return dst, err
} }
func emptyCache() *zcache.Cache[string, types.Node] { func emptyCache() *zcache.Cache[types.RegistrationID, types.RegisterNode] {
return zcache.New[string, types.Node](time.Minute, time.Hour) return zcache.New[types.RegistrationID, types.RegisterNode](time.Minute, time.Hour)
} }
// requireConstraintFailed checks if the error is a constraint failure with // requireConstraintFailed checks if the error is a constraint failure with

View file

@ -320,7 +320,7 @@ func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error {
} }
func (hsdb *HSDatabase) RegisterNodeFromAuthCallback( func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
mkey key.MachinePublic, registrationID types.RegistrationID,
userID types.UserID, userID types.UserID,
nodeExpiry *time.Time, nodeExpiry *time.Time,
registrationMethod string, registrationMethod string,
@ -328,7 +328,7 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
ipv6 *netip.Addr, ipv6 *netip.Addr,
) (*types.Node, error) { ) (*types.Node, error) {
return Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) { return Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
if node, ok := hsdb.regCache.Get(mkey.String()); ok { if reg, ok := hsdb.regCache.Get(registrationID); ok {
user, err := GetUserByID(tx, userID) user, err := GetUserByID(tx, userID)
if err != nil { if err != nil {
return nil, fmt.Errorf( return nil, fmt.Errorf(
@ -338,36 +338,40 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
} }
log.Debug(). log.Debug().
Str("machine_key", mkey.ShortString()). Str("registration_id", registrationID.String()).
Str("username", user.Username()). Str("username", user.Username()).
Str("registrationMethod", registrationMethod). Str("registrationMethod", registrationMethod).
Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)). Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)).
Msg("Registering node from API/CLI or auth callback") Msg("Registering node from API/CLI or auth callback")
// TODO(kradalby): This looks quite wrong? why ID 0?
// Why not always?
// Registration of expired node with different user // Registration of expired node with different user
if node.ID != 0 && if reg.Node.ID != 0 &&
node.UserID != user.ID { reg.Node.UserID != user.ID {
return nil, ErrDifferentRegisteredUser return nil, ErrDifferentRegisteredUser
} }
node.UserID = user.ID reg.Node.UserID = user.ID
node.User = *user reg.Node.User = *user
node.RegisterMethod = registrationMethod reg.Node.RegisterMethod = registrationMethod
if nodeExpiry != nil { if nodeExpiry != nil {
node.Expiry = nodeExpiry reg.Node.Expiry = nodeExpiry
} }
node, err := RegisterNode( node, err := RegisterNode(
tx, tx,
node, reg.Node,
ipv4, ipv6, ipv4, ipv6,
) )
if err == nil { if err == nil {
hsdb.regCache.Delete(mkey.String()) hsdb.regCache.Delete(registrationID)
} }
// Signal to waiting clients that the machine has been registered.
close(reg.Registered)
return node, err return node, err
} }

View file

@ -227,11 +227,10 @@ func (api headscaleV1APIServer) RegisterNode(
) (*v1.RegisterNodeResponse, error) { ) (*v1.RegisterNodeResponse, error) {
log.Trace(). log.Trace().
Str("user", request.GetUser()). Str("user", request.GetUser()).
Str("machine_key", request.GetKey()). Str("registration_id", request.GetKey()).
Msg("Registering node") Msg("Registering node")
var mkey key.MachinePublic registrationId, err := types.RegistrationIDFromString(request.GetKey())
err := mkey.UnmarshalText([]byte(request.GetKey()))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -247,7 +246,7 @@ func (api headscaleV1APIServer) RegisterNode(
} }
node, err := api.h.db.RegisterNodeFromAuthCallback( node, err := api.h.db.RegisterNodeFromAuthCallback(
mkey, registrationId,
types.UserID(user.ID), types.UserID(user.ID),
nil, nil,
util.RegisterMethodCLI, util.RegisterMethodCLI,
@ -839,16 +838,15 @@ func (api headscaleV1APIServer) DebugCreateNode(
Hostname: "DebugTestNode", Hostname: "DebugTestNode",
} }
var mkey key.MachinePublic registrationId, err := types.RegistrationIDFromString(request.GetKey())
err = mkey.UnmarshalText([]byte(request.GetKey()))
if err != nil { if err != nil {
return nil, err return nil, err
} }
nodeKey := key.NewNode() nodeKey := key.NewNode()
newNode := types.Node{ newNode := types.RegisterNode{
MachineKey: mkey, Node: types.Node{
NodeKey: nodeKey.Public(), NodeKey: nodeKey.Public(),
Hostname: request.GetName(), Hostname: request.GetName(),
User: *user, User: *user,
@ -857,18 +855,20 @@ func (api headscaleV1APIServer) DebugCreateNode(
LastSeen: &time.Time{}, LastSeen: &time.Time{},
Hostinfo: &hostinfo, Hostinfo: &hostinfo,
},
Registered: make(chan struct{}),
} }
log.Debug(). log.Debug().
Str("machine_key", mkey.ShortString()). Str("registration_id", registrationId.String()).
Msg("adding debug machine via CLI, appending to registration cache") Msg("adding debug machine via CLI, appending to registration cache")
api.h.registrationCache.Set( api.h.registrationCache.Set(
mkey.String(), registrationId,
newNode, newNode,
) )
return &v1.DebugCreateNodeResponse{Node: newNode.Proto()}, nil return &v1.DebugCreateNodeResponse{Node: newNode.Node.Proto()}, nil
} }
func (api headscaleV1APIServer) mustEmbedUnimplementedHeadscaleServiceServer() {} func (api headscaleV1APIServer) mustEmbedUnimplementedHeadscaleServiceServer() {}

View file

@ -8,13 +8,13 @@ import (
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
"time"
"github.com/chasefleming/elem-go" "github.com/chasefleming/elem-go"
"github.com/chasefleming/elem-go/attrs" "github.com/chasefleming/elem-go/attrs"
"github.com/chasefleming/elem-go/styles" "github.com/chasefleming/elem-go/styles"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/juanfont/headscale/hscontrol/templates" "github.com/juanfont/headscale/hscontrol/templates"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key" "tailscale.com/types/key"
@ -32,8 +32,6 @@ const (
// See also https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go // See also https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go
NoiseCapabilityVersion = 39 NoiseCapabilityVersion = 39
// TODO(juan): remove this once https://github.com/juanfont/headscale/issues/727 is fixed.
registrationHoldoff = time.Second * 5
reservedResponseHeaderSize = 4 reservedResponseHeaderSize = 4
) )
@ -239,11 +237,11 @@ func NewAuthProviderWeb(serverURL string) *AuthProviderWeb {
} }
} }
func (a *AuthProviderWeb) AuthURL(mKey key.MachinePublic) string { func (a *AuthProviderWeb) AuthURL(registrationId types.RegistrationID) string {
return fmt.Sprintf( return fmt.Sprintf(
"%s/register/%s", "%s/register/%s",
strings.TrimSuffix(a.serverURL, "/"), strings.TrimSuffix(a.serverURL, "/"),
mKey.String()) registrationId.String())
} }
// RegisterWebAPI shows a simple message in the browser to point to the CLI // RegisterWebAPI shows a simple message in the browser to point to the CLI

View file

@ -21,7 +21,6 @@ import (
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"tailscale.com/types/key"
"zgo.at/zcache/v2" "zgo.at/zcache/v2"
) )
@ -49,7 +48,7 @@ var (
// RegistrationInfo contains both machine key and verifier information for OIDC validation. // RegistrationInfo contains both machine key and verifier information for OIDC validation.
type RegistrationInfo struct { type RegistrationInfo struct {
MachineKey key.MachinePublic RegistrationID types.RegistrationID
Verifier *string Verifier *string
} }
@ -112,11 +111,11 @@ func NewAuthProviderOIDC(
}, nil }, nil
} }
func (a *AuthProviderOIDC) AuthURL(mKey key.MachinePublic) string { func (a *AuthProviderOIDC) AuthURL(registrationID types.RegistrationID) string {
return fmt.Sprintf( return fmt.Sprintf(
"%s/register/%s", "%s/register/%s",
strings.TrimSuffix(a.serverURL, "/"), strings.TrimSuffix(a.serverURL, "/"),
mKey.String()) registrationID.String())
} }
func (a *AuthProviderOIDC) determineNodeExpiry(idTokenExpiration time.Time) time.Time { func (a *AuthProviderOIDC) determineNodeExpiry(idTokenExpiration time.Time) time.Time {
@ -129,32 +128,29 @@ func (a *AuthProviderOIDC) determineNodeExpiry(idTokenExpiration time.Time) time
// RegisterOIDC redirects to the OIDC provider for authentication // RegisterOIDC redirects to the OIDC provider for authentication
// Puts NodeKey in cache so the callback can retrieve it using the oidc state param // Puts NodeKey in cache so the callback can retrieve it using the oidc state param
// Listens in /register/:mKey. // Listens in /register/:registration_id.
func (a *AuthProviderOIDC) RegisterHandler( func (a *AuthProviderOIDC) RegisterHandler(
writer http.ResponseWriter, writer http.ResponseWriter,
req *http.Request, req *http.Request,
) { ) {
vars := mux.Vars(req) vars := mux.Vars(req)
machineKeyStr, ok := vars["mkey"] registrationIdStr, ok := vars["registration_id"]
log.Debug().
Caller().
Str("machine_key", machineKeyStr).
Bool("ok", ok).
Msg("Received oidc register call")
// We need to make sure we dont open for XSS style injections, if the parameter that // We need to make sure we dont open for XSS style injections, if the parameter that
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render // is passed as a key is not parsable/validated as a NodePublic key, then fail to render
// the template and log an error. // the template and log an error.
var machineKey key.MachinePublic registrationId, err := types.RegistrationIDFromString(registrationIdStr)
err := machineKey.UnmarshalText(
[]byte(machineKeyStr),
)
if err != nil { if err != nil {
http.Error(writer, err.Error(), http.StatusBadRequest) http.Error(writer, "invalid registration ID", http.StatusBadRequest)
return return
} }
log.Debug().
Caller().
Str("registration_id", registrationId.String()).
Bool("ok", ok).
Msg("Received oidc register call")
// Set the state and nonce cookies to protect against CSRF attacks // Set the state and nonce cookies to protect against CSRF attacks
state, err := setCSRFCookie(writer, req, "state") state, err := setCSRFCookie(writer, req, "state")
if err != nil { if err != nil {
@ -171,7 +167,7 @@ func (a *AuthProviderOIDC) RegisterHandler(
// Initialize registration info with machine key // Initialize registration info with machine key
registrationInfo := RegistrationInfo{ registrationInfo := RegistrationInfo{
MachineKey: machineKey, RegistrationID: registrationId,
} }
extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams)+defaultOAuthOptionsCount) extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams)+defaultOAuthOptionsCount)
@ -326,8 +322,8 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
} }
// Register the node if it does not exist. // Register the node if it does not exist.
if mKey != nil { if registrationId != nil {
if err := a.registerNode(user, mKey, nodeExpiry); err != nil { if err := a.registerNode(user, *registrationId, nodeExpiry); err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError) http.Error(writer, err.Error(), http.StatusInternalServerError)
return return
} }
@ -456,22 +452,14 @@ func validateOIDCAllowedUsers(
return nil return nil
} }
// getMachineKeyFromState retrieves the machine key from the state // getRegistrationIDFromState retrieves the registration ID from the state.
// cache. If the machine key is found, it will try retrieve the func (a *AuthProviderOIDC) getRegistrationIDFromState(state string) *types.RegistrationID {
// node information from the database.
func (a *AuthProviderOIDC) getMachineKeyFromState(state string) (*types.Node, *key.MachinePublic) {
regInfo, ok := a.registrationCache.Get(state) regInfo, ok := a.registrationCache.Get(state)
if !ok { if !ok {
return nil, nil return nil
} }
// retrieve node information if it exist return &regInfo.RegistrationID
// The error is not important, because if it does not
// exist, then this is a new node and we will move
// on to registration.
node, _ := a.db.GetNodeByMachineKey(regInfo.MachineKey)
return node, &regInfo.MachineKey
} }
// reauthenticateNode updates the node expiry in the database // reauthenticateNode updates the node expiry in the database
@ -558,7 +546,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
func (a *AuthProviderOIDC) registerNode( func (a *AuthProviderOIDC) registerNode(
user *types.User, user *types.User,
machineKey *key.MachinePublic, registrationID types.RegistrationID,
expiry time.Time, expiry time.Time,
) error { ) error {
ipv4, ipv6, err := a.ipAlloc.Next() ipv4, ipv6, err := a.ipAlloc.Next()
@ -567,7 +555,7 @@ func (a *AuthProviderOIDC) registerNode(
} }
if _, err := a.db.RegisterNodeFromAuthCallback( if _, err := a.db.RegisterNodeFromAuthCallback(
*machineKey, registrationID,
types.UserID(user.ID), types.UserID(user.ID),
&expiry, &expiry,
util.RegisterMethodOIDC, util.RegisterMethodOIDC,