move refresh logic to db layer

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2025-01-17 12:46:40 +01:00
parent d13c15cb67
commit b177b24c6d
No known key found for this signature in database
4 changed files with 113 additions and 118 deletions

View file

@ -124,6 +124,9 @@ func (h *Headscale) handleRegister(
logInfo, logTrace, _ := logAuthFunc(regReq, machineKey, registrationId) logInfo, logTrace, _ := logAuthFunc(regReq, machineKey, registrationId)
now := time.Now().UTC() now := time.Now().UTC()
logTrace("handleRegister called, looking up machine in DB") logTrace("handleRegister called, looking up machine in DB")
// TODO(kradalby): Use reqs NodeKey and OldNodeKey as indicators for new registrations vs
// key refreshes. This will allow us to remove the machineKey from the registration request.
node, err := h.db.GetNodeByAnyKey(machineKey, regReq.NodeKey, regReq.OldNodeKey) node, err := h.db.GetNodeByAnyKey(machineKey, regReq.NodeKey, regReq.OldNodeKey)
logTrace("handleRegister database lookup has returned") logTrace("handleRegister database lookup has returned")
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
@ -329,6 +332,8 @@ func (h *Headscale) handleAuthKey(
// The error is not important, because if it does not // The error is not important, because if it does not
// exist, then this is a new node and we will move // exist, then this is a new node and we will move
// on to registration. // on to registration.
// TODO(kradalby): Use reqs NodeKey and OldNodeKey as indicators for new registrations vs
// key refreshes. This will allow us to remove the machineKey from the registration request.
node, _ := h.db.GetNodeByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey) node, _ := h.db.GetNodeByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey)
if node != nil { if node != nil {
log.Trace(). log.Trace().

View file

@ -343,64 +343,83 @@ func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error {
return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("last_seen", lastSeen).Error return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("last_seen", lastSeen).Error
} }
func (hsdb *HSDatabase) RegisterNodeFromAuthCallback( // HandleNodeFromAuthPath is called from the OIDC or CLI auth path
// with a registrationID to register or reauthenticate a node.
// If the node found in the registration cache is not already registered,
// it will be registered with the user and the node will be removed from the cache.
// If the node is already registered, the expiry will be updated.
// The node, and a boolean indicating if it was a new node or not, will be returned.
func (hsdb *HSDatabase) HandleNodeFromAuthPath(
registrationID types.RegistrationID, registrationID types.RegistrationID,
userID types.UserID, userID types.UserID,
nodeExpiry *time.Time, nodeExpiry *time.Time,
registrationMethod string, registrationMethod string,
ipv4 *netip.Addr, ipv4 *netip.Addr,
ipv6 *netip.Addr, ipv6 *netip.Addr,
) (*types.Node, error) { ) (*types.Node, bool, error) {
return Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) { var newNode bool
node, err := Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
if reg, ok := hsdb.regCache.Get(registrationID); ok { if reg, ok := hsdb.regCache.Get(registrationID); ok {
user, err := GetUserByID(tx, userID) if node, _ := GetNodeByNodeKey(tx, reg.Node.NodeKey); node == nil {
if err != nil { user, err := GetUserByID(tx, userID)
return nil, fmt.Errorf( if err != nil {
"failed to find user in register node from auth callback, %w", return nil, fmt.Errorf(
err, "failed to find user in register node from auth callback, %w",
err,
)
}
log.Debug().
Str("registration_id", registrationID.String()).
Str("username", user.Username()).
Str("registrationMethod", registrationMethod).
Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)).
Msg("Registering node from API/CLI or auth callback")
// TODO(kradalby): This looks quite wrong? why ID 0?
// Why not always?
// Registration of expired node with different user
if reg.Node.ID != 0 &&
reg.Node.UserID != user.ID {
return nil, ErrDifferentRegisteredUser
}
reg.Node.UserID = user.ID
reg.Node.User = *user
reg.Node.RegisterMethod = registrationMethod
if nodeExpiry != nil {
reg.Node.Expiry = nodeExpiry
}
node, err := RegisterNode(
tx,
reg.Node,
ipv4, ipv6,
) )
if err == nil {
hsdb.regCache.Delete(registrationID)
}
// Signal to waiting clients that the machine has been registered.
close(reg.Registered)
newNode = true
return node, err
} else {
// If the node is already registered, this is a refresh.
err := NodeSetExpiry(tx, node.ID, *nodeExpiry)
if err != nil {
return nil, err
}
return node, nil
} }
log.Debug().
Str("registration_id", registrationID.String()).
Str("username", user.Username()).
Str("registrationMethod", registrationMethod).
Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)).
Msg("Registering node from API/CLI or auth callback")
// TODO(kradalby): This looks quite wrong? why ID 0?
// Why not always?
// Registration of expired node with different user
if reg.Node.ID != 0 &&
reg.Node.UserID != user.ID {
return nil, ErrDifferentRegisteredUser
}
reg.Node.UserID = user.ID
reg.Node.User = *user
reg.Node.RegisterMethod = registrationMethod
if nodeExpiry != nil {
reg.Node.Expiry = nodeExpiry
}
node, err := RegisterNode(
tx,
reg.Node,
ipv4, ipv6,
)
if err == nil {
hsdb.regCache.Delete(registrationID)
}
// Signal to waiting clients that the machine has been registered.
close(reg.Registered)
return node, err
} }
return nil, ErrNodeNotFoundRegistrationCache return nil, ErrNodeNotFoundRegistrationCache
}) })
return node, newNode, err
} }
func (hsdb *HSDatabase) RegisterNode(node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) { func (hsdb *HSDatabase) RegisterNode(node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) {

View file

@ -245,7 +245,7 @@ func (api headscaleV1APIServer) RegisterNode(
return nil, fmt.Errorf("looking up user: %w", err) return nil, fmt.Errorf("looking up user: %w", err)
} }
node, err := api.h.db.RegisterNodeFromAuthCallback( node, _, err := api.h.db.HandleNodeFromAuthPath(
registrationId, registrationId,
types.UserID(user.ID), types.UserID(user.ID),
nil, nil,

View file

@ -286,49 +286,27 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
return return
} }
// Retrieve the node and the machine key from the state cache and // TODO(kradalby): Is this comment right?
// database.
// If the node exists, then the node should be reauthenticated, // If the node exists, then the node should be reauthenticated,
// if the node does not exist, and the machine key exists, then // if the node does not exist, and the machine key exists, then
// this is a new node that should be registered. // this is a new node that should be registered.
node, mKey := a.getMachineKeyFromState(state) registrationId := a.getRegistrationIDFromState(state)
// Reauthenticate the node if it does exists.
if node != nil {
err := a.reauthenticateNode(node, nodeExpiry)
if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
}
// TODO(kradalby): replace with go-elem
var content bytes.Buffer
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
User: user.DisplayNameOrUsername(),
Verb: "Reauthenticated",
}); err != nil {
http.Error(writer, fmt.Errorf("rendering OIDC callback template: %w", err).Error(), http.StatusInternalServerError)
return
}
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(content.Bytes())
if err != nil {
util.LogErr(err, "Failed to write response")
}
return
}
// Register the node if it does not exist. // Register the node if it does not exist.
if registrationId != nil { if registrationId != nil {
if err := a.registerNode(user, *registrationId, nodeExpiry); err != nil { verb := "Reauthenticated"
newNode, err := a.handleRegistrationID(user, *registrationId, nodeExpiry)
if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError) http.Error(writer, err.Error(), http.StatusInternalServerError)
return return
} }
content, err := renderOIDCCallbackTemplate(user) if newNode {
verb = "Authenticated"
}
// TODO(kradalby): replace with go-elem
content, err := renderOIDCCallbackTemplate(user, verb)
if err != nil { if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError) http.Error(writer, err.Error(), http.StatusInternalServerError)
return return
@ -462,33 +440,6 @@ func (a *AuthProviderOIDC) getRegistrationIDFromState(state string) *types.Regis
return &regInfo.RegistrationID return &regInfo.RegistrationID
} }
// reauthenticateNode updates the node expiry in the database
// and notifies the node and its peers about the change.
func (a *AuthProviderOIDC) reauthenticateNode(
node *types.Node,
expiry time.Time,
) error {
err := a.db.NodeSetExpiry(node.ID, expiry)
if err != nil {
return err
}
ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname)
a.notifier.NotifyByNodeID(
ctx,
types.StateUpdate{
Type: types.StateSelfUpdate,
ChangeNodes: []types.NodeID{node.ID},
},
node.ID,
)
ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname)
a.notifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, expiry), node.ID)
return nil
}
func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
claims *types.OIDCClaims, claims *types.OIDCClaims,
) (*types.User, error) { ) (*types.User, error) {
@ -544,43 +495,63 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
return user, nil return user, nil
} }
func (a *AuthProviderOIDC) registerNode( func (a *AuthProviderOIDC) handleRegistrationID(
user *types.User, user *types.User,
registrationID types.RegistrationID, registrationID types.RegistrationID,
expiry time.Time, expiry time.Time,
) error { ) (bool, error) {
ipv4, ipv6, err := a.ipAlloc.Next() ipv4, ipv6, err := a.ipAlloc.Next()
if err != nil { if err != nil {
return err return false, err
} }
if _, err := a.db.RegisterNodeFromAuthCallback( node, newNode, err := a.db.HandleNodeFromAuthPath(
registrationID, registrationID,
types.UserID(user.ID), types.UserID(user.ID),
&expiry, &expiry,
util.RegisterMethodOIDC, util.RegisterMethodOIDC,
ipv4, ipv6, ipv4, ipv6,
); err != nil { )
return fmt.Errorf("could not register node: %w", err)
}
err = nodesChangedHook(a.db, a.polMan, a.notifier)
if err != nil { if err != nil {
return fmt.Errorf("updating resources using node: %w", err) return false, fmt.Errorf("could not register node: %w", err)
} }
return nil // Send an update to all nodes if this is a new node that they need to know
// about.
// If this is a refresh, just send new expiry updates.
if newNode {
err = nodesChangedHook(a.db, a.polMan, a.notifier)
if err != nil {
return false, fmt.Errorf("updating resources using node: %w", err)
}
} else {
ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname)
a.notifier.NotifyByNodeID(
ctx,
types.StateUpdate{
Type: types.StateSelfUpdate,
ChangeNodes: []types.NodeID{node.ID},
},
node.ID,
)
ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname)
a.notifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, expiry), node.ID)
}
return newNode, nil
} }
// TODO(kradalby): // TODO(kradalby):
// Rewrite in elem-go. // Rewrite in elem-go.
func renderOIDCCallbackTemplate( func renderOIDCCallbackTemplate(
user *types.User, user *types.User,
verb string,
) (*bytes.Buffer, error) { ) (*bytes.Buffer, error) {
var content bytes.Buffer var content bytes.Buffer
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{ if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
User: user.DisplayNameOrUsername(), User: user.DisplayNameOrUsername(),
Verb: "Authenticated", Verb: verb,
}); err != nil { }); err != nil {
return nil, fmt.Errorf("rendering OIDC callback template: %w", err) return nil, fmt.Errorf("rendering OIDC callback template: %w", err)
} }