headscale/hscontrol/auth.go
2025-02-01 09:16:51 +00:00

266 lines
6.7 KiB
Go

package hscontrol
import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"time"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/ptr"
)
type AuthProvider interface {
RegisterHandler(http.ResponseWriter, *http.Request)
AuthURL(types.RegistrationID) string
}
func (h *Headscale) handleRegister(
ctx context.Context,
regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
node, err := h.db.GetNodeByNodeKey(regReq.NodeKey)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("looking up node in database: %w", err)
}
if node != nil {
resp, err := h.handleExistingNode(node, regReq, machineKey)
if err != nil {
return nil, fmt.Errorf("handling existing node: %w", err)
}
return resp, nil
}
if regReq.Followup != "" {
// TODO(kradalby): Does this need to return an error of some sort?
// Maybe if the registration fails down the line it can be sent
// on the channel and returned here?
h.waitForFollowup(ctx, regReq)
}
if regReq.Auth != nil && regReq.Auth.AuthKey != "" {
resp, err := h.handleRegisterWithAuthKey(regReq, machineKey)
if err != nil {
return nil, fmt.Errorf("handling register with auth key: %w", err)
}
return resp, nil
}
resp, err := h.handleRegisterInteractive(regReq, machineKey)
if err != nil {
return nil, fmt.Errorf("handling register interactive: %w", err)
}
return resp, nil
}
func (h *Headscale) handleExistingNode(
node *types.Node,
regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
if node.MachineKey != machineKey {
return nil, errors.New("node already exists with different machine key")
}
expired := node.IsExpired()
if !expired && !regReq.Expiry.IsZero() {
requestExpiry := regReq.Expiry
// The client is trying to extend their key, this is not allowed.
if requestExpiry.After(time.Now()) {
return nil, errors.New("extending key is not allowed")
}
// If the request expiry is in the past, we consider it a logout.
if requestExpiry.Before(time.Now()) {
if node.IsEphemeral() {
changedNodes, err := h.db.DeleteNode(node, h.nodeNotifier.LikelyConnectedMap())
if err != nil {
return nil, fmt.Errorf("deleting ephemeral node: %w", err)
}
ctx := types.NotifyCtx(context.Background(), "logout-ephemeral", "na")
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StatePeerRemoved,
Removed: []types.NodeID{node.ID},
})
if changedNodes != nil {
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: changedNodes,
})
}
}
expired = true
}
err := h.db.NodeSetExpiry(node.ID, requestExpiry)
if err != nil {
return nil, fmt.Errorf("setting node expiry: %w", err)
}
ctx := types.NotifyCtx(context.Background(), "logout-expiry", "na")
h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, requestExpiry), node.ID)
}
return &tailcfg.RegisterResponse{
// TODO(kradalby): Only send for user-owned nodes
// and not tagged nodes when tags is working.
User: *node.User.TailscaleUser(),
Login: *node.User.TailscaleLogin(),
NodeKeyExpired: expired,
// Headscale does not implement the concept of machine authorization
// so we always return true here.
// Revisit this if #2176 gets implemented.
MachineAuthorized: true,
}, nil
}
func (h *Headscale) waitForFollowup(
ctx context.Context,
regReq tailcfg.RegisterRequest,
) {
fu, err := url.Parse(regReq.Followup)
if err != nil {
return
}
followupReg, err := types.RegistrationIDFromString(strings.ReplaceAll(fu.Path, "/register/", ""))
if err != nil {
return
}
if reg, ok := h.registrationCache.Get(followupReg); ok {
select {
case <-ctx.Done():
return
case <-reg.Registered:
return
}
}
}
func (h *Headscale) handleRegisterWithAuthKey(
regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
pak, err := h.db.ValidatePreAuthKey(regReq.Auth.AuthKey)
if err != nil {
return nil, fmt.Errorf("invalid pre auth key: %w", err)
}
nodeToRegister := types.Node{
Hostname: regReq.Hostinfo.Hostname,
UserID: pak.User.ID,
User: pak.User,
MachineKey: machineKey,
NodeKey: regReq.NodeKey,
Hostinfo: regReq.Hostinfo,
LastSeen: ptr.To(time.Now()),
RegisterMethod: util.RegisterMethodAuthKey,
// TODO(kradalby): This should not be set on the node,
// they should be looked up through the key, which is
// attached to the node.
ForcedTags: pak.Proto().GetAclTags(),
AuthKey: pak,
AuthKeyID: &pak.ID,
}
if !regReq.Expiry.IsZero() {
nodeToRegister.Expiry = &regReq.Expiry
}
ipv4, ipv6, err := h.ipAlloc.Next()
if err != nil {
return nil, fmt.Errorf("allocating IPs: %w", err)
}
node, err := db.Write(h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
node, err := db.RegisterNode(tx,
nodeToRegister,
ipv4, ipv6,
)
if err != nil {
return nil, fmt.Errorf("registering node: %w", err)
}
if !pak.Reusable {
err = db.UsePreAuthKey(tx, pak)
if err != nil {
return nil, fmt.Errorf("using pre auth key: %w", err)
}
}
return node, nil
})
if err != nil {
return nil, err
}
updateSent, err := nodesChangedHook(h.db, h.polMan, h.nodeNotifier)
if err != nil {
return nil, fmt.Errorf("nodes changed hook: %w", err)
}
if !updateSent {
ctx := types.NotifyCtx(context.Background(), "node updated", node.Hostname)
h.nodeNotifier.NotifyAll(ctx, types.StateUpdatePeerAdded(node.ID))
}
return &tailcfg.RegisterResponse{
MachineAuthorized: true,
NodeKeyExpired: node.IsExpired(),
User: *pak.User.TailscaleUser(),
Login: *pak.User.TailscaleLogin(),
}, nil
}
func (h *Headscale) handleRegisterInteractive(
regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
registrationId, err := types.NewRegistrationID()
if err != nil {
return nil, fmt.Errorf("generating registration ID: %w", err)
}
newNode := types.RegisterNode{
Node: types.Node{
Hostname: regReq.Hostinfo.Hostname,
MachineKey: machineKey,
NodeKey: regReq.NodeKey,
Hostinfo: regReq.Hostinfo,
LastSeen: ptr.To(time.Now()),
},
Registered: make(chan struct{}),
}
if !regReq.Expiry.IsZero() {
newNode.Node.Expiry = &regReq.Expiry
}
h.registrationCache.Set(
registrationId,
newNode,
)
return &tailcfg.RegisterResponse{
AuthURL: h.authProvider.AuthURL(registrationId),
}, nil
}