headscale/grpcv1.go
Kristoffer Dalby 469551bc5d Register new machines needing callback in memory
This commit stores temporary registration data in cache, instead of
memory allowing us to only have actually registered machines in the
database.
2022-02-28 08:06:39 +00:00

427 lines
9.8 KiB
Go

//nolint
package headscale
import (
"context"
"encoding/json"
"time"
"github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/rs/zerolog/log"
"gorm.io/datatypes"
"tailscale.com/tailcfg"
)
type headscaleV1APIServer struct { // v1.HeadscaleServiceServer
v1.UnimplementedHeadscaleServiceServer
h *Headscale
}
func newHeadscaleV1APIServer(h *Headscale) v1.HeadscaleServiceServer {
return headscaleV1APIServer{
h: h,
}
}
func (api headscaleV1APIServer) GetNamespace(
ctx context.Context,
request *v1.GetNamespaceRequest,
) (*v1.GetNamespaceResponse, error) {
namespace, err := api.h.GetNamespace(request.GetName())
if err != nil {
return nil, err
}
return &v1.GetNamespaceResponse{Namespace: namespace.toProto()}, nil
}
func (api headscaleV1APIServer) CreateNamespace(
ctx context.Context,
request *v1.CreateNamespaceRequest,
) (*v1.CreateNamespaceResponse, error) {
namespace, err := api.h.CreateNamespace(request.GetName())
if err != nil {
return nil, err
}
return &v1.CreateNamespaceResponse{Namespace: namespace.toProto()}, nil
}
func (api headscaleV1APIServer) RenameNamespace(
ctx context.Context,
request *v1.RenameNamespaceRequest,
) (*v1.RenameNamespaceResponse, error) {
err := api.h.RenameNamespace(request.GetOldName(), request.GetNewName())
if err != nil {
return nil, err
}
namespace, err := api.h.GetNamespace(request.GetNewName())
if err != nil {
return nil, err
}
return &v1.RenameNamespaceResponse{Namespace: namespace.toProto()}, nil
}
func (api headscaleV1APIServer) DeleteNamespace(
ctx context.Context,
request *v1.DeleteNamespaceRequest,
) (*v1.DeleteNamespaceResponse, error) {
err := api.h.DestroyNamespace(request.GetName())
if err != nil {
return nil, err
}
return &v1.DeleteNamespaceResponse{}, nil
}
func (api headscaleV1APIServer) ListNamespaces(
ctx context.Context,
request *v1.ListNamespacesRequest,
) (*v1.ListNamespacesResponse, error) {
namespaces, err := api.h.ListNamespaces()
if err != nil {
return nil, err
}
response := make([]*v1.Namespace, len(namespaces))
for index, namespace := range namespaces {
response[index] = namespace.toProto()
}
log.Trace().Caller().Interface("namespaces", response).Msg("")
return &v1.ListNamespacesResponse{Namespaces: response}, nil
}
func (api headscaleV1APIServer) CreatePreAuthKey(
ctx context.Context,
request *v1.CreatePreAuthKeyRequest,
) (*v1.CreatePreAuthKeyResponse, error) {
var expiration time.Time
if request.GetExpiration() != nil {
expiration = request.GetExpiration().AsTime()
}
preAuthKey, err := api.h.CreatePreAuthKey(
request.GetNamespace(),
request.GetReusable(),
request.GetEphemeral(),
&expiration,
)
if err != nil {
return nil, err
}
return &v1.CreatePreAuthKeyResponse{PreAuthKey: preAuthKey.toProto()}, nil
}
func (api headscaleV1APIServer) ExpirePreAuthKey(
ctx context.Context,
request *v1.ExpirePreAuthKeyRequest,
) (*v1.ExpirePreAuthKeyResponse, error) {
preAuthKey, err := api.h.GetPreAuthKey(request.GetNamespace(), request.Key)
if err != nil {
return nil, err
}
err = api.h.ExpirePreAuthKey(preAuthKey)
if err != nil {
return nil, err
}
return &v1.ExpirePreAuthKeyResponse{}, nil
}
func (api headscaleV1APIServer) ListPreAuthKeys(
ctx context.Context,
request *v1.ListPreAuthKeysRequest,
) (*v1.ListPreAuthKeysResponse, error) {
preAuthKeys, err := api.h.ListPreAuthKeys(request.GetNamespace())
if err != nil {
return nil, err
}
response := make([]*v1.PreAuthKey, len(preAuthKeys))
for index, key := range preAuthKeys {
response[index] = key.toProto()
}
return &v1.ListPreAuthKeysResponse{PreAuthKeys: response}, nil
}
func (api headscaleV1APIServer) RegisterMachine(
ctx context.Context,
request *v1.RegisterMachineRequest,
) (*v1.RegisterMachineResponse, error) {
log.Trace().
Str("namespace", request.GetNamespace()).
Str("machine_key", request.GetKey()).
Msg("Registering machine")
// TODO(kradalby): Currently, if it fails to find a requested expiry, non will be set
// This means that if a user is to slow with register a machine, it will possibly not
// have the correct expiry.
requestedTime := time.Time{}
if requestedTimeIf, found := api.h.requestedExpiryCache.Get(request.GetKey()); found {
log.Trace().
Caller().
Str("machine", request.Key).
Msg("Expiry time found in cache, assigning to node")
if reqTime, ok := requestedTimeIf.(time.Time); ok {
requestedTime = reqTime
}
}
machine, err := api.h.RegisterMachineFromAuthCallback(
request.GetKey(),
request.GetNamespace(),
RegisterMethodCLI,
&requestedTime,
)
if err != nil {
return nil, err
}
return &v1.RegisterMachineResponse{Machine: machine.toProto()}, nil
}
func (api headscaleV1APIServer) GetMachine(
ctx context.Context,
request *v1.GetMachineRequest,
) (*v1.GetMachineResponse, error) {
machine, err := api.h.GetMachineByID(request.GetMachineId())
if err != nil {
return nil, err
}
return &v1.GetMachineResponse{Machine: machine.toProto()}, nil
}
func (api headscaleV1APIServer) DeleteMachine(
ctx context.Context,
request *v1.DeleteMachineRequest,
) (*v1.DeleteMachineResponse, error) {
machine, err := api.h.GetMachineByID(request.GetMachineId())
if err != nil {
return nil, err
}
err = api.h.DeleteMachine(
machine,
)
if err != nil {
return nil, err
}
return &v1.DeleteMachineResponse{}, nil
}
func (api headscaleV1APIServer) ExpireMachine(
ctx context.Context,
request *v1.ExpireMachineRequest,
) (*v1.ExpireMachineResponse, error) {
machine, err := api.h.GetMachineByID(request.GetMachineId())
if err != nil {
return nil, err
}
api.h.ExpireMachine(
machine,
)
log.Trace().
Str("machine", machine.Name).
Time("expiry", *machine.Expiry).
Msg("machine expired")
return &v1.ExpireMachineResponse{Machine: machine.toProto()}, nil
}
func (api headscaleV1APIServer) ListMachines(
ctx context.Context,
request *v1.ListMachinesRequest,
) (*v1.ListMachinesResponse, error) {
if request.GetNamespace() != "" {
machines, err := api.h.ListMachinesInNamespace(request.GetNamespace())
if err != nil {
return nil, err
}
response := make([]*v1.Machine, len(machines))
for index, machine := range machines {
response[index] = machine.toProto()
}
return &v1.ListMachinesResponse{Machines: response}, nil
}
machines, err := api.h.ListMachines()
if err != nil {
return nil, err
}
response := make([]*v1.Machine, len(machines))
for index, machine := range machines {
response[index] = machine.toProto()
}
return &v1.ListMachinesResponse{Machines: response}, nil
}
func (api headscaleV1APIServer) GetMachineRoute(
ctx context.Context,
request *v1.GetMachineRouteRequest,
) (*v1.GetMachineRouteResponse, error) {
machine, err := api.h.GetMachineByID(request.GetMachineId())
if err != nil {
return nil, err
}
routes, err := machine.RoutesToProto()
if err != nil {
return nil, err
}
return &v1.GetMachineRouteResponse{
Routes: routes,
}, nil
}
func (api headscaleV1APIServer) EnableMachineRoutes(
ctx context.Context,
request *v1.EnableMachineRoutesRequest,
) (*v1.EnableMachineRoutesResponse, error) {
machine, err := api.h.GetMachineByID(request.GetMachineId())
if err != nil {
return nil, err
}
err = api.h.EnableRoutes(machine, request.GetRoutes()...)
if err != nil {
return nil, err
}
routes, err := machine.RoutesToProto()
if err != nil {
return nil, err
}
return &v1.EnableMachineRoutesResponse{
Routes: routes,
}, nil
}
func (api headscaleV1APIServer) CreateApiKey(
ctx context.Context,
request *v1.CreateApiKeyRequest,
) (*v1.CreateApiKeyResponse, error) {
var expiration time.Time
if request.GetExpiration() != nil {
expiration = request.GetExpiration().AsTime()
}
apiKey, _, err := api.h.CreateAPIKey(
&expiration,
)
if err != nil {
return nil, err
}
return &v1.CreateApiKeyResponse{ApiKey: apiKey}, nil
}
func (api headscaleV1APIServer) ExpireApiKey(
ctx context.Context,
request *v1.ExpireApiKeyRequest,
) (*v1.ExpireApiKeyResponse, error) {
var apiKey *APIKey
var err error
apiKey, err = api.h.GetAPIKey(request.Prefix)
if err != nil {
return nil, err
}
err = api.h.ExpireAPIKey(apiKey)
if err != nil {
return nil, err
}
return &v1.ExpireApiKeyResponse{}, nil
}
func (api headscaleV1APIServer) ListApiKeys(
ctx context.Context,
request *v1.ListApiKeysRequest,
) (*v1.ListApiKeysResponse, error) {
apiKeys, err := api.h.ListAPIKeys()
if err != nil {
return nil, err
}
response := make([]*v1.ApiKey, len(apiKeys))
for index, key := range apiKeys {
response[index] = key.toProto()
}
return &v1.ListApiKeysResponse{ApiKeys: response}, nil
}
// The following service calls are for testing and debugging
func (api headscaleV1APIServer) DebugCreateMachine(
ctx context.Context,
request *v1.DebugCreateMachineRequest,
) (*v1.DebugCreateMachineResponse, error) {
namespace, err := api.h.GetNamespace(request.GetNamespace())
if err != nil {
return nil, err
}
routes, err := stringToIPPrefix(request.GetRoutes())
if err != nil {
return nil, err
}
log.Trace().
Caller().
Interface("route-prefix", routes).
Interface("route-str", request.GetRoutes()).
Msg("")
hostinfo := tailcfg.Hostinfo{
RoutableIPs: routes,
OS: "TestOS",
Hostname: "DebugTestMachine",
}
log.Trace().Caller().Interface("hostinfo", hostinfo).Msg("")
hostinfoJson, err := json.Marshal(hostinfo)
if err != nil {
return nil, err
}
newMachine := Machine{
MachineKey: request.GetKey(),
Name: request.GetName(),
Namespace: *namespace,
Expiry: &time.Time{},
LastSeen: &time.Time{},
LastSuccessfulUpdate: &time.Time{},
HostInfo: datatypes.JSON(hostinfoJson),
}
// log.Trace().Caller().Interface("machine", newMachine).Msg("")
if err := api.h.db.Create(&newMachine).Error; err != nil {
return nil, err
}
return &v1.DebugCreateMachineResponse{Machine: newMachine.toProto()}, nil
}
func (api headscaleV1APIServer) mustEmbedUnimplementedHeadscaleServiceServer() {}