From b177b24c6d6acef013eff6c5ea12a8e389224df6 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Fri, 17 Jan 2025 12:46:40 +0100 Subject: [PATCH] move refresh logic to db layer Signed-off-by: Kristoffer Dalby --- hscontrol/auth.go | 5 ++ hscontrol/db/node.go | 109 +++++++++++++++++++++++----------------- hscontrol/grpcv1.go | 2 +- hscontrol/oidc.go | 115 ++++++++++++++++--------------------------- 4 files changed, 113 insertions(+), 118 deletions(-) diff --git a/hscontrol/auth.go b/hscontrol/auth.go index 491594c3..23d66bb5 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -124,6 +124,9 @@ func (h *Headscale) handleRegister( logInfo, logTrace, _ := logAuthFunc(regReq, machineKey, registrationId) now := time.Now().UTC() logTrace("handleRegister called, looking up machine in DB") + + // TODO(kradalby): Use reqs NodeKey and OldNodeKey as indicators for new registrations vs + // key refreshes. This will allow us to remove the machineKey from the registration request. node, err := h.db.GetNodeByAnyKey(machineKey, regReq.NodeKey, regReq.OldNodeKey) logTrace("handleRegister database lookup has returned") if errors.Is(err, gorm.ErrRecordNotFound) { @@ -329,6 +332,8 @@ func (h *Headscale) handleAuthKey( // The error is not important, because if it does not // exist, then this is a new node and we will move // on to registration. + // TODO(kradalby): Use reqs NodeKey and OldNodeKey as indicators for new registrations vs + // key refreshes. This will allow us to remove the machineKey from the registration request. node, _ := h.db.GetNodeByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey) if node != nil { log.Trace(). diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index d7b0864f..f722d9ab 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -343,64 +343,83 @@ func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error { return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("last_seen", lastSeen).Error } -func (hsdb *HSDatabase) RegisterNodeFromAuthCallback( +// HandleNodeFromAuthPath is called from the OIDC or CLI auth path +// with a registrationID to register or reauthenticate a node. +// If the node found in the registration cache is not already registered, +// it will be registered with the user and the node will be removed from the cache. +// If the node is already registered, the expiry will be updated. +// The node, and a boolean indicating if it was a new node or not, will be returned. +func (hsdb *HSDatabase) HandleNodeFromAuthPath( registrationID types.RegistrationID, userID types.UserID, nodeExpiry *time.Time, registrationMethod string, ipv4 *netip.Addr, ipv6 *netip.Addr, -) (*types.Node, error) { - return Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) { +) (*types.Node, bool, error) { + var newNode bool + node, err := Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) { if reg, ok := hsdb.regCache.Get(registrationID); ok { - user, err := GetUserByID(tx, userID) - if err != nil { - return nil, fmt.Errorf( - "failed to find user in register node from auth callback, %w", - err, + if node, _ := GetNodeByNodeKey(tx, reg.Node.NodeKey); node == nil { + user, err := GetUserByID(tx, userID) + if err != nil { + return nil, fmt.Errorf( + "failed to find user in register node from auth callback, %w", + err, + ) + } + + log.Debug(). + Str("registration_id", registrationID.String()). + Str("username", user.Username()). + Str("registrationMethod", registrationMethod). + Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)). + Msg("Registering node from API/CLI or auth callback") + + // TODO(kradalby): This looks quite wrong? why ID 0? + // Why not always? + // Registration of expired node with different user + if reg.Node.ID != 0 && + reg.Node.UserID != user.ID { + return nil, ErrDifferentRegisteredUser + } + + reg.Node.UserID = user.ID + reg.Node.User = *user + reg.Node.RegisterMethod = registrationMethod + + if nodeExpiry != nil { + reg.Node.Expiry = nodeExpiry + } + + node, err := RegisterNode( + tx, + reg.Node, + ipv4, ipv6, ) + + if err == nil { + hsdb.regCache.Delete(registrationID) + } + + // Signal to waiting clients that the machine has been registered. + close(reg.Registered) + newNode = true + return node, err + } else { + // If the node is already registered, this is a refresh. + err := NodeSetExpiry(tx, node.ID, *nodeExpiry) + if err != nil { + return nil, err + } + return node, nil } - - log.Debug(). - Str("registration_id", registrationID.String()). - Str("username", user.Username()). - Str("registrationMethod", registrationMethod). - Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)). - Msg("Registering node from API/CLI or auth callback") - - // TODO(kradalby): This looks quite wrong? why ID 0? - // Why not always? - // Registration of expired node with different user - if reg.Node.ID != 0 && - reg.Node.UserID != user.ID { - return nil, ErrDifferentRegisteredUser - } - - reg.Node.UserID = user.ID - reg.Node.User = *user - reg.Node.RegisterMethod = registrationMethod - - if nodeExpiry != nil { - reg.Node.Expiry = nodeExpiry - } - - node, err := RegisterNode( - tx, - reg.Node, - ipv4, ipv6, - ) - - if err == nil { - hsdb.regCache.Delete(registrationID) - } - - // Signal to waiting clients that the machine has been registered. - close(reg.Registered) - return node, err } return nil, ErrNodeNotFoundRegistrationCache }) + + return node, newNode, err } func (hsdb *HSDatabase) RegisterNode(node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) { diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 1d4ecbb1..e438332a 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -245,7 +245,7 @@ func (api headscaleV1APIServer) RegisterNode( return nil, fmt.Errorf("looking up user: %w", err) } - node, err := api.h.db.RegisterNodeFromAuthCallback( + node, _, err := api.h.db.HandleNodeFromAuthPath( registrationId, types.UserID(user.ID), nil, diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 73a64b91..5bc548d0 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -286,49 +286,27 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( return } - // Retrieve the node and the machine key from the state cache and - // database. + // TODO(kradalby): Is this comment right? // If the node exists, then the node should be reauthenticated, // if the node does not exist, and the machine key exists, then // this is a new node that should be registered. - node, mKey := a.getMachineKeyFromState(state) - - // Reauthenticate the node if it does exists. - if node != nil { - err := a.reauthenticateNode(node, nodeExpiry) - if err != nil { - http.Error(writer, err.Error(), http.StatusInternalServerError) - return - } - - // TODO(kradalby): replace with go-elem - var content bytes.Buffer - if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{ - User: user.DisplayNameOrUsername(), - Verb: "Reauthenticated", - }); err != nil { - http.Error(writer, fmt.Errorf("rendering OIDC callback template: %w", err).Error(), http.StatusInternalServerError) - return - } - - writer.Header().Set("Content-Type", "text/html; charset=utf-8") - writer.WriteHeader(http.StatusOK) - _, err = writer.Write(content.Bytes()) - if err != nil { - util.LogErr(err, "Failed to write response") - } - - return - } + registrationId := a.getRegistrationIDFromState(state) // Register the node if it does not exist. if registrationId != nil { - if err := a.registerNode(user, *registrationId, nodeExpiry); err != nil { + verb := "Reauthenticated" + newNode, err := a.handleRegistrationID(user, *registrationId, nodeExpiry) + if err != nil { http.Error(writer, err.Error(), http.StatusInternalServerError) return } - content, err := renderOIDCCallbackTemplate(user) + if newNode { + verb = "Authenticated" + } + + // TODO(kradalby): replace with go-elem + content, err := renderOIDCCallbackTemplate(user, verb) if err != nil { http.Error(writer, err.Error(), http.StatusInternalServerError) return @@ -462,33 +440,6 @@ func (a *AuthProviderOIDC) getRegistrationIDFromState(state string) *types.Regis return ®Info.RegistrationID } -// reauthenticateNode updates the node expiry in the database -// and notifies the node and its peers about the change. -func (a *AuthProviderOIDC) reauthenticateNode( - node *types.Node, - expiry time.Time, -) error { - err := a.db.NodeSetExpiry(node.ID, expiry) - if err != nil { - return err - } - - ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname) - a.notifier.NotifyByNodeID( - ctx, - types.StateUpdate{ - Type: types.StateSelfUpdate, - ChangeNodes: []types.NodeID{node.ID}, - }, - node.ID, - ) - - ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname) - a.notifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, expiry), node.ID) - - return nil -} - func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( claims *types.OIDCClaims, ) (*types.User, error) { @@ -544,43 +495,63 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( return user, nil } -func (a *AuthProviderOIDC) registerNode( +func (a *AuthProviderOIDC) handleRegistrationID( user *types.User, registrationID types.RegistrationID, expiry time.Time, -) error { +) (bool, error) { ipv4, ipv6, err := a.ipAlloc.Next() if err != nil { - return err + return false, err } - if _, err := a.db.RegisterNodeFromAuthCallback( + node, newNode, err := a.db.HandleNodeFromAuthPath( registrationID, types.UserID(user.ID), &expiry, util.RegisterMethodOIDC, ipv4, ipv6, - ); err != nil { - return fmt.Errorf("could not register node: %w", err) - } - - err = nodesChangedHook(a.db, a.polMan, a.notifier) + ) if err != nil { - return fmt.Errorf("updating resources using node: %w", err) + return false, fmt.Errorf("could not register node: %w", err) } - return nil + // Send an update to all nodes if this is a new node that they need to know + // about. + // If this is a refresh, just send new expiry updates. + if newNode { + err = nodesChangedHook(a.db, a.polMan, a.notifier) + if err != nil { + return false, fmt.Errorf("updating resources using node: %w", err) + } + } else { + ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname) + a.notifier.NotifyByNodeID( + ctx, + types.StateUpdate{ + Type: types.StateSelfUpdate, + ChangeNodes: []types.NodeID{node.ID}, + }, + node.ID, + ) + + ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname) + a.notifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, expiry), node.ID) + } + + return newNode, nil } // TODO(kradalby): // Rewrite in elem-go. func renderOIDCCallbackTemplate( user *types.User, + verb string, ) (*bytes.Buffer, error) { var content bytes.Buffer if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{ User: user.DisplayNameOrUsername(), - Verb: "Authenticated", + Verb: verb, }); err != nil { return nil, fmt.Errorf("rendering OIDC callback template: %w", err) }