Merge pull request #656 from juanfont/abandon-gin

Drop Gin as web framework for TS2019 API
This commit is contained in:
Juan Font 2022-06-26 15:54:41 +02:00 committed by GitHub
commit 4a200c308b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 966 additions and 410 deletions

View file

@ -30,8 +30,11 @@
- Add -c option to specify config file from command line [#285](https://github.com/juanfont/headscale/issues/285) [#612](https://github.com/juanfont/headscale/pull/601) - Add -c option to specify config file from command line [#285](https://github.com/juanfont/headscale/issues/285) [#612](https://github.com/juanfont/headscale/pull/601)
- Add configuration option to allow Tailscale clients to use a random WireGuard port. [kb/1181/firewalls](https://tailscale.com/kb/1181/firewalls) [#624](https://github.com/juanfont/headscale/pull/624) - Add configuration option to allow Tailscale clients to use a random WireGuard port. [kb/1181/firewalls](https://tailscale.com/kb/1181/firewalls) [#624](https://github.com/juanfont/headscale/pull/624)
- Improve obtuse UX regarding missing configuration (`ephemeral_node_inactivity_timeout` not set) [#639](https://github.com/juanfont/headscale/pull/639) - Improve obtuse UX regarding missing configuration (`ephemeral_node_inactivity_timeout` not set) [#639](https://github.com/juanfont/headscale/pull/639)
- Fix nodes being shown as 'offline' in `tailscale status` [648](https://github.com/juanfont/headscale/pull/648)
- Fix nodes being shown as 'offline' in `tailscale status` [#648](https://github.com/juanfont/headscale/pull/648) - Fix nodes being shown as 'offline' in `tailscale status` [#648](https://github.com/juanfont/headscale/pull/648)
- Improve shutdown behaviour [#651](https://github.com/juanfont/headscale/pull/651) - Improve shutdown behaviour [#651](https://github.com/juanfont/headscale/pull/651)
- Drop Gin as web framework in Headscale [648](https://github.com/juanfont/headscale/pull/648)
## 0.15.0 (2022-03-20) ## 0.15.0 (2022-03-20)

View file

@ -37,7 +37,7 @@ const (
expectedTokenItems = 2 expectedTokenItems = 2
) )
// For some reason golang.org/x/net/internal/iana is an internal package // For some reason golang.org/x/net/internal/iana is an internal package.
const ( const (
protocolICMP = 1 // Internet Control Message protocolICMP = 1 // Internet Control Message
protocolIGMP = 2 // Internet Group Management protocolIGMP = 2 // Internet Group Management

293
api.go
View file

@ -12,7 +12,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gorilla/mux"
"github.com/klauspost/compress/zstd" "github.com/klauspost/compress/zstd"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"gorm.io/gorm" "gorm.io/gorm"
@ -32,12 +32,19 @@ const (
// KeyHandler provides the Headscale pub key // KeyHandler provides the Headscale pub key
// Listens in /key. // Listens in /key.
func (h *Headscale) KeyHandler(ctx *gin.Context) { func (h *Headscale) KeyHandler(
ctx.Data( writer http.ResponseWriter,
http.StatusOK, req *http.Request,
"text/plain; charset=utf-8", ) {
[]byte(MachinePublicKeyStripPrefix(h.privateKey.Public())), writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
) writer.WriteHeader(http.StatusOK)
_, err := writer.Write([]byte(MachinePublicKeyStripPrefix(h.privateKey.Public())))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
} }
type registerWebAPITemplateConfig struct { type registerWebAPITemplateConfig struct {
@ -63,10 +70,21 @@ var registerWebAPITemplate = template.Must(
// RegisterWebAPI shows a simple message in the browser to point to the CLI // RegisterWebAPI shows a simple message in the browser to point to the CLI
// Listens in /register. // Listens in /register.
func (h *Headscale) RegisterWebAPI(ctx *gin.Context) { func (h *Headscale) RegisterWebAPI(
machineKeyStr := ctx.Query("key") writer http.ResponseWriter,
req *http.Request,
) {
machineKeyStr := req.URL.Query().Get("key")
if machineKeyStr == "" { if machineKeyStr == "" {
ctx.String(http.StatusBadRequest, "Wrong params") writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("Wrong params"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return return
} }
@ -79,21 +97,48 @@ func (h *Headscale) RegisterWebAPI(ctx *gin.Context) {
Str("func", "RegisterWebAPI"). Str("func", "RegisterWebAPI").
Err(err). Err(err).
Msg("Could not render register web API template") Msg("Could not render register web API template")
ctx.Data( writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
http.StatusInternalServerError, writer.WriteHeader(http.StatusInternalServerError)
"text/html; charset=utf-8", _, err = writer.Write([]byte("Could not render register web API template"))
[]byte("Could not render register web API template"), if err != nil {
) log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return
} }
ctx.Data(http.StatusOK, "text/html; charset=utf-8", content.Bytes()) writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err := writer.Write(content.Bytes())
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
} }
// RegistrationHandler handles the actual registration process of a machine // RegistrationHandler handles the actual registration process of a machine
// Endpoint /machine/:id. // Endpoint /machine/:mkey.
func (h *Headscale) RegistrationHandler(ctx *gin.Context) { func (h *Headscale) RegistrationHandler(
body, _ := io.ReadAll(ctx.Request.Body) writer http.ResponseWriter,
machineKeyStr := ctx.Param("id") req *http.Request,
) {
vars := mux.Vars(req)
machineKeyStr, ok := vars["mkey"]
if !ok || machineKeyStr == "" {
log.Error().
Str("handler", "RegistrationHandler").
Msg("No machine ID in request")
http.Error(writer, "No machine ID in request", http.StatusBadRequest)
return
}
body, _ := io.ReadAll(req.Body)
var machineKey key.MachinePublic var machineKey key.MachinePublic
err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr))) err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr)))
@ -103,19 +148,19 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
Err(err). Err(err).
Msg("Cannot parse machine key") Msg("Cannot parse machine key")
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc() machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
ctx.String(http.StatusInternalServerError, "Sad!") http.Error(writer, "Cannot parse machine key", http.StatusBadRequest)
return return
} }
req := tailcfg.RegisterRequest{} registerRequest := tailcfg.RegisterRequest{}
err = decode(body, &req, &machineKey, h.privateKey) err = decode(body, &registerRequest, &machineKey, h.privateKey)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
Err(err). Err(err).
Msg("Cannot decode message") Msg("Cannot decode message")
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc() machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
ctx.String(http.StatusInternalServerError, "Very sad!") http.Error(writer, "Cannot decode message", http.StatusBadRequest)
return return
} }
@ -123,23 +168,23 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
now := time.Now().UTC() now := time.Now().UTC()
machine, err := h.GetMachineByMachineKey(machineKey) machine, err := h.GetMachineByMachineKey(machineKey)
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine") log.Info().Str("machine", registerRequest.Hostinfo.Hostname).Msg("New machine")
machineKeyStr := MachinePublicKeyStripPrefix(machineKey) machineKeyStr := MachinePublicKeyStripPrefix(machineKey)
// If the machine has AuthKey set, handle registration via PreAuthKeys // If the machine has AuthKey set, handle registration via PreAuthKeys
if req.Auth.AuthKey != "" { if registerRequest.Auth.AuthKey != "" {
h.handleAuthKey(ctx, machineKey, req) h.handleAuthKey(writer, req, machineKey, registerRequest)
return return
} }
givenName, err := h.GenerateGivenName(req.Hostinfo.Hostname) givenName, err := h.GenerateGivenName(registerRequest.Hostinfo.Hostname)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
Str("func", "RegistrationHandler"). Str("func", "RegistrationHandler").
Str("hostinfo.name", req.Hostinfo.Hostname). Str("hostinfo.name", registerRequest.Hostinfo.Hostname).
Err(err) Err(err)
return return
@ -151,20 +196,20 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
// happens // happens
newMachine := Machine{ newMachine := Machine{
MachineKey: machineKeyStr, MachineKey: machineKeyStr,
Hostname: req.Hostinfo.Hostname, Hostname: registerRequest.Hostinfo.Hostname,
GivenName: givenName, GivenName: givenName,
NodeKey: NodePublicKeyStripPrefix(req.NodeKey), NodeKey: NodePublicKeyStripPrefix(registerRequest.NodeKey),
LastSeen: &now, LastSeen: &now,
Expiry: &time.Time{}, Expiry: &time.Time{},
} }
if !req.Expiry.IsZero() { if !registerRequest.Expiry.IsZero() {
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine", req.Hostinfo.Hostname). Str("machine", registerRequest.Hostinfo.Hostname).
Time("expiry", req.Expiry). Time("expiry", registerRequest.Expiry).
Msg("Non-zero expiry time requested") Msg("Non-zero expiry time requested")
newMachine.Expiry = &req.Expiry newMachine.Expiry = &registerRequest.Expiry
} }
h.registrationCache.Set( h.registrationCache.Set(
@ -173,7 +218,7 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
registerCacheExpiration, registerCacheExpiration,
) )
h.handleMachineRegistrationNew(ctx, machineKey, req) h.handleMachineRegistrationNew(writer, req, machineKey, registerRequest)
return return
} }
@ -185,11 +230,11 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
// - Trying to log out (sending a expiry in the past) // - Trying to log out (sending a expiry in the past)
// - A valid, registered machine, looking for the node map // - A valid, registered machine, looking for the node map
// - Expired machine wanting to reauthenticate // - Expired machine wanting to reauthenticate
if machine.NodeKey == NodePublicKeyStripPrefix(req.NodeKey) { if machine.NodeKey == NodePublicKeyStripPrefix(registerRequest.NodeKey) {
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout) // The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648 // https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) { if !registerRequest.Expiry.IsZero() && registerRequest.Expiry.UTC().Before(now) {
h.handleMachineLogOut(ctx, machineKey, *machine) h.handleMachineLogOut(writer, req, machineKey, *machine)
return return
} }
@ -197,22 +242,22 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
// If machine is not expired, and is register, we have a already accepted this machine, // If machine is not expired, and is register, we have a already accepted this machine,
// let it proceed with a valid registration // let it proceed with a valid registration
if !machine.isExpired() { if !machine.isExpired() {
h.handleMachineValidRegistration(ctx, machineKey, *machine) h.handleMachineValidRegistration(writer, req, machineKey, *machine)
return return
} }
} }
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration // The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
if machine.NodeKey == NodePublicKeyStripPrefix(req.OldNodeKey) && if machine.NodeKey == NodePublicKeyStripPrefix(registerRequest.OldNodeKey) &&
!machine.isExpired() { !machine.isExpired() {
h.handleMachineRefreshKey(ctx, machineKey, req, *machine) h.handleMachineRefreshKey(writer, req, machineKey, registerRequest, *machine)
return return
} }
// The machine has expired // The machine has expired
h.handleMachineExpired(ctx, machineKey, req, *machine) h.handleMachineExpired(writer, req, machineKey, registerRequest, *machine)
return return
} }
@ -220,12 +265,12 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
func (h *Headscale) getMapResponse( func (h *Headscale) getMapResponse(
machineKey key.MachinePublic, machineKey key.MachinePublic,
req tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
machine *Machine, machine *Machine,
) ([]byte, error) { ) ([]byte, error) {
log.Trace(). log.Trace().
Str("func", "getMapResponse"). Str("func", "getMapResponse").
Str("machine", req.Hostinfo.Hostname). Str("machine", mapRequest.Hostinfo.Hostname).
Msg("Creating Map response") Msg("Creating Map response")
node, err := machine.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig, true) node, err := machine.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
if err != nil { if err != nil {
@ -286,12 +331,12 @@ func (h *Headscale) getMapResponse(
log.Trace(). log.Trace().
Str("func", "getMapResponse"). Str("func", "getMapResponse").
Str("machine", req.Hostinfo.Hostname). Str("machine", mapRequest.Hostinfo.Hostname).
// Interface("payload", resp). // Interface("payload", resp).
Msgf("Generated map response: %s", tailMapResponseToString(resp)) Msgf("Generated map response: %s", tailMapResponseToString(resp))
var respBody []byte var respBody []byte
if req.Compress == "zstd" { if mapRequest.Compress == "zstd" {
src, err := json.Marshal(resp) src, err := json.Marshal(resp)
if err != nil { if err != nil {
log.Error(). log.Error().
@ -357,7 +402,8 @@ func (h *Headscale) getMapKeepAliveResponse(
} }
func (h *Headscale) handleMachineLogOut( func (h *Headscale) handleMachineLogOut(
ctx *gin.Context, writer http.ResponseWriter,
req *http.Request,
machineKey key.MachinePublic, machineKey key.MachinePublic,
machine Machine, machine Machine,
) { ) {
@ -367,7 +413,17 @@ func (h *Headscale) handleMachineLogOut(
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Msg("Client requested logout") Msg("Client requested logout")
h.ExpireMachine(&machine) err := h.ExpireMachine(&machine)
if err != nil {
log.Error().
Caller().
Str("func", "handleMachineLogOut").
Err(err).
Msg("Failed to expire machine")
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
resp.AuthURL = "" resp.AuthURL = ""
resp.MachineAuthorized = false resp.MachineAuthorized = false
@ -378,15 +434,25 @@ func (h *Headscale) handleMachineLogOut(
Caller(). Caller().
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
ctx.String(http.StatusInternalServerError, "") http.Error(writer, "Internal server error", http.StatusInternalServerError)
return return
} }
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(respBody)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
} }
func (h *Headscale) handleMachineValidRegistration( func (h *Headscale) handleMachineValidRegistration(
ctx *gin.Context, writer http.ResponseWriter,
req *http.Request,
machineKey key.MachinePublic, machineKey key.MachinePublic,
machine Machine, machine Machine,
) { ) {
@ -410,17 +476,27 @@ func (h *Headscale) handleMachineValidRegistration(
Msg("Cannot encode message") Msg("Cannot encode message")
machineRegistrations.WithLabelValues("update", "web", "error", machine.Namespace.Name). machineRegistrations.WithLabelValues("update", "web", "error", machine.Namespace.Name).
Inc() Inc()
ctx.String(http.StatusInternalServerError, "") http.Error(writer, "Internal server error", http.StatusInternalServerError)
return return
} }
machineRegistrations.WithLabelValues("update", "web", "success", machine.Namespace.Name). machineRegistrations.WithLabelValues("update", "web", "success", machine.Namespace.Name).
Inc() Inc()
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(respBody)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
} }
func (h *Headscale) handleMachineExpired( func (h *Headscale) handleMachineExpired(
ctx *gin.Context, writer http.ResponseWriter,
req *http.Request,
machineKey key.MachinePublic, machineKey key.MachinePublic,
registerRequest tailcfg.RegisterRequest, registerRequest tailcfg.RegisterRequest,
machine Machine, machine Machine,
@ -433,7 +509,7 @@ func (h *Headscale) handleMachineExpired(
Msg("Machine registration has expired. Sending a authurl to register") Msg("Machine registration has expired. Sending a authurl to register")
if registerRequest.Auth.AuthKey != "" { if registerRequest.Auth.AuthKey != "" {
h.handleAuthKey(ctx, machineKey, registerRequest) h.handleAuthKey(writer, req, machineKey, registerRequest)
return return
} }
@ -454,17 +530,27 @@ func (h *Headscale) handleMachineExpired(
Msg("Cannot encode message") Msg("Cannot encode message")
machineRegistrations.WithLabelValues("reauth", "web", "error", machine.Namespace.Name). machineRegistrations.WithLabelValues("reauth", "web", "error", machine.Namespace.Name).
Inc() Inc()
ctx.String(http.StatusInternalServerError, "") http.Error(writer, "Internal server error", http.StatusInternalServerError)
return return
} }
machineRegistrations.WithLabelValues("reauth", "web", "success", machine.Namespace.Name). machineRegistrations.WithLabelValues("reauth", "web", "success", machine.Namespace.Name).
Inc() Inc()
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(respBody)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
} }
func (h *Headscale) handleMachineRefreshKey( func (h *Headscale) handleMachineRefreshKey(
ctx *gin.Context, writer http.ResponseWriter,
req *http.Request,
machineKey key.MachinePublic, machineKey key.MachinePublic,
registerRequest tailcfg.RegisterRequest, registerRequest tailcfg.RegisterRequest,
machine Machine, machine Machine,
@ -481,7 +567,7 @@ func (h *Headscale) handleMachineRefreshKey(
Caller(). Caller().
Err(err). Err(err).
Msg("Failed to update machine key in the database") Msg("Failed to update machine key in the database")
ctx.String(http.StatusInternalServerError, "Internal server error") http.Error(writer, "Internal server error", http.StatusInternalServerError)
return return
} }
@ -494,15 +580,25 @@ func (h *Headscale) handleMachineRefreshKey(
Caller(). Caller().
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
ctx.String(http.StatusInternalServerError, "Internal server error") http.Error(writer, "Internal server error", http.StatusInternalServerError)
return return
} }
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(respBody)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
} }
func (h *Headscale) handleMachineRegistrationNew( func (h *Headscale) handleMachineRegistrationNew(
ctx *gin.Context, writer http.ResponseWriter,
req *http.Request,
machineKey key.MachinePublic, machineKey key.MachinePublic,
registerRequest tailcfg.RegisterRequest, registerRequest tailcfg.RegisterRequest,
) { ) {
@ -529,16 +625,26 @@ func (h *Headscale) handleMachineRegistrationNew(
Caller(). Caller().
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
ctx.String(http.StatusInternalServerError, "") http.Error(writer, "Internal server error", http.StatusInternalServerError)
return return
} }
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(respBody)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
} }
// TODO: check if any locks are needed around IP allocation. // TODO: check if any locks are needed around IP allocation.
func (h *Headscale) handleAuthKey( func (h *Headscale) handleAuthKey(
ctx *gin.Context, writer http.ResponseWriter,
req *http.Request,
machineKey key.MachinePublic, machineKey key.MachinePublic,
registerRequest tailcfg.RegisterRequest, registerRequest tailcfg.RegisterRequest,
) { ) {
@ -567,14 +673,23 @@ func (h *Headscale) handleAuthKey(
Str("machine", registerRequest.Hostinfo.Hostname). Str("machine", registerRequest.Hostinfo.Hostname).
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
ctx.String(http.StatusInternalServerError, "") http.Error(writer, "Internal server error", http.StatusInternalServerError)
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name). machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name).
Inc() Inc()
return return
} }
ctx.Data(http.StatusUnauthorized, "application/json; charset=utf-8", respBody) writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusUnauthorized)
_, err = writer.Write(respBody)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
log.Error(). log.Error().
Caller(). Caller().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
@ -611,7 +726,16 @@ func (h *Headscale) handleAuthKey(
machine.NodeKey = nodeKey machine.NodeKey = nodeKey
machine.AuthKeyID = uint(pak.ID) machine.AuthKeyID = uint(pak.ID)
h.RefreshMachine(machine, registerRequest.Expiry) err := h.RefreshMachine(machine, registerRequest.Expiry)
if err != nil {
log.Error().
Caller().
Str("machine", machine.Hostname).
Err(err).
Msg("Failed to refresh machine")
return
}
} else { } else {
now := time.Now().UTC() now := time.Now().UTC()
@ -648,16 +772,24 @@ func (h *Headscale) handleAuthKey(
Msg("could not register machine") Msg("could not register machine")
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name). machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name).
Inc() Inc()
ctx.String( http.Error(writer, "Internal server error", http.StatusInternalServerError)
http.StatusInternalServerError,
"could not register machine",
)
return return
} }
} }
h.UsePreAuthKey(pak) err = h.UsePreAuthKey(pak)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to use pre-auth key")
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name).
Inc()
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
resp.MachineAuthorized = true resp.MachineAuthorized = true
resp.User = *pak.Namespace.toUser() resp.User = *pak.Namespace.toUser()
@ -671,13 +803,22 @@ func (h *Headscale) handleAuthKey(
Msg("Cannot encode message") Msg("Cannot encode message")
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name). machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name).
Inc() Inc()
ctx.String(http.StatusInternalServerError, "Extremely sad!") http.Error(writer, "Internal server error", http.StatusInternalServerError)
return return
} }
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "success", pak.Namespace.Name). machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "success", pak.Namespace.Name).
Inc() Inc()
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(respBody)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
log.Info(). log.Info().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", registerRequest.Hostinfo.Hostname). Str("machine", registerRequest.Hostinfo.Hostname).

166
app.go
View file

@ -18,6 +18,7 @@ import (
"github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gorilla/mux"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
@ -93,6 +94,8 @@ type Headscale struct {
registrationCache *cache.Cache registrationCache *cache.Cache
ipAllocationMutex sync.Mutex ipAllocationMutex sync.Mutex
shutdownChan chan struct{}
} }
// Look up the TLS constant relative to user-supplied TLS client // Look up the TLS constant relative to user-supplied TLS client
@ -327,48 +330,74 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
return handler(ctx, req) return handler(ctx, req)
} }
func (h *Headscale) httpAuthenticationMiddleware(ctx *gin.Context) { func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler {
log.Trace(). return http.HandlerFunc(func(
Caller(). writer http.ResponseWriter,
Str("client_address", ctx.ClientIP()). req *http.Request,
Msg("HTTP authentication invoked") ) {
log.Trace().
authHeader := ctx.GetHeader("authorization")
if !strings.HasPrefix(authHeader, AuthPrefix) {
log.Error().
Caller(). Caller().
Str("client_address", ctx.ClientIP()). Str("client_address", req.RemoteAddr).
Msg(`missing "Bearer " prefix in "Authorization" header`) Msg("HTTP authentication invoked")
ctx.AbortWithStatus(http.StatusUnauthorized)
return authHeader := req.Header.Get("authorization")
}
valid, err := h.ValidateAPIKey(strings.TrimPrefix(authHeader, AuthPrefix)) if !strings.HasPrefix(authHeader, AuthPrefix) {
if err != nil { log.Error().
log.Error(). Caller().
Caller(). Str("client_address", req.RemoteAddr).
Err(err). Msg(`missing "Bearer " prefix in "Authorization" header`)
Str("client_address", ctx.ClientIP()). writer.WriteHeader(http.StatusUnauthorized)
Msg("failed to validate token") _, err := writer.Write([]byte("Unauthorized"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
ctx.AbortWithStatus(http.StatusInternalServerError) return
}
return valid, err := h.ValidateAPIKey(strings.TrimPrefix(authHeader, AuthPrefix))
} if err != nil {
log.Error().
Caller().
Err(err).
Str("client_address", req.RemoteAddr).
Msg("failed to validate token")
if !valid { writer.WriteHeader(http.StatusInternalServerError)
log.Info(). _, err := writer.Write([]byte("Unauthorized"))
Str("client_address", ctx.ClientIP()). if err != nil {
Msg("invalid token") log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
ctx.AbortWithStatus(http.StatusUnauthorized) return
}
return if !valid {
} log.Info().
Str("client_address", req.RemoteAddr).
Msg("invalid token")
ctx.Next() writer.WriteHeader(http.StatusUnauthorized)
_, err := writer.Write([]byte("Unauthorized"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return
}
next.ServeHTTP(writer, req)
})
} }
// ensureUnixSocketIsAbsent will check if the given path for headscales unix socket is clear // ensureUnixSocketIsAbsent will check if the given path for headscales unix socket is clear
@ -391,39 +420,48 @@ func (h *Headscale) createPrometheusRouter() *gin.Engine {
return promRouter return promRouter
} }
func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *gin.Engine { func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *mux.Router {
router := gin.Default() router := mux.NewRouter()
router.GET( router.HandleFunc(
"/health", "/health",
func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"healthy": "ok"}) }, func(writer http.ResponseWriter, req *http.Request) {
) writer.WriteHeader(http.StatusOK)
router.GET("/key", h.KeyHandler) _, err := writer.Write([]byte("{\"healthy\": \"ok\"}"))
router.GET("/register", h.RegisterWebAPI) if err != nil {
router.POST("/machine/:id/map", h.PollNetMapHandler) log.Error().
router.POST("/machine/:id", h.RegistrationHandler) Caller().
router.GET("/oidc/register/:mkey", h.RegisterOIDC) Err(err).
router.GET("/oidc/callback", h.OIDCCallback) Msg("Failed to write response")
router.GET("/apple", h.AppleConfigMessage) }
router.GET("/apple/:platform", h.ApplePlatformConfig) }).Methods(http.MethodGet)
router.GET("/windows", h.WindowsConfigMessage)
router.GET("/windows/tailscale.reg", h.WindowsRegConfig) router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet)
router.GET("/swagger", SwaggerUI) router.HandleFunc("/register", h.RegisterWebAPI).Methods(http.MethodGet)
router.GET("/swagger/v1/openapiv2.json", SwaggerAPIv1) router.HandleFunc("/machine/{mkey}/map", h.PollNetMapHandler).Methods(http.MethodPost)
router.HandleFunc("/machine/{mkey}", h.RegistrationHandler).Methods(http.MethodPost)
router.HandleFunc("/oidc/register/{mkey}", h.RegisterOIDC).Methods(http.MethodGet)
router.HandleFunc("/oidc/callback", h.OIDCCallback).Methods(http.MethodGet)
router.HandleFunc("/apple", h.AppleConfigMessage).Methods(http.MethodGet)
router.HandleFunc("/apple/{platform}", h.ApplePlatformConfig).Methods(http.MethodGet)
router.HandleFunc("/windows", h.WindowsConfigMessage).Methods(http.MethodGet)
router.HandleFunc("/windows/tailscale.reg", h.WindowsRegConfig).Methods(http.MethodGet)
router.HandleFunc("/swagger", SwaggerUI).Methods(http.MethodGet)
router.HandleFunc("/swagger/v1/openapiv2.json", SwaggerAPIv1).Methods(http.MethodGet)
if h.cfg.DERP.ServerEnabled { if h.cfg.DERP.ServerEnabled {
router.Any("/derp", h.DERPHandler) router.HandleFunc("/derp", h.DERPHandler)
router.Any("/derp/probe", h.DERPProbeHandler) router.HandleFunc("/derp/probe", h.DERPProbeHandler)
router.Any("/bootstrap-dns", h.DERPBootstrapDNSHandler) router.HandleFunc("/bootstrap-dns", h.DERPBootstrapDNSHandler)
} }
api := router.Group("/api") api := router.PathPrefix("/api").Subrouter()
api.Use(h.httpAuthenticationMiddleware) api.Use(h.httpAuthenticationMiddleware)
{ {
api.Any("/v1/*any", gin.WrapF(grpcMux.ServeHTTP)) api.HandleFunc("/v1/*any", grpcMux.ServeHTTP)
} }
router.NoRoute(stdoutHandler) router.PathPrefix("/").HandlerFunc(stdoutHandler)
return router return router
} }
@ -631,6 +669,7 @@ func (h *Headscale) Serve() error {
Msgf("listening and serving metrics on: %s", h.cfg.MetricsAddr) Msgf("listening and serving metrics on: %s", h.cfg.MetricsAddr)
// Handle common process-killing signals so we can gracefully shut down: // Handle common process-killing signals so we can gracefully shut down:
h.shutdownChan = make(chan struct{})
sigc := make(chan os.Signal, 1) sigc := make(chan os.Signal, 1)
signal.Notify(sigc, signal.Notify(sigc,
syscall.SIGHUP, syscall.SIGHUP,
@ -668,6 +707,8 @@ func (h *Headscale) Serve() error {
Str("signal", sig.String()). Str("signal", sig.String()).
Msg("Received signal to stop, shutting down gracefully") Msg("Received signal to stop, shutting down gracefully")
h.shutdownChan <- struct{}{}
// Gracefully shut down servers // Gracefully shut down servers
ctx, cancel := context.WithTimeout(context.Background(), HTTPShutdownTimeout) ctx, cancel := context.WithTimeout(context.Background(), HTTPShutdownTimeout)
if err := promHTTPServer.Shutdown(ctx); err != nil { if err := promHTTPServer.Shutdown(ctx); err != nil {
@ -831,13 +872,16 @@ func (h *Headscale) getLastStateChange(namespaces ...string) time.Time {
} }
} }
func stdoutHandler(ctx *gin.Context) { func stdoutHandler(
body, _ := io.ReadAll(ctx.Request.Body) writer http.ResponseWriter,
req *http.Request,
) {
body, _ := io.ReadAll(req.Body)
log.Trace(). log.Trace().
Interface("header", ctx.Request.Header). Interface("header", req.Header).
Interface("proto", ctx.Request.Proto). Interface("proto", req.Proto).
Interface("url", ctx.Request.URL). Interface("url", req.URL).
Bytes("body", body). Bytes("body", body).
Msg("Request did not match") Msg("Request did not match")
} }

5
db.go
View file

@ -89,7 +89,7 @@ func (h *Headscale) initDB() error {
log.Error().Err(err).Msg("Error accessing db") log.Error().Err(err).Msg("Error accessing db")
} }
for _, machine := range machines { for item, machine := range machines {
if machine.GivenName == "" { if machine.GivenName == "" {
normalizedHostname, err := NormalizeToFQDNRules( normalizedHostname, err := NormalizeToFQDNRules(
machine.Hostname, machine.Hostname,
@ -103,7 +103,7 @@ func (h *Headscale) initDB() error {
Msg("Failed to normalize machine hostname in DB migration") Msg("Failed to normalize machine hostname in DB migration")
} }
err = h.RenameMachine(&machine, normalizedHostname) err = h.RenameMachine(&machines[item], normalizedHostname)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
@ -111,7 +111,6 @@ func (h *Headscale) initDB() error {
Err(err). Err(err).
Msg("Failed to save normalized machine name in DB migration") Msg("Failed to save normalized machine name in DB migration")
} }
} }
} }
} }

View file

@ -2,6 +2,7 @@ package headscale
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@ -10,7 +11,6 @@ import (
"strings" "strings"
"time" "time"
"github.com/gin-gonic/gin"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"tailscale.com/derp" "tailscale.com/derp"
"tailscale.com/net/stun" "tailscale.com/net/stun"
@ -30,6 +30,7 @@ type DERPServer struct {
} }
func (h *Headscale) NewDERPServer() (*DERPServer, error) { func (h *Headscale) NewDERPServer() (*DERPServer, error) {
log.Trace().Caller().Msg("Creating new embedded DERP server")
server := derp.NewServer(key.NodePrivate(*h.privateKey), log.Info().Msgf) server := derp.NewServer(key.NodePrivate(*h.privateKey), log.Info().Msgf)
region, err := h.generateRegionLocalDERP() region, err := h.generateRegionLocalDERP()
if err != nil { if err != nil {
@ -87,30 +88,48 @@ func (h *Headscale) generateRegionLocalDERP() (tailcfg.DERPRegion, error) {
} }
localDERPregion.Nodes[0].STUNPort = portSTUN localDERPregion.Nodes[0].STUNPort = portSTUN
log.Info().Caller().Msgf("DERP region: %+v", localDERPregion)
return localDERPregion, nil return localDERPregion, nil
} }
func (h *Headscale) DERPHandler(ctx *gin.Context) { func (h *Headscale) DERPHandler(
log.Trace().Caller().Msgf("/derp request from %v", ctx.ClientIP()) writer http.ResponseWriter,
up := strings.ToLower(ctx.Request.Header.Get("Upgrade")) req *http.Request,
) {
log.Trace().Caller().Msgf("/derp request from %v", req.RemoteAddr)
up := strings.ToLower(req.Header.Get("Upgrade"))
if up != "websocket" && up != "derp" { if up != "websocket" && up != "derp" {
if up != "" { if up != "" {
log.Warn().Caller().Msgf("Weird websockets connection upgrade: %q", up) log.Warn().Caller().Msgf("Weird websockets connection upgrade: %q", up)
} }
ctx.String(http.StatusUpgradeRequired, "DERP requires connection upgrade") writer.Header().Set("Content-Type", "text/plain")
writer.WriteHeader(http.StatusUpgradeRequired)
_, err := writer.Write([]byte("DERP requires connection upgrade"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return return
} }
fastStart := ctx.Request.Header.Get(fastStartHeader) == "1" fastStart := req.Header.Get(fastStartHeader) == "1"
hijacker, ok := ctx.Writer.(http.Hijacker) hijacker, ok := writer.(http.Hijacker)
if !ok { if !ok {
log.Error().Caller().Msg("DERP requires Hijacker interface from Gin") log.Error().Caller().Msg("DERP requires Hijacker interface from Gin")
ctx.String( writer.Header().Set("Content-Type", "text/plain")
http.StatusInternalServerError, writer.WriteHeader(http.StatusInternalServerError)
"HTTP does not support general TCP support", _, err := writer.Write([]byte("HTTP does not support general TCP support"))
) if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return return
} }
@ -118,13 +137,19 @@ func (h *Headscale) DERPHandler(ctx *gin.Context) {
netConn, conn, err := hijacker.Hijack() netConn, conn, err := hijacker.Hijack()
if err != nil { if err != nil {
log.Error().Caller().Err(err).Msgf("Hijack failed") log.Error().Caller().Err(err).Msgf("Hijack failed")
ctx.String( writer.Header().Set("Content-Type", "text/plain")
http.StatusInternalServerError, writer.WriteHeader(http.StatusInternalServerError)
"HTTP does not support general TCP support", _, err = writer.Write([]byte("HTTP does not support general TCP support"))
) if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return return
} }
log.Trace().Caller().Msgf("Hijacked connection from %v", req.RemoteAddr)
if !fastStart { if !fastStart {
pubKey := h.privateKey.Public() pubKey := h.privateKey.Public()
@ -143,12 +168,23 @@ func (h *Headscale) DERPHandler(ctx *gin.Context) {
// DERPProbeHandler is the endpoint that js/wasm clients hit to measure // DERPProbeHandler is the endpoint that js/wasm clients hit to measure
// DERP latency, since they can't do UDP STUN queries. // DERP latency, since they can't do UDP STUN queries.
func (h *Headscale) DERPProbeHandler(ctx *gin.Context) { func (h *Headscale) DERPProbeHandler(
switch ctx.Request.Method { writer http.ResponseWriter,
req *http.Request,
) {
switch req.Method {
case "HEAD", "GET": case "HEAD", "GET":
ctx.Writer.Header().Set("Access-Control-Allow-Origin", "*") writer.Header().Set("Access-Control-Allow-Origin", "*")
writer.WriteHeader(http.StatusOK)
default: default:
ctx.String(http.StatusMethodNotAllowed, "bogus probe method") writer.WriteHeader(http.StatusMethodNotAllowed)
_, err := writer.Write([]byte("bogus probe method"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
} }
} }
@ -159,15 +195,18 @@ func (h *Headscale) DERPProbeHandler(ctx *gin.Context) {
// The initial implementation is here https://github.com/tailscale/tailscale/pull/1406 // The initial implementation is here https://github.com/tailscale/tailscale/pull/1406
// They have a cache, but not clear if that is really necessary at Headscale, uh, scale. // They have a cache, but not clear if that is really necessary at Headscale, uh, scale.
// An example implementation is found here https://derp.tailscale.com/bootstrap-dns // An example implementation is found here https://derp.tailscale.com/bootstrap-dns
func (h *Headscale) DERPBootstrapDNSHandler(ctx *gin.Context) { func (h *Headscale) DERPBootstrapDNSHandler(
writer http.ResponseWriter,
req *http.Request,
) {
dnsEntries := make(map[string][]net.IP) dnsEntries := make(map[string][]net.IP)
resolvCtx, cancel := context.WithTimeout(context.Background(), time.Minute) resolvCtx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel() defer cancel()
var r net.Resolver var resolver net.Resolver
for _, region := range h.DERPMap.Regions { for _, region := range h.DERPMap.Regions {
for _, node := range region.Nodes { // we don't care if we override some nodes for _, node := range region.Nodes { // we don't care if we override some nodes
addrs, err := r.LookupIP(resolvCtx, "ip", node.HostName) addrs, err := resolver.LookupIP(resolvCtx, "ip", node.HostName)
if err != nil { if err != nil {
log.Trace(). log.Trace().
Caller(). Caller().
@ -179,7 +218,15 @@ func (h *Headscale) DERPBootstrapDNSHandler(ctx *gin.Context) {
dnsEntries[node.HostName] = addrs dnsEntries[node.HostName] = addrs
} }
} }
ctx.JSON(http.StatusOK, dnsEntries) writer.Header().Set("Content-Type", "application/json")
writer.WriteHeader(http.StatusOK)
err := json.NewEncoder(writer).Encode(dnsEntries)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
} }
// ServeSTUN starts a STUN server on the configured addr. // ServeSTUN starts a STUN server on the configured addr.

View file

@ -24,7 +24,7 @@
# When updating go.mod or go.sum, a new sha will need to be calculated, # When updating go.mod or go.sum, a new sha will need to be calculated,
# update this if you have a mismatch after doing a change to thos files. # update this if you have a mismatch after doing a change to thos files.
vendorSha256 = "sha256-j/hI6vP92UmcexFfzCe5fkGE8QUdQdNajSxMGib175Q="; vendorSha256 = "sha256-T6rH+aqofFmCPxDfoA5xd3kNUJeZkT4GRyuFEnenps8=";
ldflags = [ "-s" "-w" "-X github.com/juanfont/headscale/cmd/headscale/cli.Version=v${version}" ]; ldflags = [ "-s" "-w" "-X github.com/juanfont/headscale/cmd/headscale/cli.Version=v${version}" ];
}; };

1
go.mod
View file

@ -11,6 +11,7 @@ require (
github.com/gin-gonic/gin v1.7.7 github.com/gin-gonic/gin v1.7.7
github.com/glebarez/sqlite v1.4.3 github.com/glebarez/sqlite v1.4.3
github.com/gofrs/uuid v4.2.0+incompatible github.com/gofrs/uuid v4.2.0+incompatible
github.com/gorilla/mux v1.8.0
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 github.com/grpc-ecosystem/go-grpc-middleware v1.3.0
github.com/grpc-ecosystem/grpc-gateway/v2 v2.10.0 github.com/grpc-ecosystem/grpc-gateway/v2 v2.10.0
github.com/klauspost/compress v1.15.4 github.com/klauspost/compress v1.15.4

1
go.sum
View file

@ -403,6 +403,7 @@ github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORR
github.com/gordonklaus/ineffassign v0.0.0-20200309095847-7953dde2c7bf/go.mod h1:cuNKsD1zp2v6XfE/orVX2QE1LC+i254ceGcVeDT3pTU= github.com/gordonklaus/ineffassign v0.0.0-20200309095847-7953dde2c7bf/go.mod h1:cuNKsD1zp2v6XfE/orVX2QE1LC+i254ceGcVeDT3pTU=
github.com/gordonklaus/ineffassign v0.0.0-20210225214923-2e10b2664254/go.mod h1:M9mZEtGIsR1oDaZagNPNG9iq9n2HrhZ17dsXk73V3Lw= github.com/gordonklaus/ineffassign v0.0.0-20210225214923-2e10b2664254/go.mod h1:M9mZEtGIsR1oDaZagNPNG9iq9n2HrhZ17dsXk73V3Lw=
github.com/gorhill/cronexpr v0.0.0-20180427100037-88b0669f7d75/go.mod h1:g2644b03hfBX9Ov0ZBDgXXens4rxSxmqFBbhvKv2yVA= github.com/gorhill/cronexpr v0.0.0-20180427100037-88b0669f7d75/go.mod h1:g2644b03hfBX9Ov0ZBDgXXens4rxSxmqFBbhvKv2yVA=
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
github.com/gorilla/websocket v0.0.0-20170926233335-4201258b820c/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= github.com/gorilla/websocket v0.0.0-20170926233335-4201258b820c/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ=
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=

View file

@ -27,6 +27,7 @@ const (
errCouldNotConvertMachineInterface = Error("failed to convert machine interface") errCouldNotConvertMachineInterface = Error("failed to convert machine interface")
errHostnameTooLong = Error("Hostname too long") errHostnameTooLong = Error("Hostname too long")
MachineGivenNameHashLength = 8 MachineGivenNameHashLength = 8
MachineGivenNameTrimSize = 2
) )
const ( const (
@ -898,7 +899,7 @@ func (machine *Machine) RoutesToProto() *v1.Routes {
func (h *Headscale) GenerateGivenName(suppliedName string) (string, error) { func (h *Headscale) GenerateGivenName(suppliedName string) (string, error) {
// If a hostname is or will be longer than 63 chars after adding the hash, // If a hostname is or will be longer than 63 chars after adding the hash,
// it needs to be trimmed. // it needs to be trimmed.
trimmedHostnameLength := labelHostnameLength - MachineGivenNameHashLength - 2 trimmedHostnameLength := labelHostnameLength - MachineGivenNameHashLength - MachineGivenNameTrimSize
normalizedHostname, err := NormalizeToFQDNRules( normalizedHostname, err := NormalizeToFQDNRules(
suppliedName, suppliedName,

View file

@ -249,10 +249,12 @@ func (s *Suite) TestExpireMachine(c *check.C) {
machineFromDB, err := app.GetMachine("test", "testmachine") machineFromDB, err := app.GetMachine("test", "testmachine")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(machineFromDB, check.NotNil)
c.Assert(machineFromDB.isExpired(), check.Equals, false) c.Assert(machineFromDB.isExpired(), check.Equals, false)
app.ExpireMachine(machineFromDB) err = app.ExpireMachine(machineFromDB)
c.Assert(err, check.IsNil)
c.Assert(machineFromDB.isExpired(), check.Equals, true) c.Assert(machineFromDB.isExpired(), check.Equals, true)
} }
@ -918,6 +920,7 @@ func TestHeadscale_GenerateGivenName(t *testing.T) {
err, err,
tt.wantErr, tt.wantErr,
) )
return return
} }

256
oidc.go
View file

@ -13,7 +13,7 @@ import (
"time" "time"
"github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc"
"github.com/gin-gonic/gin" "github.com/gorilla/mux"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"tailscale.com/types/key" "tailscale.com/types/key"
@ -63,10 +63,17 @@ func (h *Headscale) initOIDC() error {
// RegisterOIDC redirects to the OIDC provider for authentication // RegisterOIDC redirects to the OIDC provider for authentication
// Puts machine key in cache so the callback can retrieve it using the oidc state param // Puts machine key in cache so the callback can retrieve it using the oidc state param
// Listens in /oidc/register/:mKey. // Listens in /oidc/register/:mKey.
func (h *Headscale) RegisterOIDC(ctx *gin.Context) { func (h *Headscale) RegisterOIDC(
machineKeyStr := ctx.Param("mkey") writer http.ResponseWriter,
if machineKeyStr == "" { req *http.Request,
ctx.String(http.StatusBadRequest, "Wrong params") ) {
vars := mux.Vars(req)
machineKeyStr, ok := vars["mkey"]
if !ok || machineKeyStr == "" {
log.Error().
Caller().
Msg("Missing machine key in URL")
http.Error(writer, "Missing machine key in URL", http.StatusBadRequest)
return return
} }
@ -81,7 +88,7 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
log.Error(). log.Error().
Caller(). Caller().
Msg("could not read 16 bytes from rand") Msg("could not read 16 bytes from rand")
ctx.String(http.StatusInternalServerError, "could not read 16 bytes from rand") http.Error(writer, "Internal server error", http.StatusInternalServerError)
return return
} }
@ -101,7 +108,7 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
authURL := h.oauth2Config.AuthCodeURL(stateStr, extras...) authURL := h.oauth2Config.AuthCodeURL(stateStr, extras...)
log.Debug().Msgf("Redirecting to %s for authentication", authURL) log.Debug().Msgf("Redirecting to %s for authentication", authURL)
ctx.Redirect(http.StatusFound, authURL) http.Redirect(writer, req, authURL, http.StatusFound)
} }
type oidcCallbackTemplateConfig struct { type oidcCallbackTemplateConfig struct {
@ -125,12 +132,23 @@ var oidcCallbackTemplate = template.Must(
// TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities // TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities
// TODO: Add groups information from OIDC tokens into machine HostInfo // TODO: Add groups information from OIDC tokens into machine HostInfo
// Listens in /oidc/callback. // Listens in /oidc/callback.
func (h *Headscale) OIDCCallback(ctx *gin.Context) { func (h *Headscale) OIDCCallback(
code := ctx.Query("code") writer http.ResponseWriter,
state := ctx.Query("state") req *http.Request,
) {
code := req.URL.Query().Get("code")
state := req.URL.Query().Get("state")
if code == "" || state == "" { if code == "" || state == "" {
ctx.String(http.StatusBadRequest, "Wrong params") writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("Wrong params"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return return
} }
@ -141,7 +159,15 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
Err(err). Err(err).
Caller(). Caller().
Msg("Could not exchange code for token") Msg("Could not exchange code for token")
ctx.String(http.StatusBadRequest, "Could not exchange code for token") writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("Could not exchange code for token"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return return
} }
@ -154,7 +180,15 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string) rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string)
if !rawIDTokenOK { if !rawIDTokenOK {
ctx.String(http.StatusBadRequest, "Could not extract ID Token") writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("Could not extract ID Token"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return return
} }
@ -167,7 +201,15 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
Err(err). Err(err).
Caller(). Caller().
Msg("failed to verify id token") Msg("failed to verify id token")
ctx.String(http.StatusBadRequest, "Failed to verify id token") writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("Failed to verify id token"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return return
} }
@ -186,10 +228,15 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
Err(err). Err(err).
Caller(). Caller().
Msg("Failed to decode id token claims") Msg("Failed to decode id token claims")
ctx.String( writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
http.StatusBadRequest, writer.WriteHeader(http.StatusBadRequest)
"Failed to decode id token claims", _, err := writer.Write([]byte("Failed to decode id token claims"))
) if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return return
} }
@ -199,10 +246,15 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
if at := strings.LastIndex(claims.Email, "@"); at < 0 || if at := strings.LastIndex(claims.Email, "@"); at < 0 ||
!IsStringInSlice(h.cfg.OIDC.AllowedDomains, claims.Email[at+1:]) { !IsStringInSlice(h.cfg.OIDC.AllowedDomains, claims.Email[at+1:]) {
log.Error().Msg("authenticated principal does not match any allowed domain") log.Error().Msg("authenticated principal does not match any allowed domain")
ctx.String( writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
http.StatusBadRequest, writer.WriteHeader(http.StatusBadRequest)
"unauthorized principal (domain mismatch)", _, err := writer.Write([]byte("unauthorized principal (domain mismatch)"))
) if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return return
} }
@ -212,7 +264,15 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
if len(h.cfg.OIDC.AllowedUsers) > 0 && if len(h.cfg.OIDC.AllowedUsers) > 0 &&
!IsStringInSlice(h.cfg.OIDC.AllowedUsers, claims.Email) { !IsStringInSlice(h.cfg.OIDC.AllowedUsers, claims.Email) {
log.Error().Msg("authenticated principal does not match any allowed user") log.Error().Msg("authenticated principal does not match any allowed user")
ctx.String(http.StatusBadRequest, "unauthorized principal (user mismatch)") writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("unauthorized principal (user mismatch)"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return return
} }
@ -223,7 +283,15 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
if !machineKeyFound { if !machineKeyFound {
log.Error(). log.Error().
Msg("requested machine state key expired before authorisation completed") Msg("requested machine state key expired before authorisation completed")
ctx.String(http.StatusBadRequest, "state has expired") writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("state has expired"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return return
} }
@ -237,17 +305,30 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
if err != nil { if err != nil {
log.Error(). log.Error().
Msg("could not parse machine public key") Msg("could not parse machine public key")
ctx.String(http.StatusBadRequest, "could not parse public key") writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("could not parse public key"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return return
} }
if !machineKeyOK { if !machineKeyOK {
log.Error().Msg("could not get machine key from cache") log.Error().Msg("could not get machine key from cache")
ctx.String( writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
http.StatusInternalServerError, writer.WriteHeader(http.StatusInternalServerError)
"could not get machine key from cache", _, err := writer.Write([]byte("could not get machine key from cache"))
) if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return return
} }
@ -264,7 +345,16 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Msg("machine already registered, reauthenticating") Msg("machine already registered, reauthenticating")
h.RefreshMachine(machine, time.Time{}) err := h.RefreshMachine(machine, time.Time{})
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to refresh machine")
http.Error(writer, "Failed to refresh machine", http.StatusInternalServerError)
return
}
var content bytes.Buffer var content bytes.Buffer
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{ if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
@ -276,14 +366,29 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
Str("type", "reauthenticate"). Str("type", "reauthenticate").
Err(err). Err(err).
Msg("Could not render OIDC callback template") Msg("Could not render OIDC callback template")
ctx.Data(
http.StatusInternalServerError, writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
"text/html; charset=utf-8", writer.WriteHeader(http.StatusInternalServerError)
[]byte("Could not render OIDC callback template"), _, err := writer.Write([]byte("Could not render OIDC callback template"))
) if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return
} }
ctx.Data(http.StatusOK, "text/html; charset=utf-8", content.Bytes()) writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(content.Bytes())
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return return
} }
@ -294,10 +399,15 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
) )
if err != nil { if err != nil {
log.Error().Err(err).Caller().Msgf("couldn't normalize email") log.Error().Err(err).Caller().Msgf("couldn't normalize email")
ctx.String( writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
http.StatusInternalServerError, writer.WriteHeader(http.StatusInternalServerError)
"couldn't normalize email", _, err := writer.Write([]byte("couldn't normalize email"))
) if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return return
} }
@ -314,10 +424,15 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
Err(err). Err(err).
Caller(). Caller().
Msgf("could not create new namespace '%s'", namespaceName) Msgf("could not create new namespace '%s'", namespaceName)
ctx.String( writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
http.StatusInternalServerError, writer.WriteHeader(http.StatusInternalServerError)
"could not create new namespace", _, err := writer.Write([]byte("could not create namespace"))
) if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return return
} }
@ -327,10 +442,15 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
Err(err). Err(err).
Str("namespace", namespaceName). Str("namespace", namespaceName).
Msg("could not find or create namespace") Msg("could not find or create namespace")
ctx.String( writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
http.StatusInternalServerError, writer.WriteHeader(http.StatusInternalServerError)
"could not find or create namespace", _, err := writer.Write([]byte("could not find or create namespace"))
) if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return return
} }
@ -347,10 +467,15 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
Caller(). Caller().
Err(err). Err(err).
Msg("could not register machine") Msg("could not register machine")
ctx.String( writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
http.StatusInternalServerError, writer.WriteHeader(http.StatusInternalServerError)
"could not register machine", _, err := writer.Write([]byte("could not register machine"))
) if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return return
} }
@ -365,12 +490,27 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
Str("type", "authenticate"). Str("type", "authenticate").
Err(err). Err(err).
Msg("Could not render OIDC callback template") Msg("Could not render OIDC callback template")
ctx.Data(
http.StatusInternalServerError, writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
"text/html; charset=utf-8", writer.WriteHeader(http.StatusInternalServerError)
[]byte("Could not render OIDC callback template"), _, err := writer.Write([]byte("Could not render OIDC callback template"))
) if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return
} }
ctx.Data(http.StatusOK, "text/html; charset=utf-8", content.Bytes()) writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(content.Bytes())
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
} }

View file

@ -6,13 +6,16 @@ import (
"net/http" "net/http"
textTemplate "text/template" textTemplate "text/template"
"github.com/gin-gonic/gin"
"github.com/gofrs/uuid" "github.com/gofrs/uuid"
"github.com/gorilla/mux"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
// WindowsConfigMessage shows a simple message in the browser for how to configure the Windows Tailscale client. // WindowsConfigMessage shows a simple message in the browser for how to configure the Windows Tailscale client.
func (h *Headscale) WindowsConfigMessage(ctx *gin.Context) { func (h *Headscale) WindowsConfigMessage(
writer http.ResponseWriter,
req *http.Request,
) {
winTemplate := template.Must(template.New("windows").Parse(` winTemplate := template.Must(template.New("windows").Parse(`
<html> <html>
<body> <body>
@ -63,20 +66,36 @@ REG ADD "HKLM\Software\Tailscale IPN" /v LoginURL /t REG_SZ /d "{{.URL}}"</code>
Str("handler", "WindowsRegConfig"). Str("handler", "WindowsRegConfig").
Err(err). Err(err).
Msg("Could not render Windows index template") Msg("Could not render Windows index template")
ctx.Data(
http.StatusInternalServerError, writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
"text/html; charset=utf-8", writer.WriteHeader(http.StatusInternalServerError)
[]byte("Could not render Windows index template"), _, err := writer.Write([]byte("Could not render Windows index template"))
) if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return return
} }
ctx.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes()) writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err := writer.Write(payload.Bytes())
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
} }
// WindowsRegConfig generates and serves a .reg file configured with the Headscale server address. // WindowsRegConfig generates and serves a .reg file configured with the Headscale server address.
func (h *Headscale) WindowsRegConfig(ctx *gin.Context) { func (h *Headscale) WindowsRegConfig(
writer http.ResponseWriter,
req *http.Request,
) {
config := WindowsRegistryConfig{ config := WindowsRegistryConfig{
URL: h.cfg.ServerURL, URL: h.cfg.ServerURL,
} }
@ -87,24 +106,36 @@ func (h *Headscale) WindowsRegConfig(ctx *gin.Context) {
Str("handler", "WindowsRegConfig"). Str("handler", "WindowsRegConfig").
Err(err). Err(err).
Msg("Could not render Apple macOS template") Msg("Could not render Apple macOS template")
ctx.Data(
http.StatusInternalServerError, writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
"text/html; charset=utf-8", writer.WriteHeader(http.StatusInternalServerError)
[]byte("Could not render Windows registry template"), _, err := writer.Write([]byte("Could not render Windows registry template"))
) if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return return
} }
ctx.Data( writer.Header().Set("Content-Type", "text/x-ms-regedit; charset=utf-8")
http.StatusOK, writer.WriteHeader(http.StatusOK)
"text/x-ms-regedit; charset=utf-8", _, err := writer.Write(content.Bytes())
content.Bytes(), if err != nil {
) log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
} }
// AppleConfigMessage shows a simple message in the browser to point the user to the iOS/MacOS profile and instructions for how to install it. // AppleConfigMessage shows a simple message in the browser to point the user to the iOS/MacOS profile and instructions for how to install it.
func (h *Headscale) AppleConfigMessage(ctx *gin.Context) { func (h *Headscale) AppleConfigMessage(
writer http.ResponseWriter,
req *http.Request,
) {
appleTemplate := template.Must(template.New("apple").Parse(` appleTemplate := template.Must(template.New("apple").Parse(`
<html> <html>
<body> <body>
@ -165,20 +196,45 @@ func (h *Headscale) AppleConfigMessage(ctx *gin.Context) {
Str("handler", "AppleMobileConfig"). Str("handler", "AppleMobileConfig").
Err(err). Err(err).
Msg("Could not render Apple index template") Msg("Could not render Apple index template")
ctx.Data(
http.StatusInternalServerError, writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
"text/html; charset=utf-8", writer.WriteHeader(http.StatusInternalServerError)
[]byte("Could not render Apple index template"), _, err := writer.Write([]byte("Could not render Apple index template"))
) if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return return
} }
ctx.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes()) writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err := writer.Write(payload.Bytes())
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
} }
func (h *Headscale) ApplePlatformConfig(ctx *gin.Context) { func (h *Headscale) ApplePlatformConfig(
platform := ctx.Param("platform") writer http.ResponseWriter,
req *http.Request,
) {
vars := mux.Vars(req)
platform, ok := vars["platform"]
if !ok {
log.Error().
Str("handler", "ApplePlatformConfig").
Msg("No platform specified")
http.Error(writer, "No platform specified", http.StatusBadRequest)
return
}
id, err := uuid.NewV4() id, err := uuid.NewV4()
if err != nil { if err != nil {
@ -186,11 +242,16 @@ func (h *Headscale) ApplePlatformConfig(ctx *gin.Context) {
Str("handler", "ApplePlatformConfig"). Str("handler", "ApplePlatformConfig").
Err(err). Err(err).
Msg("Failed not create UUID") Msg("Failed not create UUID")
ctx.Data(
http.StatusInternalServerError, writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
"text/html; charset=utf-8", writer.WriteHeader(http.StatusInternalServerError)
[]byte("Failed to create UUID"), _, err := writer.Write([]byte("Failed to create UUID"))
) if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return return
} }
@ -201,11 +262,16 @@ func (h *Headscale) ApplePlatformConfig(ctx *gin.Context) {
Str("handler", "ApplePlatformConfig"). Str("handler", "ApplePlatformConfig").
Err(err). Err(err).
Msg("Failed not create UUID") Msg("Failed not create UUID")
ctx.Data(
http.StatusInternalServerError, writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
"text/html; charset=utf-8", writer.WriteHeader(http.StatusInternalServerError)
[]byte("Failed to create UUID"), _, err := writer.Write([]byte("Failed to create content UUID"))
) if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return return
} }
@ -224,11 +290,16 @@ func (h *Headscale) ApplePlatformConfig(ctx *gin.Context) {
Str("handler", "ApplePlatformConfig"). Str("handler", "ApplePlatformConfig").
Err(err). Err(err).
Msg("Could not render Apple macOS template") Msg("Could not render Apple macOS template")
ctx.Data(
http.StatusInternalServerError, writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
"text/html; charset=utf-8", writer.WriteHeader(http.StatusInternalServerError)
[]byte("Could not render Apple macOS template"), _, err := writer.Write([]byte("Could not render Apple macOS template"))
) if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return return
} }
@ -238,20 +309,29 @@ func (h *Headscale) ApplePlatformConfig(ctx *gin.Context) {
Str("handler", "ApplePlatformConfig"). Str("handler", "ApplePlatformConfig").
Err(err). Err(err).
Msg("Could not render Apple iOS template") Msg("Could not render Apple iOS template")
ctx.Data(
http.StatusInternalServerError, writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
"text/html; charset=utf-8", writer.WriteHeader(http.StatusInternalServerError)
[]byte("Could not render Apple iOS template"), _, err := writer.Write([]byte("Could not render Apple iOS template"))
) if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return return
} }
default: default:
ctx.Data( writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
http.StatusOK, writer.WriteHeader(http.StatusBadRequest)
"text/html; charset=utf-8", _, err := writer.Write([]byte("Invalid platform, only ios and macos is supported"))
[]byte("Invalid platform, only ios and macos is supported"), if err != nil {
) log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return return
} }
@ -268,20 +348,29 @@ func (h *Headscale) ApplePlatformConfig(ctx *gin.Context) {
Str("handler", "ApplePlatformConfig"). Str("handler", "ApplePlatformConfig").
Err(err). Err(err).
Msg("Could not render Apple platform template") Msg("Could not render Apple platform template")
ctx.Data(
http.StatusInternalServerError, writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
"text/html; charset=utf-8", writer.WriteHeader(http.StatusInternalServerError)
[]byte("Could not render Apple platform template"), _, err := writer.Write([]byte("Could not render Apple platform template"))
) if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return return
} }
ctx.Data( writer.Header().Set("Content-Type", "application/x-apple-aspen-config; charset=utf-8")
http.StatusOK, writer.WriteHeader(http.StatusOK)
"application/x-apple-aspen-config; charset=utf-8", _, err = writer.Write(content.Bytes())
content.Bytes(), if err != nil {
) log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
} }
type WindowsRegistryConfig struct { type WindowsRegistryConfig struct {

274
poll.go
View file

@ -8,7 +8,7 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gorilla/mux"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
@ -33,13 +33,25 @@ const machineNameContextKey = contextKey("machineName")
// only after their first request (marked with the ReadOnly field). // only after their first request (marked with the ReadOnly field).
// //
// At this moment the updates are sent in a quite horrendous way, but they kinda work. // At this moment the updates are sent in a quite horrendous way, but they kinda work.
func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { func (h *Headscale) PollNetMapHandler(
writer http.ResponseWriter,
req *http.Request,
) {
vars := mux.Vars(req)
machineKeyStr, ok := vars["mkey"]
if !ok || machineKeyStr == "" {
log.Error().
Str("handler", "PollNetMap").
Msg("No machine key in request")
http.Error(writer, "No machine key in request", http.StatusBadRequest)
return
}
log.Trace(). log.Trace().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("id", ctx.Param("id")). Str("id", machineKeyStr).
Msg("PollNetMapHandler called") Msg("PollNetMapHandler called")
body, _ := io.ReadAll(ctx.Request.Body) body, _ := io.ReadAll(req.Body)
machineKeyStr := ctx.Param("id")
var machineKey key.MachinePublic var machineKey key.MachinePublic
err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr))) err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr)))
@ -48,18 +60,19 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Err(err). Err(err).
Msg("Cannot parse client key") Msg("Cannot parse client key")
ctx.String(http.StatusBadRequest, "")
http.Error(writer, "Cannot parse client key", http.StatusBadRequest)
return return
} }
req := tailcfg.MapRequest{} mapRequest := tailcfg.MapRequest{}
err = decode(body, &req, &machineKey, h.privateKey) err = decode(body, &mapRequest, &machineKey, h.privateKey)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Err(err). Err(err).
Msg("Cannot decode message") Msg("Cannot decode message")
ctx.String(http.StatusBadRequest, "") http.Error(writer, "Cannot decode message", http.StatusBadRequest)
return return
} }
@ -70,26 +83,27 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
log.Warn(). log.Warn().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Msgf("Ignoring request, cannot find machine with key %s", machineKey.String()) Msgf("Ignoring request, cannot find machine with key %s", machineKey.String())
ctx.String(http.StatusUnauthorized, "")
http.Error(writer, "", http.StatusUnauthorized)
return return
} }
log.Error(). log.Error().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Msgf("Failed to fetch machine from the database with Machine key: %s", machineKey.String()) Msgf("Failed to fetch machine from the database with Machine key: %s", machineKey.String())
ctx.String(http.StatusInternalServerError, "") http.Error(writer, "", http.StatusInternalServerError)
return return
} }
log.Trace(). log.Trace().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("id", ctx.Param("id")). Str("id", machineKeyStr).
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Msg("Found machine in database") Msg("Found machine in database")
machine.Hostname = req.Hostinfo.Hostname machine.Hostname = mapRequest.Hostinfo.Hostname
machine.HostInfo = HostInfo(*req.Hostinfo) machine.HostInfo = HostInfo(*mapRequest.Hostinfo)
machine.DiscoKey = DiscoPublicKeyStripPrefix(req.DiscoKey) machine.DiscoKey = DiscoPublicKeyStripPrefix(mapRequest.DiscoKey)
now := time.Now().UTC() now := time.Now().UTC()
// update ACLRules with peer informations (to update server tags if necessary) // update ACLRules with peer informations (to update server tags if necessary)
@ -111,8 +125,8 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
// //
// The intended use is for clients to discover the DERP map at start-up // The intended use is for clients to discover the DERP map at start-up
// before their first real endpoint update. // before their first real endpoint update.
if !req.ReadOnly { if !mapRequest.ReadOnly {
machine.Endpoints = req.Endpoints machine.Endpoints = mapRequest.Endpoints
machine.LastSeen = &now machine.LastSeen = &now
} }
@ -120,25 +134,25 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("id", ctx.Param("id")). Str("id", machineKeyStr).
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Err(err). Err(err).
Msg("Failed to persist/update machine in the database") Msg("Failed to persist/update machine in the database")
ctx.String(http.StatusInternalServerError, ":(") http.Error(writer, "", http.StatusInternalServerError)
return return
} }
} }
data, err := h.getMapResponse(machineKey, req, machine) data, err := h.getMapResponse(machineKey, mapRequest, machine)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("id", ctx.Param("id")). Str("id", machineKeyStr).
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Err(err). Err(err).
Msg("Failed to get Map response") Msg("Failed to get Map response")
ctx.String(http.StatusInternalServerError, ":(") http.Error(writer, "", http.StatusInternalServerError)
return return
} }
@ -150,19 +164,28 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
// Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696 // Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696
log.Debug(). log.Debug().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("id", ctx.Param("id")). Str("id", machineKeyStr).
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Bool("readOnly", req.ReadOnly). Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", req.OmitPeers). Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", req.Stream). Bool("stream", mapRequest.Stream).
Msg("Client map request processed") Msg("Client map request processed")
if req.ReadOnly { if mapRequest.ReadOnly {
log.Info(). log.Info().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Msg("Client is starting up. Probably interested in a DERP map") Msg("Client is starting up. Probably interested in a DERP map")
ctx.Data(http.StatusOK, "application/json; charset=utf-8", data)
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err := writer.Write(data)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return return
} }
@ -177,7 +200,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
// Only create update channel if it has not been created // Only create update channel if it has not been created
log.Trace(). log.Trace().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("id", ctx.Param("id")). Str("id", machineKeyStr).
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Msg("Loading or creating update channel") Msg("Loading or creating update channel")
@ -189,13 +212,20 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
keepAliveChan := make(chan []byte) keepAliveChan := make(chan []byte)
if req.OmitPeers && !req.Stream { if mapRequest.OmitPeers && !mapRequest.Stream {
log.Info(). log.Info().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Msg("Client sent endpoint update and is ok with a response without peer list") Msg("Client sent endpoint update and is ok with a response without peer list")
ctx.Data(http.StatusOK, "application/json; charset=utf-8", data) writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err := writer.Write(data)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
// It sounds like we should update the nodes when we have received a endpoint update // It sounds like we should update the nodes when we have received a endpoint update
// even tho the comments in the tailscale code dont explicitly say so. // even tho the comments in the tailscale code dont explicitly say so.
updateRequestsFromNode.WithLabelValues(machine.Namespace.Name, machine.Hostname, "endpoint-update"). updateRequestsFromNode.WithLabelValues(machine.Namespace.Name, machine.Hostname, "endpoint-update").
@ -203,12 +233,12 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
updateChan <- struct{}{} updateChan <- struct{}{}
return return
} else if req.OmitPeers && req.Stream { } else if mapRequest.OmitPeers && mapRequest.Stream {
log.Warn(). log.Warn().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Msg("Ignoring request, don't know how to handle it") Msg("Ignoring request, don't know how to handle it")
ctx.String(http.StatusBadRequest, "") http.Error(writer, "", http.StatusBadRequest)
return return
} }
@ -232,9 +262,10 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
updateChan <- struct{}{} updateChan <- struct{}{}
h.PollNetMapStream( h.PollNetMapStream(
ctx, writer,
machine,
req, req,
machine,
mapRequest,
machineKey, machineKey,
pollDataChan, pollDataChan,
keepAliveChan, keepAliveChan,
@ -242,7 +273,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
) )
log.Trace(). log.Trace().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("id", ctx.Param("id")). Str("id", machineKeyStr).
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Msg("Finished stream, closing PollNetMap session") Msg("Finished stream, closing PollNetMap session")
} }
@ -251,7 +282,8 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
// stream logic, ensuring we communicate updates and data // stream logic, ensuring we communicate updates and data
// to the connected clients. // to the connected clients.
func (h *Headscale) PollNetMapStream( func (h *Headscale) PollNetMapStream(
ctx *gin.Context, writer http.ResponseWriter,
req *http.Request,
machine *Machine, machine *Machine,
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
machineKey key.MachinePublic, machineKey key.MachinePublic,
@ -259,51 +291,31 @@ func (h *Headscale) PollNetMapStream(
keepAliveChan chan []byte, keepAliveChan chan []byte,
updateChan chan struct{}, updateChan chan struct{},
) { ) {
{ ctx := context.WithValue(req.Context(), machineNameContextKey, machine.Hostname)
machine, err := h.GetMachineByMachineKey(machineKey)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Warn().
Str("handler", "PollNetMap").
Msgf("Ignoring request, cannot find machine with key %s", machineKey.String())
ctx.String(http.StatusUnauthorized, "")
return ctx, cancel := context.WithCancel(ctx)
} defer cancel()
log.Error().
Str("handler", "PollNetMap").
Msgf("Failed to fetch machine from the database with Machine key: %s", machineKey.String())
ctx.String(http.StatusInternalServerError, "")
return go h.scheduledPollWorker(
} ctx,
updateChan,
keepAliveChan,
machineKey,
mapRequest,
machine,
)
ctx := context.WithValue(ctx.Request.Context(), machineNameContextKey, machine.Hostname) log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", machine.Hostname).
Msg("Waiting for data to stream...")
ctx, cancel := context.WithCancel(ctx) log.Trace().
defer cancel() Str("handler", "PollNetMapStream").
Str("machine", machine.Hostname).
go h.scheduledPollWorker( Msgf("pollData is %#v, keepAliveChan is %#v, updateChan is %#v", pollDataChan, keepAliveChan, updateChan)
ctx,
updateChan,
keepAliveChan,
machineKey,
mapRequest,
machine,
)
}
ctx.Stream(func(writer io.Writer) bool {
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", machine.Hostname).
Msg("Waiting for data to stream...")
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", machine.Hostname).
Msgf("pollData is %#v, keepAliveChan is %#v, updateChan is %#v", pollDataChan, keepAliveChan, updateChan)
for {
select { select {
case data := <-pollDataChan: case data := <-pollDataChan:
log.Trace(). log.Trace().
@ -321,8 +333,21 @@ func (h *Headscale) PollNetMapStream(
Err(err). Err(err).
Msg("Cannot write data") Msg("Cannot write data")
return false return
} }
flusher, ok := writer.(http.Flusher)
if !ok {
log.Error().
Caller().
Str("handler", "PollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "pollData").
Msg("Cannot cast writer to http.Flusher")
} else {
flusher.Flush()
}
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
@ -343,7 +368,7 @@ func (h *Headscale) PollNetMapStream(
// client has been removed from database // client has been removed from database
// since the stream opened, terminate connection. // since the stream opened, terminate connection.
return false return
} }
now := time.Now().UTC() now := time.Now().UTC()
machine.LastSeen = &now machine.LastSeen = &now
@ -360,16 +385,16 @@ func (h *Headscale) PollNetMapStream(
Str("channel", "pollData"). Str("channel", "pollData").
Err(err). Err(err).
Msg("Cannot update machine LastSuccessfulUpdate") Msg("Cannot update machine LastSuccessfulUpdate")
} else {
log.Trace(). return
Str("handler", "PollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "pollData").
Int("bytes", len(data)).
Msg("Machine entry in database updated successfully after sending pollData")
} }
return true log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "pollData").
Int("bytes", len(data)).
Msg("Machine entry in database updated successfully after sending data")
case data := <-keepAliveChan: case data := <-keepAliveChan:
log.Trace(). log.Trace().
@ -387,8 +412,20 @@ func (h *Headscale) PollNetMapStream(
Err(err). Err(err).
Msg("Cannot write keep alive message") Msg("Cannot write keep alive message")
return false return
} }
flusher, ok := writer.(http.Flusher)
if !ok {
log.Error().
Caller().
Str("handler", "PollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "keepAlive").
Msg("Cannot cast writer to http.Flusher")
} else {
flusher.Flush()
}
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
@ -409,7 +446,7 @@ func (h *Headscale) PollNetMapStream(
// client has been removed from database // client has been removed from database
// since the stream opened, terminate connection. // since the stream opened, terminate connection.
return false return
} }
now := time.Now().UTC() now := time.Now().UTC()
machine.LastSeen = &now machine.LastSeen = &now
@ -421,16 +458,16 @@ func (h *Headscale) PollNetMapStream(
Str("channel", "keepAlive"). Str("channel", "keepAlive").
Err(err). Err(err).
Msg("Cannot update machine LastSeen") Msg("Cannot update machine LastSeen")
} else {
log.Trace(). return
Str("handler", "PollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "keepAlive").
Int("bytes", len(data)).
Msg("Machine updated successfully after sending keep alive")
} }
return true log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "keepAlive").
Int("bytes", len(data)).
Msg("Machine updated successfully after sending keep alive")
case <-updateChan: case <-updateChan:
log.Trace(). log.Trace().
@ -440,6 +477,7 @@ func (h *Headscale) PollNetMapStream(
Msg("Received a request for update") Msg("Received a request for update")
updateRequestsReceivedOnChannel.WithLabelValues(machine.Namespace.Name, machine.Hostname). updateRequestsReceivedOnChannel.WithLabelValues(machine.Namespace.Name, machine.Hostname).
Inc() Inc()
if h.isOutdated(machine) { if h.isOutdated(machine) {
var lastUpdate time.Time var lastUpdate time.Time
if machine.LastSuccessfulUpdate != nil { if machine.LastSuccessfulUpdate != nil {
@ -459,6 +497,8 @@ func (h *Headscale) PollNetMapStream(
Str("channel", "update"). Str("channel", "update").
Err(err). Err(err).
Msg("Could not get the map update") Msg("Could not get the map update")
return
} }
_, err = writer.Write(data) _, err = writer.Write(data)
if err != nil { if err != nil {
@ -471,8 +511,21 @@ func (h *Headscale) PollNetMapStream(
updateRequestsSentToNode.WithLabelValues(machine.Namespace.Name, machine.Hostname, "failed"). updateRequestsSentToNode.WithLabelValues(machine.Namespace.Name, machine.Hostname, "failed").
Inc() Inc()
return false return
} }
flusher, ok := writer.(http.Flusher)
if !ok {
log.Error().
Caller().
Str("handler", "PollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "update").
Msg("Cannot cast writer to http.Flusher")
} else {
flusher.Flush()
}
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
@ -499,7 +552,7 @@ func (h *Headscale) PollNetMapStream(
// client has been removed from database // client has been removed from database
// since the stream opened, terminate connection. // since the stream opened, terminate connection.
return false return
} }
now := time.Now().UTC() now := time.Now().UTC()
@ -515,6 +568,8 @@ func (h *Headscale) PollNetMapStream(
Str("channel", "update"). Str("channel", "update").
Err(err). Err(err).
Msg("Cannot update machine LastSuccessfulUpdate") Msg("Cannot update machine LastSuccessfulUpdate")
return
} }
} else { } else {
var lastUpdate time.Time var lastUpdate time.Time
@ -529,9 +584,7 @@ func (h *Headscale) PollNetMapStream(
Msgf("%s is up to date", machine.Hostname) Msgf("%s is up to date", machine.Hostname)
} }
return true case <-ctx.Done():
case <-ctx.Request.Context().Done():
log.Info(). log.Info().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
@ -550,7 +603,7 @@ func (h *Headscale) PollNetMapStream(
// client has been removed from database // client has been removed from database
// since the stream opened, terminate connection. // since the stream opened, terminate connection.
return false return
} }
now := time.Now().UTC() now := time.Now().UTC()
machine.LastSeen = &now machine.LastSeen = &now
@ -564,9 +617,18 @@ func (h *Headscale) PollNetMapStream(
Msg("Cannot update machine LastSeen") Msg("Cannot update machine LastSeen")
} }
return false // The connection has been closed, so we can stop polling.
return
case <-h.shutdownChan:
log.Info().
Str("handler", "PollNetMapStream").
Str("machine", machine.Hostname).
Msg("The long-poll handler is shutting down")
return
} }
}) }
} }
func (h *Headscale) scheduledPollWorker( func (h *Headscale) scheduledPollWorker(

View file

@ -28,7 +28,7 @@ func (s *Suite) TestGetRoutes(c *check.C) {
MachineKey: "foo", MachineKey: "foo",
NodeKey: "bar", NodeKey: "bar",
DiscoKey: "faa", DiscoKey: "faa",
Hostname: "test_get_route_machine", Hostname: "test_get_route_machine",
NamespaceID: namespace.ID, NamespaceID: namespace.ID,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
@ -79,7 +79,7 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
MachineKey: "foo", MachineKey: "foo",
NodeKey: "bar", NodeKey: "bar",
DiscoKey: "faa", DiscoKey: "faa",
Hostname: "test_enable_route_machine", Hostname: "test_enable_route_machine",
NamespaceID: namespace.ID, NamespaceID: namespace.ID,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),

View file

@ -6,14 +6,16 @@ import (
"html/template" "html/template"
"net/http" "net/http"
"github.com/gin-gonic/gin"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
//go:embed gen/openapiv2/headscale/v1/headscale.swagger.json //go:embed gen/openapiv2/headscale/v1/headscale.swagger.json
var apiV1JSON []byte var apiV1JSON []byte
func SwaggerUI(ctx *gin.Context) { func SwaggerUI(
writer http.ResponseWriter,
req *http.Request,
) {
swaggerTemplate := template.Must(template.New("swagger").Parse(` swaggerTemplate := template.Must(template.New("swagger").Parse(`
<html> <html>
<head> <head>
@ -52,18 +54,41 @@ func SwaggerUI(ctx *gin.Context) {
Caller(). Caller().
Err(err). Err(err).
Msg("Could not render Swagger") Msg("Could not render Swagger")
ctx.Data(
http.StatusInternalServerError, writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
"text/html; charset=utf-8", writer.WriteHeader(http.StatusInternalServerError)
[]byte("Could not render Swagger"), _, err := writer.Write([]byte("Could not render Swagger"))
) if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return return
} }
ctx.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes()) writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err := writer.Write(payload.Bytes())
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
} }
func SwaggerAPIv1(ctx *gin.Context) { func SwaggerAPIv1(
ctx.Data(http.StatusOK, "application/json; charset=utf-8", apiV1JSON) writer http.ResponseWriter,
req *http.Request,
) {
writer.Header().Set("Content-Type", "application/json; charset=utf-88")
writer.WriteHeader(http.StatusOK)
if _, err := writer.Write(apiV1JSON); err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
} }

View file

@ -324,18 +324,18 @@ func GenerateRandomStringURLSafe(n int) (string, error) {
// It will return an error if the system's secure random // It will return an error if the system's secure random
// number generator fails to function correctly, in which // number generator fails to function correctly, in which
// case the caller should not continue. // case the caller should not continue.
func GenerateRandomStringDNSSafe(n int) (string, error) { func GenerateRandomStringDNSSafe(size int) (string, error) {
var str string var str string
var err error var err error
for len(str) < n { for len(str) < size {
str, err = GenerateRandomStringURLSafe(n) str, err = GenerateRandomStringURLSafe(size)
if err != nil { if err != nil {
return "", err return "", err
} }
str = strings.ToLower(strings.ReplaceAll(strings.ReplaceAll(str, "_", ""), "-", "")) str = strings.ToLower(strings.ReplaceAll(strings.ReplaceAll(str, "_", ""), "-", ""))
} }
return str[:n], nil return str[:size], nil
} }
func IsStringInSlice(slice []string, str string) bool { func IsStringInSlice(slice []string, str string) bool {