move refresh logic to db layer

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2025-01-17 12:46:40 +01:00
parent d13c15cb67
commit b177b24c6d
No known key found for this signature in database
4 changed files with 113 additions and 118 deletions

View file

@ -124,6 +124,9 @@ func (h *Headscale) handleRegister(
logInfo, logTrace, _ := logAuthFunc(regReq, machineKey, registrationId)
now := time.Now().UTC()
logTrace("handleRegister called, looking up machine in DB")
// TODO(kradalby): Use reqs NodeKey and OldNodeKey as indicators for new registrations vs
// key refreshes. This will allow us to remove the machineKey from the registration request.
node, err := h.db.GetNodeByAnyKey(machineKey, regReq.NodeKey, regReq.OldNodeKey)
logTrace("handleRegister database lookup has returned")
if errors.Is(err, gorm.ErrRecordNotFound) {
@ -329,6 +332,8 @@ func (h *Headscale) handleAuthKey(
// The error is not important, because if it does not
// exist, then this is a new node and we will move
// on to registration.
// TODO(kradalby): Use reqs NodeKey and OldNodeKey as indicators for new registrations vs
// key refreshes. This will allow us to remove the machineKey from the registration request.
node, _ := h.db.GetNodeByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey)
if node != nil {
log.Trace().

View file

@ -343,16 +343,24 @@ func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error {
return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("last_seen", lastSeen).Error
}
func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
// HandleNodeFromAuthPath is called from the OIDC or CLI auth path
// with a registrationID to register or reauthenticate a node.
// If the node found in the registration cache is not already registered,
// it will be registered with the user and the node will be removed from the cache.
// If the node is already registered, the expiry will be updated.
// The node, and a boolean indicating if it was a new node or not, will be returned.
func (hsdb *HSDatabase) HandleNodeFromAuthPath(
registrationID types.RegistrationID,
userID types.UserID,
nodeExpiry *time.Time,
registrationMethod string,
ipv4 *netip.Addr,
ipv6 *netip.Addr,
) (*types.Node, error) {
return Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
) (*types.Node, bool, error) {
var newNode bool
node, err := Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
if reg, ok := hsdb.regCache.Get(registrationID); ok {
if node, _ := GetNodeByNodeKey(tx, reg.Node.NodeKey); node == nil {
user, err := GetUserByID(tx, userID)
if err != nil {
return nil, fmt.Errorf(
@ -396,11 +404,22 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
// Signal to waiting clients that the machine has been registered.
close(reg.Registered)
newNode = true
return node, err
} else {
// If the node is already registered, this is a refresh.
err := NodeSetExpiry(tx, node.ID, *nodeExpiry)
if err != nil {
return nil, err
}
return node, nil
}
}
return nil, ErrNodeNotFoundRegistrationCache
})
return node, newNode, err
}
func (hsdb *HSDatabase) RegisterNode(node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) {

View file

@ -245,7 +245,7 @@ func (api headscaleV1APIServer) RegisterNode(
return nil, fmt.Errorf("looking up user: %w", err)
}
node, err := api.h.db.RegisterNodeFromAuthCallback(
node, _, err := api.h.db.HandleNodeFromAuthPath(
registrationId,
types.UserID(user.ID),
nil,

View file

@ -286,49 +286,27 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
return
}
// Retrieve the node and the machine key from the state cache and
// database.
// TODO(kradalby): Is this comment right?
// If the node exists, then the node should be reauthenticated,
// if the node does not exist, and the machine key exists, then
// this is a new node that should be registered.
node, mKey := a.getMachineKeyFromState(state)
// Reauthenticate the node if it does exists.
if node != nil {
err := a.reauthenticateNode(node, nodeExpiry)
if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
}
// TODO(kradalby): replace with go-elem
var content bytes.Buffer
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
User: user.DisplayNameOrUsername(),
Verb: "Reauthenticated",
}); err != nil {
http.Error(writer, fmt.Errorf("rendering OIDC callback template: %w", err).Error(), http.StatusInternalServerError)
return
}
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(content.Bytes())
if err != nil {
util.LogErr(err, "Failed to write response")
}
return
}
registrationId := a.getRegistrationIDFromState(state)
// Register the node if it does not exist.
if registrationId != nil {
if err := a.registerNode(user, *registrationId, nodeExpiry); err != nil {
verb := "Reauthenticated"
newNode, err := a.handleRegistrationID(user, *registrationId, nodeExpiry)
if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
}
content, err := renderOIDCCallbackTemplate(user)
if newNode {
verb = "Authenticated"
}
// TODO(kradalby): replace with go-elem
content, err := renderOIDCCallbackTemplate(user, verb)
if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
@ -462,33 +440,6 @@ func (a *AuthProviderOIDC) getRegistrationIDFromState(state string) *types.Regis
return &regInfo.RegistrationID
}
// reauthenticateNode updates the node expiry in the database
// and notifies the node and its peers about the change.
func (a *AuthProviderOIDC) reauthenticateNode(
node *types.Node,
expiry time.Time,
) error {
err := a.db.NodeSetExpiry(node.ID, expiry)
if err != nil {
return err
}
ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname)
a.notifier.NotifyByNodeID(
ctx,
types.StateUpdate{
Type: types.StateSelfUpdate,
ChangeNodes: []types.NodeID{node.ID},
},
node.ID,
)
ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname)
a.notifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, expiry), node.ID)
return nil
}
func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
claims *types.OIDCClaims,
) (*types.User, error) {
@ -544,43 +495,63 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
return user, nil
}
func (a *AuthProviderOIDC) registerNode(
func (a *AuthProviderOIDC) handleRegistrationID(
user *types.User,
registrationID types.RegistrationID,
expiry time.Time,
) error {
) (bool, error) {
ipv4, ipv6, err := a.ipAlloc.Next()
if err != nil {
return err
return false, err
}
if _, err := a.db.RegisterNodeFromAuthCallback(
node, newNode, err := a.db.HandleNodeFromAuthPath(
registrationID,
types.UserID(user.ID),
&expiry,
util.RegisterMethodOIDC,
ipv4, ipv6,
); err != nil {
return fmt.Errorf("could not register node: %w", err)
)
if err != nil {
return false, fmt.Errorf("could not register node: %w", err)
}
// Send an update to all nodes if this is a new node that they need to know
// about.
// If this is a refresh, just send new expiry updates.
if newNode {
err = nodesChangedHook(a.db, a.polMan, a.notifier)
if err != nil {
return fmt.Errorf("updating resources using node: %w", err)
return false, fmt.Errorf("updating resources using node: %w", err)
}
} else {
ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname)
a.notifier.NotifyByNodeID(
ctx,
types.StateUpdate{
Type: types.StateSelfUpdate,
ChangeNodes: []types.NodeID{node.ID},
},
node.ID,
)
ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname)
a.notifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, expiry), node.ID)
}
return nil
return newNode, nil
}
// TODO(kradalby):
// Rewrite in elem-go.
func renderOIDCCallbackTemplate(
user *types.User,
verb string,
) (*bytes.Buffer, error) {
var content bytes.Buffer
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
User: user.DisplayNameOrUsername(),
Verb: "Authenticated",
Verb: verb,
}); err != nil {
return nil, fmt.Errorf("rendering OIDC callback template: %w", err)
}