headscale/hscontrol/grpcv1.go
Kristoffer Dalby 68a8ecee7a
Prepare notify channel before sending first update (#1730)
* create channel before sending first update

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* do not notify on register, wait for connect

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

---------

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-02-12 09:11:17 +01:00

677 lines
16 KiB
Go

// nolint
package hscontrol
import (
"context"
"fmt"
"strings"
"time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
type headscaleV1APIServer struct { // v1.HeadscaleServiceServer
v1.UnimplementedHeadscaleServiceServer
h *Headscale
}
func newHeadscaleV1APIServer(h *Headscale) v1.HeadscaleServiceServer {
return headscaleV1APIServer{
h: h,
}
}
func (api headscaleV1APIServer) GetUser(
ctx context.Context,
request *v1.GetUserRequest,
) (*v1.GetUserResponse, error) {
user, err := api.h.db.GetUser(request.GetName())
if err != nil {
return nil, err
}
return &v1.GetUserResponse{User: user.Proto()}, nil
}
func (api headscaleV1APIServer) CreateUser(
ctx context.Context,
request *v1.CreateUserRequest,
) (*v1.CreateUserResponse, error) {
user, err := api.h.db.CreateUser(request.GetName())
if err != nil {
return nil, err
}
return &v1.CreateUserResponse{User: user.Proto()}, nil
}
func (api headscaleV1APIServer) RenameUser(
ctx context.Context,
request *v1.RenameUserRequest,
) (*v1.RenameUserResponse, error) {
err := api.h.db.RenameUser(request.GetOldName(), request.GetNewName())
if err != nil {
return nil, err
}
user, err := api.h.db.GetUser(request.GetNewName())
if err != nil {
return nil, err
}
return &v1.RenameUserResponse{User: user.Proto()}, nil
}
func (api headscaleV1APIServer) DeleteUser(
ctx context.Context,
request *v1.DeleteUserRequest,
) (*v1.DeleteUserResponse, error) {
err := api.h.db.DestroyUser(request.GetName())
if err != nil {
return nil, err
}
return &v1.DeleteUserResponse{}, nil
}
func (api headscaleV1APIServer) ListUsers(
ctx context.Context,
request *v1.ListUsersRequest,
) (*v1.ListUsersResponse, error) {
users, err := api.h.db.ListUsers()
if err != nil {
return nil, err
}
response := make([]*v1.User, len(users))
for index, user := range users {
response[index] = user.Proto()
}
log.Trace().Caller().Interface("users", response).Msg("")
return &v1.ListUsersResponse{Users: response}, nil
}
func (api headscaleV1APIServer) CreatePreAuthKey(
ctx context.Context,
request *v1.CreatePreAuthKeyRequest,
) (*v1.CreatePreAuthKeyResponse, error) {
var expiration time.Time
if request.GetExpiration() != nil {
expiration = request.GetExpiration().AsTime()
}
for _, tag := range request.AclTags {
err := validateTag(tag)
if err != nil {
return &v1.CreatePreAuthKeyResponse{
PreAuthKey: nil,
}, status.Error(codes.InvalidArgument, err.Error())
}
}
preAuthKey, err := api.h.db.CreatePreAuthKey(
request.GetUser(),
request.GetReusable(),
request.GetEphemeral(),
&expiration,
request.AclTags,
)
if err != nil {
return nil, err
}
return &v1.CreatePreAuthKeyResponse{PreAuthKey: preAuthKey.Proto()}, nil
}
func (api headscaleV1APIServer) ExpirePreAuthKey(
ctx context.Context,
request *v1.ExpirePreAuthKeyRequest,
) (*v1.ExpirePreAuthKeyResponse, error) {
err := api.h.db.DB.Transaction(func(tx *gorm.DB) error {
preAuthKey, err := db.GetPreAuthKey(tx, request.GetUser(), request.Key)
if err != nil {
return err
}
return db.ExpirePreAuthKey(tx, preAuthKey)
})
if err != nil {
return nil, err
}
return &v1.ExpirePreAuthKeyResponse{}, nil
}
func (api headscaleV1APIServer) ListPreAuthKeys(
ctx context.Context,
request *v1.ListPreAuthKeysRequest,
) (*v1.ListPreAuthKeysResponse, error) {
preAuthKeys, err := api.h.db.ListPreAuthKeys(request.GetUser())
if err != nil {
return nil, err
}
response := make([]*v1.PreAuthKey, len(preAuthKeys))
for index, key := range preAuthKeys {
response[index] = key.Proto()
}
return &v1.ListPreAuthKeysResponse{PreAuthKeys: response}, nil
}
func (api headscaleV1APIServer) RegisterNode(
ctx context.Context,
request *v1.RegisterNodeRequest,
) (*v1.RegisterNodeResponse, error) {
log.Trace().
Str("user", request.GetUser()).
Str("machine_key", request.GetKey()).
Msg("Registering node")
var mkey key.MachinePublic
err := mkey.UnmarshalText([]byte(request.GetKey()))
if err != nil {
return nil, err
}
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
return db.RegisterNodeFromAuthCallback(
tx,
api.h.registrationCache,
mkey,
request.GetUser(),
nil,
util.RegisterMethodCLI,
api.h.cfg.IPPrefixes,
)
})
if err != nil {
return nil, err
}
return &v1.RegisterNodeResponse{Node: node.Proto()}, nil
}
func (api headscaleV1APIServer) GetNode(
ctx context.Context,
request *v1.GetNodeRequest,
) (*v1.GetNodeResponse, error) {
node, err := api.h.db.GetNodeByID(request.GetNodeId())
if err != nil {
return nil, err
}
resp := node.Proto()
// Populate the online field based on
// currently connected nodes.
resp.Online = api.h.nodeNotifier.IsConnected(node.MachineKey)
return &v1.GetNodeResponse{Node: resp}, nil
}
func (api headscaleV1APIServer) SetTags(
ctx context.Context,
request *v1.SetTagsRequest,
) (*v1.SetTagsResponse, error) {
for _, tag := range request.GetTags() {
err := validateTag(tag)
if err != nil {
return nil, err
}
}
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
err := db.SetTags(tx, request.GetNodeId(), request.GetTags())
if err != nil {
return nil, err
}
return db.GetNodeByID(tx, request.GetNodeId())
})
if err != nil {
return &v1.SetTagsResponse{
Node: nil,
}, status.Error(codes.InvalidArgument, err.Error())
}
stateUpdate := types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{node},
Message: "called from api.SetTags",
}
if stateUpdate.Valid() {
ctx := types.NotifyCtx(ctx, "cli-settags", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
}
log.Trace().
Str("node", node.Hostname).
Strs("tags", request.GetTags()).
Msg("Changing tags of node")
return &v1.SetTagsResponse{Node: node.Proto()}, nil
}
func validateTag(tag string) error {
if strings.Index(tag, "tag:") != 0 {
return fmt.Errorf("tag must start with the string 'tag:'")
}
if strings.ToLower(tag) != tag {
return fmt.Errorf("tag should be lowercase")
}
if len(strings.Fields(tag)) > 1 {
return fmt.Errorf("tag should not contains space")
}
return nil
}
func (api headscaleV1APIServer) DeleteNode(
ctx context.Context,
request *v1.DeleteNodeRequest,
) (*v1.DeleteNodeResponse, error) {
node, err := api.h.db.GetNodeByID(request.GetNodeId())
if err != nil {
return nil, err
}
err = api.h.db.DeleteNode(
node,
api.h.nodeNotifier.ConnectedMap(),
)
if err != nil {
return nil, err
}
stateUpdate := types.StateUpdate{
Type: types.StatePeerRemoved,
Removed: []tailcfg.NodeID{tailcfg.NodeID(node.ID)},
}
if stateUpdate.Valid() {
ctx := types.NotifyCtx(ctx, "cli-deletenode", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, stateUpdate)
}
return &v1.DeleteNodeResponse{}, nil
}
func (api headscaleV1APIServer) ExpireNode(
ctx context.Context,
request *v1.ExpireNodeRequest,
) (*v1.ExpireNodeResponse, error) {
now := time.Now()
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
db.NodeSetExpiry(
tx,
request.GetNodeId(),
now,
)
return db.GetNodeByID(tx, request.GetNodeId())
})
if err != nil {
return nil, err
}
selfUpdate := types.StateUpdate{
Type: types.StateSelfUpdate,
ChangeNodes: types.Nodes{node},
}
if selfUpdate.Valid() {
ctx := types.NotifyCtx(ctx, "cli-expirenode-self", node.Hostname)
api.h.nodeNotifier.NotifyByMachineKey(
ctx,
selfUpdate,
node.MachineKey)
}
stateUpdate := types.StateUpdateExpire(node.ID, now)
if stateUpdate.Valid() {
ctx := types.NotifyCtx(ctx, "cli-expirenode-peers", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
}
log.Trace().
Str("node", node.Hostname).
Time("expiry", *node.Expiry).
Msg("node expired")
return &v1.ExpireNodeResponse{Node: node.Proto()}, nil
}
func (api headscaleV1APIServer) RenameNode(
ctx context.Context,
request *v1.RenameNodeRequest,
) (*v1.RenameNodeResponse, error) {
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
err := db.RenameNode(
tx,
request.GetNodeId(),
request.GetNewName(),
)
if err != nil {
return nil, err
}
return db.GetNodeByID(tx, request.GetNodeId())
})
if err != nil {
return nil, err
}
stateUpdate := types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{node},
Message: "called from api.RenameNode",
}
if stateUpdate.Valid() {
ctx := types.NotifyCtx(ctx, "cli-renamenode", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
}
log.Trace().
Str("node", node.Hostname).
Str("new_name", request.GetNewName()).
Msg("node renamed")
return &v1.RenameNodeResponse{Node: node.Proto()}, nil
}
func (api headscaleV1APIServer) ListNodes(
ctx context.Context,
request *v1.ListNodesRequest,
) (*v1.ListNodesResponse, error) {
isConnected := api.h.nodeNotifier.ConnectedMap()
if request.GetUser() != "" {
nodes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Nodes, error) {
return db.ListNodesByUser(rx, request.GetUser())
})
if err != nil {
return nil, err
}
response := make([]*v1.Node, len(nodes))
for index, node := range nodes {
resp := node.Proto()
// Populate the online field based on
// currently connected nodes.
resp.Online = isConnected[node.MachineKey]
response[index] = resp
}
return &v1.ListNodesResponse{Nodes: response}, nil
}
nodes, err := api.h.db.ListNodes()
if err != nil {
return nil, err
}
response := make([]*v1.Node, len(nodes))
for index, node := range nodes {
resp := node.Proto()
// Populate the online field based on
// currently connected nodes.
resp.Online = isConnected[node.MachineKey]
validTags, invalidTags := api.h.ACLPolicy.TagsOfNode(
node,
)
resp.InvalidTags = invalidTags
resp.ValidTags = validTags
response[index] = resp
}
return &v1.ListNodesResponse{Nodes: response}, nil
}
func (api headscaleV1APIServer) MoveNode(
ctx context.Context,
request *v1.MoveNodeRequest,
) (*v1.MoveNodeResponse, error) {
node, err := api.h.db.GetNodeByID(request.GetNodeId())
if err != nil {
return nil, err
}
err = api.h.db.AssignNodeToUser(node, request.GetUser())
if err != nil {
return nil, err
}
return &v1.MoveNodeResponse{Node: node.Proto()}, nil
}
func (api headscaleV1APIServer) GetRoutes(
ctx context.Context,
request *v1.GetRoutesRequest,
) (*v1.GetRoutesResponse, error) {
routes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Routes, error) {
return db.GetRoutes(rx)
})
if err != nil {
return nil, err
}
return &v1.GetRoutesResponse{
Routes: types.Routes(routes).Proto(),
}, nil
}
func (api headscaleV1APIServer) EnableRoute(
ctx context.Context,
request *v1.EnableRouteRequest,
) (*v1.EnableRouteResponse, error) {
update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
return db.EnableRoute(tx, request.GetRouteId())
})
if err != nil {
return nil, err
}
if update != nil && update.Valid() {
ctx := types.NotifyCtx(ctx, "cli-enableroute", "unknown")
api.h.nodeNotifier.NotifyAll(
ctx, *update)
}
return &v1.EnableRouteResponse{}, nil
}
func (api headscaleV1APIServer) DisableRoute(
ctx context.Context,
request *v1.DisableRouteRequest,
) (*v1.DisableRouteResponse, error) {
isConnected := api.h.nodeNotifier.ConnectedMap()
update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
return db.DisableRoute(tx, request.GetRouteId(), isConnected)
})
if err != nil {
return nil, err
}
if update != nil && update.Valid() {
ctx := types.NotifyCtx(ctx, "cli-disableroute", "unknown")
api.h.nodeNotifier.NotifyAll(ctx, *update)
}
return &v1.DisableRouteResponse{}, nil
}
func (api headscaleV1APIServer) GetNodeRoutes(
ctx context.Context,
request *v1.GetNodeRoutesRequest,
) (*v1.GetNodeRoutesResponse, error) {
node, err := api.h.db.GetNodeByID(request.GetNodeId())
if err != nil {
return nil, err
}
routes, err := api.h.db.GetNodeRoutes(node)
if err != nil {
return nil, err
}
return &v1.GetNodeRoutesResponse{
Routes: types.Routes(routes).Proto(),
}, nil
}
func (api headscaleV1APIServer) DeleteRoute(
ctx context.Context,
request *v1.DeleteRouteRequest,
) (*v1.DeleteRouteResponse, error) {
isConnected := api.h.nodeNotifier.ConnectedMap()
update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
return db.DeleteRoute(tx, request.GetRouteId(), isConnected)
})
if err != nil {
return nil, err
}
if update != nil && update.Valid() {
ctx := types.NotifyCtx(ctx, "cli-deleteroute", "unknown")
api.h.nodeNotifier.NotifyWithIgnore(ctx, *update)
}
return &v1.DeleteRouteResponse{}, nil
}
func (api headscaleV1APIServer) CreateApiKey(
ctx context.Context,
request *v1.CreateApiKeyRequest,
) (*v1.CreateApiKeyResponse, error) {
var expiration time.Time
if request.GetExpiration() != nil {
expiration = request.GetExpiration().AsTime()
}
apiKey, _, err := api.h.db.CreateAPIKey(
&expiration,
)
if err != nil {
return nil, err
}
return &v1.CreateApiKeyResponse{ApiKey: apiKey}, nil
}
func (api headscaleV1APIServer) ExpireApiKey(
ctx context.Context,
request *v1.ExpireApiKeyRequest,
) (*v1.ExpireApiKeyResponse, error) {
var apiKey *types.APIKey
var err error
apiKey, err = api.h.db.GetAPIKey(request.Prefix)
if err != nil {
return nil, err
}
err = api.h.db.ExpireAPIKey(apiKey)
if err != nil {
return nil, err
}
return &v1.ExpireApiKeyResponse{}, nil
}
func (api headscaleV1APIServer) ListApiKeys(
ctx context.Context,
request *v1.ListApiKeysRequest,
) (*v1.ListApiKeysResponse, error) {
apiKeys, err := api.h.db.ListAPIKeys()
if err != nil {
return nil, err
}
response := make([]*v1.ApiKey, len(apiKeys))
for index, key := range apiKeys {
response[index] = key.Proto()
}
return &v1.ListApiKeysResponse{ApiKeys: response}, nil
}
// The following service calls are for testing and debugging
func (api headscaleV1APIServer) DebugCreateNode(
ctx context.Context,
request *v1.DebugCreateNodeRequest,
) (*v1.DebugCreateNodeResponse, error) {
user, err := api.h.db.GetUser(request.GetUser())
if err != nil {
return nil, err
}
routes, err := util.StringToIPPrefix(request.GetRoutes())
if err != nil {
return nil, err
}
log.Trace().
Caller().
Interface("route-prefix", routes).
Interface("route-str", request.GetRoutes()).
Msg("")
hostinfo := tailcfg.Hostinfo{
RoutableIPs: routes,
OS: "TestOS",
Hostname: "DebugTestNode",
}
var mkey key.MachinePublic
err = mkey.UnmarshalText([]byte(request.GetKey()))
if err != nil {
return nil, err
}
givenName, err := api.h.db.GenerateGivenName(mkey, request.GetName())
if err != nil {
return nil, err
}
nodeKey := key.NewNode()
newNode := types.Node{
MachineKey: mkey,
NodeKey: nodeKey.Public(),
Hostname: request.GetName(),
GivenName: givenName,
User: *user,
Expiry: &time.Time{},
LastSeen: &time.Time{},
Hostinfo: &hostinfo,
}
log.Debug().
Str("machine_key", mkey.ShortString()).
Msg("adding debug machine via CLI, appending to registration cache")
api.h.registrationCache.Set(
mkey.String(),
newNode,
registerCacheExpiration,
)
return &v1.DebugCreateNodeResponse{Node: newNode.Proto()}, nil
}
func (api headscaleV1APIServer) mustEmbedUnimplementedHeadscaleServiceServer() {}