diff --git a/CHANGELOG.md b/CHANGELOG.md index ad3bbfbf..9e916056 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,8 +30,11 @@ - Add -c option to specify config file from command line [#285](https://github.com/juanfont/headscale/issues/285) [#612](https://github.com/juanfont/headscale/pull/601) - Add configuration option to allow Tailscale clients to use a random WireGuard port. [kb/1181/firewalls](https://tailscale.com/kb/1181/firewalls) [#624](https://github.com/juanfont/headscale/pull/624) - Improve obtuse UX regarding missing configuration (`ephemeral_node_inactivity_timeout` not set) [#639](https://github.com/juanfont/headscale/pull/639) +- Fix nodes being shown as 'offline' in `tailscale status` [648](https://github.com/juanfont/headscale/pull/648) - Fix nodes being shown as 'offline' in `tailscale status` [#648](https://github.com/juanfont/headscale/pull/648) - Improve shutdown behaviour [#651](https://github.com/juanfont/headscale/pull/651) +- Drop Gin as web framework in Headscale [648](https://github.com/juanfont/headscale/pull/648) + ## 0.15.0 (2022-03-20) diff --git a/acls.go b/acls.go index c7a84afc..b485ce30 100644 --- a/acls.go +++ b/acls.go @@ -37,7 +37,7 @@ const ( expectedTokenItems = 2 ) -// For some reason golang.org/x/net/internal/iana is an internal package +// For some reason golang.org/x/net/internal/iana is an internal package. const ( protocolICMP = 1 // Internet Control Message protocolIGMP = 2 // Internet Group Management diff --git a/api.go b/api.go index 45fd7793..fc27e46b 100644 --- a/api.go +++ b/api.go @@ -12,7 +12,7 @@ import ( "strings" "time" - "github.com/gin-gonic/gin" + "github.com/gorilla/mux" "github.com/klauspost/compress/zstd" "github.com/rs/zerolog/log" "gorm.io/gorm" @@ -32,12 +32,19 @@ const ( // KeyHandler provides the Headscale pub key // Listens in /key. -func (h *Headscale) KeyHandler(ctx *gin.Context) { - ctx.Data( - http.StatusOK, - "text/plain; charset=utf-8", - []byte(MachinePublicKeyStripPrefix(h.privateKey.Public())), - ) +func (h *Headscale) KeyHandler( + writer http.ResponseWriter, + req *http.Request, +) { + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusOK) + _, err := writer.Write([]byte(MachinePublicKeyStripPrefix(h.privateKey.Public()))) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } } type registerWebAPITemplateConfig struct { @@ -63,10 +70,21 @@ var registerWebAPITemplate = template.Must( // RegisterWebAPI shows a simple message in the browser to point to the CLI // Listens in /register. -func (h *Headscale) RegisterWebAPI(ctx *gin.Context) { - machineKeyStr := ctx.Query("key") +func (h *Headscale) RegisterWebAPI( + writer http.ResponseWriter, + req *http.Request, +) { + machineKeyStr := req.URL.Query().Get("key") if machineKeyStr == "" { - ctx.String(http.StatusBadRequest, "Wrong params") + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusBadRequest) + _, err := writer.Write([]byte("Wrong params")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } return } @@ -79,21 +97,48 @@ func (h *Headscale) RegisterWebAPI(ctx *gin.Context) { Str("func", "RegisterWebAPI"). Err(err). Msg("Could not render register web API template") - ctx.Data( - http.StatusInternalServerError, - "text/html; charset=utf-8", - []byte("Could not render register web API template"), - ) + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusInternalServerError) + _, err = writer.Write([]byte("Could not render register web API template")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } + + return } - ctx.Data(http.StatusOK, "text/html; charset=utf-8", content.Bytes()) + writer.Header().Set("Content-Type", "text/html; charset=utf-8") + writer.WriteHeader(http.StatusOK) + _, err := writer.Write(content.Bytes()) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } } // RegistrationHandler handles the actual registration process of a machine -// Endpoint /machine/:id. -func (h *Headscale) RegistrationHandler(ctx *gin.Context) { - body, _ := io.ReadAll(ctx.Request.Body) - machineKeyStr := ctx.Param("id") +// Endpoint /machine/:mkey. +func (h *Headscale) RegistrationHandler( + writer http.ResponseWriter, + req *http.Request, +) { + vars := mux.Vars(req) + machineKeyStr, ok := vars["mkey"] + if !ok || machineKeyStr == "" { + log.Error(). + Str("handler", "RegistrationHandler"). + Msg("No machine ID in request") + http.Error(writer, "No machine ID in request", http.StatusBadRequest) + + return + } + + body, _ := io.ReadAll(req.Body) var machineKey key.MachinePublic err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr))) @@ -103,19 +148,19 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) { Err(err). Msg("Cannot parse machine key") machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc() - ctx.String(http.StatusInternalServerError, "Sad!") + http.Error(writer, "Cannot parse machine key", http.StatusBadRequest) return } - req := tailcfg.RegisterRequest{} - err = decode(body, &req, &machineKey, h.privateKey) + registerRequest := tailcfg.RegisterRequest{} + err = decode(body, ®isterRequest, &machineKey, h.privateKey) if err != nil { log.Error(). Caller(). Err(err). Msg("Cannot decode message") machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc() - ctx.String(http.StatusInternalServerError, "Very sad!") + http.Error(writer, "Cannot decode message", http.StatusBadRequest) return } @@ -123,23 +168,23 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) { now := time.Now().UTC() machine, err := h.GetMachineByMachineKey(machineKey) if errors.Is(err, gorm.ErrRecordNotFound) { - log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine") + log.Info().Str("machine", registerRequest.Hostinfo.Hostname).Msg("New machine") machineKeyStr := MachinePublicKeyStripPrefix(machineKey) // If the machine has AuthKey set, handle registration via PreAuthKeys - if req.Auth.AuthKey != "" { - h.handleAuthKey(ctx, machineKey, req) + if registerRequest.Auth.AuthKey != "" { + h.handleAuthKey(writer, req, machineKey, registerRequest) return } - givenName, err := h.GenerateGivenName(req.Hostinfo.Hostname) + givenName, err := h.GenerateGivenName(registerRequest.Hostinfo.Hostname) if err != nil { log.Error(). Caller(). Str("func", "RegistrationHandler"). - Str("hostinfo.name", req.Hostinfo.Hostname). + Str("hostinfo.name", registerRequest.Hostinfo.Hostname). Err(err) return @@ -151,20 +196,20 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) { // happens newMachine := Machine{ MachineKey: machineKeyStr, - Hostname: req.Hostinfo.Hostname, + Hostname: registerRequest.Hostinfo.Hostname, GivenName: givenName, - NodeKey: NodePublicKeyStripPrefix(req.NodeKey), + NodeKey: NodePublicKeyStripPrefix(registerRequest.NodeKey), LastSeen: &now, Expiry: &time.Time{}, } - if !req.Expiry.IsZero() { + if !registerRequest.Expiry.IsZero() { log.Trace(). Caller(). - Str("machine", req.Hostinfo.Hostname). - Time("expiry", req.Expiry). + Str("machine", registerRequest.Hostinfo.Hostname). + Time("expiry", registerRequest.Expiry). Msg("Non-zero expiry time requested") - newMachine.Expiry = &req.Expiry + newMachine.Expiry = ®isterRequest.Expiry } h.registrationCache.Set( @@ -173,7 +218,7 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) { registerCacheExpiration, ) - h.handleMachineRegistrationNew(ctx, machineKey, req) + h.handleMachineRegistrationNew(writer, req, machineKey, registerRequest) return } @@ -185,11 +230,11 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) { // - Trying to log out (sending a expiry in the past) // - A valid, registered machine, looking for the node map // - Expired machine wanting to reauthenticate - if machine.NodeKey == NodePublicKeyStripPrefix(req.NodeKey) { + if machine.NodeKey == NodePublicKeyStripPrefix(registerRequest.NodeKey) { // The client sends an Expiry in the past if the client is requesting to expire the key (aka logout) // https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648 - if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) { - h.handleMachineLogOut(ctx, machineKey, *machine) + if !registerRequest.Expiry.IsZero() && registerRequest.Expiry.UTC().Before(now) { + h.handleMachineLogOut(writer, req, machineKey, *machine) return } @@ -197,22 +242,22 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) { // If machine is not expired, and is register, we have a already accepted this machine, // let it proceed with a valid registration if !machine.isExpired() { - h.handleMachineValidRegistration(ctx, machineKey, *machine) + h.handleMachineValidRegistration(writer, req, machineKey, *machine) return } } // The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration - if machine.NodeKey == NodePublicKeyStripPrefix(req.OldNodeKey) && + if machine.NodeKey == NodePublicKeyStripPrefix(registerRequest.OldNodeKey) && !machine.isExpired() { - h.handleMachineRefreshKey(ctx, machineKey, req, *machine) + h.handleMachineRefreshKey(writer, req, machineKey, registerRequest, *machine) return } // The machine has expired - h.handleMachineExpired(ctx, machineKey, req, *machine) + h.handleMachineExpired(writer, req, machineKey, registerRequest, *machine) return } @@ -220,12 +265,12 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) { func (h *Headscale) getMapResponse( machineKey key.MachinePublic, - req tailcfg.MapRequest, + mapRequest tailcfg.MapRequest, machine *Machine, ) ([]byte, error) { log.Trace(). Str("func", "getMapResponse"). - Str("machine", req.Hostinfo.Hostname). + Str("machine", mapRequest.Hostinfo.Hostname). Msg("Creating Map response") node, err := machine.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig, true) if err != nil { @@ -286,12 +331,12 @@ func (h *Headscale) getMapResponse( log.Trace(). Str("func", "getMapResponse"). - Str("machine", req.Hostinfo.Hostname). + Str("machine", mapRequest.Hostinfo.Hostname). // Interface("payload", resp). Msgf("Generated map response: %s", tailMapResponseToString(resp)) var respBody []byte - if req.Compress == "zstd" { + if mapRequest.Compress == "zstd" { src, err := json.Marshal(resp) if err != nil { log.Error(). @@ -357,7 +402,8 @@ func (h *Headscale) getMapKeepAliveResponse( } func (h *Headscale) handleMachineLogOut( - ctx *gin.Context, + writer http.ResponseWriter, + req *http.Request, machineKey key.MachinePublic, machine Machine, ) { @@ -367,7 +413,17 @@ func (h *Headscale) handleMachineLogOut( Str("machine", machine.Hostname). Msg("Client requested logout") - h.ExpireMachine(&machine) + err := h.ExpireMachine(&machine) + if err != nil { + log.Error(). + Caller(). + Str("func", "handleMachineLogOut"). + Err(err). + Msg("Failed to expire machine") + http.Error(writer, "Internal server error", http.StatusInternalServerError) + + return + } resp.AuthURL = "" resp.MachineAuthorized = false @@ -378,15 +434,25 @@ func (h *Headscale) handleMachineLogOut( Caller(). Err(err). Msg("Cannot encode message") - ctx.String(http.StatusInternalServerError, "") + http.Error(writer, "Internal server error", http.StatusInternalServerError) return } - ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) + + writer.Header().Set("Content-Type", "application/json; charset=utf-8") + writer.WriteHeader(http.StatusOK) + _, err = writer.Write(respBody) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } } func (h *Headscale) handleMachineValidRegistration( - ctx *gin.Context, + writer http.ResponseWriter, + req *http.Request, machineKey key.MachinePublic, machine Machine, ) { @@ -410,17 +476,27 @@ func (h *Headscale) handleMachineValidRegistration( Msg("Cannot encode message") machineRegistrations.WithLabelValues("update", "web", "error", machine.Namespace.Name). Inc() - ctx.String(http.StatusInternalServerError, "") + http.Error(writer, "Internal server error", http.StatusInternalServerError) return } machineRegistrations.WithLabelValues("update", "web", "success", machine.Namespace.Name). Inc() - ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) + + writer.Header().Set("Content-Type", "application/json; charset=utf-8") + writer.WriteHeader(http.StatusOK) + _, err = writer.Write(respBody) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } } func (h *Headscale) handleMachineExpired( - ctx *gin.Context, + writer http.ResponseWriter, + req *http.Request, machineKey key.MachinePublic, registerRequest tailcfg.RegisterRequest, machine Machine, @@ -433,7 +509,7 @@ func (h *Headscale) handleMachineExpired( Msg("Machine registration has expired. Sending a authurl to register") if registerRequest.Auth.AuthKey != "" { - h.handleAuthKey(ctx, machineKey, registerRequest) + h.handleAuthKey(writer, req, machineKey, registerRequest) return } @@ -454,17 +530,27 @@ func (h *Headscale) handleMachineExpired( Msg("Cannot encode message") machineRegistrations.WithLabelValues("reauth", "web", "error", machine.Namespace.Name). Inc() - ctx.String(http.StatusInternalServerError, "") + http.Error(writer, "Internal server error", http.StatusInternalServerError) return } machineRegistrations.WithLabelValues("reauth", "web", "success", machine.Namespace.Name). Inc() - ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) + + writer.Header().Set("Content-Type", "application/json; charset=utf-8") + writer.WriteHeader(http.StatusOK) + _, err = writer.Write(respBody) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } } func (h *Headscale) handleMachineRefreshKey( - ctx *gin.Context, + writer http.ResponseWriter, + req *http.Request, machineKey key.MachinePublic, registerRequest tailcfg.RegisterRequest, machine Machine, @@ -481,7 +567,7 @@ func (h *Headscale) handleMachineRefreshKey( Caller(). Err(err). Msg("Failed to update machine key in the database") - ctx.String(http.StatusInternalServerError, "Internal server error") + http.Error(writer, "Internal server error", http.StatusInternalServerError) return } @@ -494,15 +580,25 @@ func (h *Headscale) handleMachineRefreshKey( Caller(). Err(err). Msg("Cannot encode message") - ctx.String(http.StatusInternalServerError, "Internal server error") + http.Error(writer, "Internal server error", http.StatusInternalServerError) return } - ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) + + writer.Header().Set("Content-Type", "application/json; charset=utf-8") + writer.WriteHeader(http.StatusOK) + _, err = writer.Write(respBody) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } } func (h *Headscale) handleMachineRegistrationNew( - ctx *gin.Context, + writer http.ResponseWriter, + req *http.Request, machineKey key.MachinePublic, registerRequest tailcfg.RegisterRequest, ) { @@ -529,16 +625,26 @@ func (h *Headscale) handleMachineRegistrationNew( Caller(). Err(err). Msg("Cannot encode message") - ctx.String(http.StatusInternalServerError, "") + http.Error(writer, "Internal server error", http.StatusInternalServerError) return } - ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) + + writer.Header().Set("Content-Type", "application/json; charset=utf-8") + writer.WriteHeader(http.StatusOK) + _, err = writer.Write(respBody) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } } // TODO: check if any locks are needed around IP allocation. func (h *Headscale) handleAuthKey( - ctx *gin.Context, + writer http.ResponseWriter, + req *http.Request, machineKey key.MachinePublic, registerRequest tailcfg.RegisterRequest, ) { @@ -567,14 +673,23 @@ func (h *Headscale) handleAuthKey( Str("machine", registerRequest.Hostinfo.Hostname). Err(err). Msg("Cannot encode message") - ctx.String(http.StatusInternalServerError, "") + http.Error(writer, "Internal server error", http.StatusInternalServerError) machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name). Inc() return } - ctx.Data(http.StatusUnauthorized, "application/json; charset=utf-8", respBody) + writer.Header().Set("Content-Type", "application/json; charset=utf-8") + writer.WriteHeader(http.StatusUnauthorized) + _, err = writer.Write(respBody) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } + log.Error(). Caller(). Str("func", "handleAuthKey"). @@ -611,7 +726,16 @@ func (h *Headscale) handleAuthKey( machine.NodeKey = nodeKey machine.AuthKeyID = uint(pak.ID) - h.RefreshMachine(machine, registerRequest.Expiry) + err := h.RefreshMachine(machine, registerRequest.Expiry) + if err != nil { + log.Error(). + Caller(). + Str("machine", machine.Hostname). + Err(err). + Msg("Failed to refresh machine") + + return + } } else { now := time.Now().UTC() @@ -648,16 +772,24 @@ func (h *Headscale) handleAuthKey( Msg("could not register machine") machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name). Inc() - ctx.String( - http.StatusInternalServerError, - "could not register machine", - ) + http.Error(writer, "Internal server error", http.StatusInternalServerError) return } } - h.UsePreAuthKey(pak) + err = h.UsePreAuthKey(pak) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to use pre-auth key") + machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name). + Inc() + http.Error(writer, "Internal server error", http.StatusInternalServerError) + + return + } resp.MachineAuthorized = true resp.User = *pak.Namespace.toUser() @@ -671,13 +803,22 @@ func (h *Headscale) handleAuthKey( Msg("Cannot encode message") machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name). Inc() - ctx.String(http.StatusInternalServerError, "Extremely sad!") + http.Error(writer, "Internal server error", http.StatusInternalServerError) return } machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "success", pak.Namespace.Name). Inc() - ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) + writer.Header().Set("Content-Type", "application/json; charset=utf-8") + writer.WriteHeader(http.StatusOK) + _, err = writer.Write(respBody) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } + log.Info(). Str("func", "handleAuthKey"). Str("machine", registerRequest.Hostinfo.Hostname). diff --git a/app.go b/app.go index 78c1ad6b..e4e69105 100644 --- a/app.go +++ b/app.go @@ -18,6 +18,7 @@ import ( "github.com/coreos/go-oidc/v3/oidc" "github.com/gin-gonic/gin" + "github.com/gorilla/mux" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" @@ -93,6 +94,8 @@ type Headscale struct { registrationCache *cache.Cache ipAllocationMutex sync.Mutex + + shutdownChan chan struct{} } // Look up the TLS constant relative to user-supplied TLS client @@ -327,48 +330,74 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context, return handler(ctx, req) } -func (h *Headscale) httpAuthenticationMiddleware(ctx *gin.Context) { - log.Trace(). - Caller(). - Str("client_address", ctx.ClientIP()). - Msg("HTTP authentication invoked") - - authHeader := ctx.GetHeader("authorization") - - if !strings.HasPrefix(authHeader, AuthPrefix) { - log.Error(). +func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func( + writer http.ResponseWriter, + req *http.Request, + ) { + log.Trace(). Caller(). - Str("client_address", ctx.ClientIP()). - Msg(`missing "Bearer " prefix in "Authorization" header`) - ctx.AbortWithStatus(http.StatusUnauthorized) + Str("client_address", req.RemoteAddr). + Msg("HTTP authentication invoked") - return - } + authHeader := req.Header.Get("authorization") - valid, err := h.ValidateAPIKey(strings.TrimPrefix(authHeader, AuthPrefix)) - if err != nil { - log.Error(). - Caller(). - Err(err). - Str("client_address", ctx.ClientIP()). - Msg("failed to validate token") + if !strings.HasPrefix(authHeader, AuthPrefix) { + log.Error(). + Caller(). + Str("client_address", req.RemoteAddr). + Msg(`missing "Bearer " prefix in "Authorization" header`) + writer.WriteHeader(http.StatusUnauthorized) + _, err := writer.Write([]byte("Unauthorized")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } - ctx.AbortWithStatus(http.StatusInternalServerError) + return + } - return - } + valid, err := h.ValidateAPIKey(strings.TrimPrefix(authHeader, AuthPrefix)) + if err != nil { + log.Error(). + Caller(). + Err(err). + Str("client_address", req.RemoteAddr). + Msg("failed to validate token") - if !valid { - log.Info(). - Str("client_address", ctx.ClientIP()). - Msg("invalid token") + writer.WriteHeader(http.StatusInternalServerError) + _, err := writer.Write([]byte("Unauthorized")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } - ctx.AbortWithStatus(http.StatusUnauthorized) + return + } - return - } + if !valid { + log.Info(). + Str("client_address", req.RemoteAddr). + Msg("invalid token") - ctx.Next() + writer.WriteHeader(http.StatusUnauthorized) + _, err := writer.Write([]byte("Unauthorized")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } + + return + } + + next.ServeHTTP(writer, req) + }) } // ensureUnixSocketIsAbsent will check if the given path for headscales unix socket is clear @@ -391,39 +420,48 @@ func (h *Headscale) createPrometheusRouter() *gin.Engine { return promRouter } -func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *gin.Engine { - router := gin.Default() +func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *mux.Router { + router := mux.NewRouter() - router.GET( + router.HandleFunc( "/health", - func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"healthy": "ok"}) }, - ) - router.GET("/key", h.KeyHandler) - router.GET("/register", h.RegisterWebAPI) - router.POST("/machine/:id/map", h.PollNetMapHandler) - router.POST("/machine/:id", h.RegistrationHandler) - router.GET("/oidc/register/:mkey", h.RegisterOIDC) - router.GET("/oidc/callback", h.OIDCCallback) - router.GET("/apple", h.AppleConfigMessage) - router.GET("/apple/:platform", h.ApplePlatformConfig) - router.GET("/windows", h.WindowsConfigMessage) - router.GET("/windows/tailscale.reg", h.WindowsRegConfig) - router.GET("/swagger", SwaggerUI) - router.GET("/swagger/v1/openapiv2.json", SwaggerAPIv1) + func(writer http.ResponseWriter, req *http.Request) { + writer.WriteHeader(http.StatusOK) + _, err := writer.Write([]byte("{\"healthy\": \"ok\"}")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } + }).Methods(http.MethodGet) + + router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet) + router.HandleFunc("/register", h.RegisterWebAPI).Methods(http.MethodGet) + router.HandleFunc("/machine/{mkey}/map", h.PollNetMapHandler).Methods(http.MethodPost) + router.HandleFunc("/machine/{mkey}", h.RegistrationHandler).Methods(http.MethodPost) + router.HandleFunc("/oidc/register/{mkey}", h.RegisterOIDC).Methods(http.MethodGet) + router.HandleFunc("/oidc/callback", h.OIDCCallback).Methods(http.MethodGet) + router.HandleFunc("/apple", h.AppleConfigMessage).Methods(http.MethodGet) + router.HandleFunc("/apple/{platform}", h.ApplePlatformConfig).Methods(http.MethodGet) + router.HandleFunc("/windows", h.WindowsConfigMessage).Methods(http.MethodGet) + router.HandleFunc("/windows/tailscale.reg", h.WindowsRegConfig).Methods(http.MethodGet) + router.HandleFunc("/swagger", SwaggerUI).Methods(http.MethodGet) + router.HandleFunc("/swagger/v1/openapiv2.json", SwaggerAPIv1).Methods(http.MethodGet) if h.cfg.DERP.ServerEnabled { - router.Any("/derp", h.DERPHandler) - router.Any("/derp/probe", h.DERPProbeHandler) - router.Any("/bootstrap-dns", h.DERPBootstrapDNSHandler) + router.HandleFunc("/derp", h.DERPHandler) + router.HandleFunc("/derp/probe", h.DERPProbeHandler) + router.HandleFunc("/bootstrap-dns", h.DERPBootstrapDNSHandler) } - api := router.Group("/api") + api := router.PathPrefix("/api").Subrouter() api.Use(h.httpAuthenticationMiddleware) { - api.Any("/v1/*any", gin.WrapF(grpcMux.ServeHTTP)) + api.HandleFunc("/v1/*any", grpcMux.ServeHTTP) } - router.NoRoute(stdoutHandler) + router.PathPrefix("/").HandlerFunc(stdoutHandler) return router } @@ -631,6 +669,7 @@ func (h *Headscale) Serve() error { Msgf("listening and serving metrics on: %s", h.cfg.MetricsAddr) // Handle common process-killing signals so we can gracefully shut down: + h.shutdownChan = make(chan struct{}) sigc := make(chan os.Signal, 1) signal.Notify(sigc, syscall.SIGHUP, @@ -668,6 +707,8 @@ func (h *Headscale) Serve() error { Str("signal", sig.String()). Msg("Received signal to stop, shutting down gracefully") + h.shutdownChan <- struct{}{} + // Gracefully shut down servers ctx, cancel := context.WithTimeout(context.Background(), HTTPShutdownTimeout) if err := promHTTPServer.Shutdown(ctx); err != nil { @@ -831,13 +872,16 @@ func (h *Headscale) getLastStateChange(namespaces ...string) time.Time { } } -func stdoutHandler(ctx *gin.Context) { - body, _ := io.ReadAll(ctx.Request.Body) +func stdoutHandler( + writer http.ResponseWriter, + req *http.Request, +) { + body, _ := io.ReadAll(req.Body) log.Trace(). - Interface("header", ctx.Request.Header). - Interface("proto", ctx.Request.Proto). - Interface("url", ctx.Request.URL). + Interface("header", req.Header). + Interface("proto", req.Proto). + Interface("url", req.URL). Bytes("body", body). Msg("Request did not match") } diff --git a/db.go b/db.go index 17d83237..e412468d 100644 --- a/db.go +++ b/db.go @@ -89,7 +89,7 @@ func (h *Headscale) initDB() error { log.Error().Err(err).Msg("Error accessing db") } - for _, machine := range machines { + for item, machine := range machines { if machine.GivenName == "" { normalizedHostname, err := NormalizeToFQDNRules( machine.Hostname, @@ -103,7 +103,7 @@ func (h *Headscale) initDB() error { Msg("Failed to normalize machine hostname in DB migration") } - err = h.RenameMachine(&machine, normalizedHostname) + err = h.RenameMachine(&machines[item], normalizedHostname) if err != nil { log.Error(). Caller(). @@ -111,7 +111,6 @@ func (h *Headscale) initDB() error { Err(err). Msg("Failed to save normalized machine name in DB migration") } - } } } diff --git a/derp_server.go b/derp_server.go index d6fb47de..098ca53e 100644 --- a/derp_server.go +++ b/derp_server.go @@ -2,6 +2,7 @@ package headscale import ( "context" + "encoding/json" "fmt" "net" "net/http" @@ -10,7 +11,6 @@ import ( "strings" "time" - "github.com/gin-gonic/gin" "github.com/rs/zerolog/log" "tailscale.com/derp" "tailscale.com/net/stun" @@ -30,6 +30,7 @@ type DERPServer struct { } func (h *Headscale) NewDERPServer() (*DERPServer, error) { + log.Trace().Caller().Msg("Creating new embedded DERP server") server := derp.NewServer(key.NodePrivate(*h.privateKey), log.Info().Msgf) region, err := h.generateRegionLocalDERP() if err != nil { @@ -87,30 +88,48 @@ func (h *Headscale) generateRegionLocalDERP() (tailcfg.DERPRegion, error) { } localDERPregion.Nodes[0].STUNPort = portSTUN + log.Info().Caller().Msgf("DERP region: %+v", localDERPregion) + return localDERPregion, nil } -func (h *Headscale) DERPHandler(ctx *gin.Context) { - log.Trace().Caller().Msgf("/derp request from %v", ctx.ClientIP()) - up := strings.ToLower(ctx.Request.Header.Get("Upgrade")) +func (h *Headscale) DERPHandler( + writer http.ResponseWriter, + req *http.Request, +) { + log.Trace().Caller().Msgf("/derp request from %v", req.RemoteAddr) + up := strings.ToLower(req.Header.Get("Upgrade")) if up != "websocket" && up != "derp" { if up != "" { log.Warn().Caller().Msgf("Weird websockets connection upgrade: %q", up) } - ctx.String(http.StatusUpgradeRequired, "DERP requires connection upgrade") + writer.Header().Set("Content-Type", "text/plain") + writer.WriteHeader(http.StatusUpgradeRequired) + _, err := writer.Write([]byte("DERP requires connection upgrade")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } return } - fastStart := ctx.Request.Header.Get(fastStartHeader) == "1" + fastStart := req.Header.Get(fastStartHeader) == "1" - hijacker, ok := ctx.Writer.(http.Hijacker) + hijacker, ok := writer.(http.Hijacker) if !ok { log.Error().Caller().Msg("DERP requires Hijacker interface from Gin") - ctx.String( - http.StatusInternalServerError, - "HTTP does not support general TCP support", - ) + writer.Header().Set("Content-Type", "text/plain") + writer.WriteHeader(http.StatusInternalServerError) + _, err := writer.Write([]byte("HTTP does not support general TCP support")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } return } @@ -118,13 +137,19 @@ func (h *Headscale) DERPHandler(ctx *gin.Context) { netConn, conn, err := hijacker.Hijack() if err != nil { log.Error().Caller().Err(err).Msgf("Hijack failed") - ctx.String( - http.StatusInternalServerError, - "HTTP does not support general TCP support", - ) + writer.Header().Set("Content-Type", "text/plain") + writer.WriteHeader(http.StatusInternalServerError) + _, err = writer.Write([]byte("HTTP does not support general TCP support")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } return } + log.Trace().Caller().Msgf("Hijacked connection from %v", req.RemoteAddr) if !fastStart { pubKey := h.privateKey.Public() @@ -143,12 +168,23 @@ func (h *Headscale) DERPHandler(ctx *gin.Context) { // DERPProbeHandler is the endpoint that js/wasm clients hit to measure // DERP latency, since they can't do UDP STUN queries. -func (h *Headscale) DERPProbeHandler(ctx *gin.Context) { - switch ctx.Request.Method { +func (h *Headscale) DERPProbeHandler( + writer http.ResponseWriter, + req *http.Request, +) { + switch req.Method { case "HEAD", "GET": - ctx.Writer.Header().Set("Access-Control-Allow-Origin", "*") + writer.Header().Set("Access-Control-Allow-Origin", "*") + writer.WriteHeader(http.StatusOK) default: - ctx.String(http.StatusMethodNotAllowed, "bogus probe method") + writer.WriteHeader(http.StatusMethodNotAllowed) + _, err := writer.Write([]byte("bogus probe method")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } } } @@ -159,15 +195,18 @@ func (h *Headscale) DERPProbeHandler(ctx *gin.Context) { // The initial implementation is here https://github.com/tailscale/tailscale/pull/1406 // They have a cache, but not clear if that is really necessary at Headscale, uh, scale. // An example implementation is found here https://derp.tailscale.com/bootstrap-dns -func (h *Headscale) DERPBootstrapDNSHandler(ctx *gin.Context) { +func (h *Headscale) DERPBootstrapDNSHandler( + writer http.ResponseWriter, + req *http.Request, +) { dnsEntries := make(map[string][]net.IP) resolvCtx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - var r net.Resolver + var resolver net.Resolver for _, region := range h.DERPMap.Regions { for _, node := range region.Nodes { // we don't care if we override some nodes - addrs, err := r.LookupIP(resolvCtx, "ip", node.HostName) + addrs, err := resolver.LookupIP(resolvCtx, "ip", node.HostName) if err != nil { log.Trace(). Caller(). @@ -179,7 +218,15 @@ func (h *Headscale) DERPBootstrapDNSHandler(ctx *gin.Context) { dnsEntries[node.HostName] = addrs } } - ctx.JSON(http.StatusOK, dnsEntries) + writer.Header().Set("Content-Type", "application/json") + writer.WriteHeader(http.StatusOK) + err := json.NewEncoder(writer).Encode(dnsEntries) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } } // ServeSTUN starts a STUN server on the configured addr. diff --git a/flake.nix b/flake.nix index e7977e98..afa8c8bb 100644 --- a/flake.nix +++ b/flake.nix @@ -24,7 +24,7 @@ # When updating go.mod or go.sum, a new sha will need to be calculated, # update this if you have a mismatch after doing a change to thos files. - vendorSha256 = "sha256-j/hI6vP92UmcexFfzCe5fkGE8QUdQdNajSxMGib175Q="; + vendorSha256 = "sha256-T6rH+aqofFmCPxDfoA5xd3kNUJeZkT4GRyuFEnenps8="; ldflags = [ "-s" "-w" "-X github.com/juanfont/headscale/cmd/headscale/cli.Version=v${version}" ]; }; diff --git a/go.mod b/go.mod index 70662579..e10ae35e 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/gin-gonic/gin v1.7.7 github.com/glebarez/sqlite v1.4.3 github.com/gofrs/uuid v4.2.0+incompatible + github.com/gorilla/mux v1.8.0 github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 github.com/grpc-ecosystem/grpc-gateway/v2 v2.10.0 github.com/klauspost/compress v1.15.4 diff --git a/go.sum b/go.sum index 9423737e..b4d03b90 100644 --- a/go.sum +++ b/go.sum @@ -403,6 +403,7 @@ github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORR github.com/gordonklaus/ineffassign v0.0.0-20200309095847-7953dde2c7bf/go.mod h1:cuNKsD1zp2v6XfE/orVX2QE1LC+i254ceGcVeDT3pTU= github.com/gordonklaus/ineffassign v0.0.0-20210225214923-2e10b2664254/go.mod h1:M9mZEtGIsR1oDaZagNPNG9iq9n2HrhZ17dsXk73V3Lw= github.com/gorhill/cronexpr v0.0.0-20180427100037-88b0669f7d75/go.mod h1:g2644b03hfBX9Ov0ZBDgXXens4rxSxmqFBbhvKv2yVA= +github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/gorilla/websocket v0.0.0-20170926233335-4201258b820c/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= diff --git a/machine.go b/machine.go index cf6b8876..1bed2955 100644 --- a/machine.go +++ b/machine.go @@ -27,6 +27,7 @@ const ( errCouldNotConvertMachineInterface = Error("failed to convert machine interface") errHostnameTooLong = Error("Hostname too long") MachineGivenNameHashLength = 8 + MachineGivenNameTrimSize = 2 ) const ( @@ -898,7 +899,7 @@ func (machine *Machine) RoutesToProto() *v1.Routes { func (h *Headscale) GenerateGivenName(suppliedName string) (string, error) { // If a hostname is or will be longer than 63 chars after adding the hash, // it needs to be trimmed. - trimmedHostnameLength := labelHostnameLength - MachineGivenNameHashLength - 2 + trimmedHostnameLength := labelHostnameLength - MachineGivenNameHashLength - MachineGivenNameTrimSize normalizedHostname, err := NormalizeToFQDNRules( suppliedName, diff --git a/machine_test.go b/machine_test.go index 48ccb153..a06d0db2 100644 --- a/machine_test.go +++ b/machine_test.go @@ -249,10 +249,12 @@ func (s *Suite) TestExpireMachine(c *check.C) { machineFromDB, err := app.GetMachine("test", "testmachine") c.Assert(err, check.IsNil) + c.Assert(machineFromDB, check.NotNil) c.Assert(machineFromDB.isExpired(), check.Equals, false) - app.ExpireMachine(machineFromDB) + err = app.ExpireMachine(machineFromDB) + c.Assert(err, check.IsNil) c.Assert(machineFromDB.isExpired(), check.Equals, true) } @@ -918,6 +920,7 @@ func TestHeadscale_GenerateGivenName(t *testing.T) { err, tt.wantErr, ) + return } diff --git a/oidc.go b/oidc.go index 38a1eb36..8b5f0242 100644 --- a/oidc.go +++ b/oidc.go @@ -13,7 +13,7 @@ import ( "time" "github.com/coreos/go-oidc/v3/oidc" - "github.com/gin-gonic/gin" + "github.com/gorilla/mux" "github.com/rs/zerolog/log" "golang.org/x/oauth2" "tailscale.com/types/key" @@ -63,10 +63,17 @@ func (h *Headscale) initOIDC() error { // RegisterOIDC redirects to the OIDC provider for authentication // Puts machine key in cache so the callback can retrieve it using the oidc state param // Listens in /oidc/register/:mKey. -func (h *Headscale) RegisterOIDC(ctx *gin.Context) { - machineKeyStr := ctx.Param("mkey") - if machineKeyStr == "" { - ctx.String(http.StatusBadRequest, "Wrong params") +func (h *Headscale) RegisterOIDC( + writer http.ResponseWriter, + req *http.Request, +) { + vars := mux.Vars(req) + machineKeyStr, ok := vars["mkey"] + if !ok || machineKeyStr == "" { + log.Error(). + Caller(). + Msg("Missing machine key in URL") + http.Error(writer, "Missing machine key in URL", http.StatusBadRequest) return } @@ -81,7 +88,7 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) { log.Error(). Caller(). Msg("could not read 16 bytes from rand") - ctx.String(http.StatusInternalServerError, "could not read 16 bytes from rand") + http.Error(writer, "Internal server error", http.StatusInternalServerError) return } @@ -101,7 +108,7 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) { authURL := h.oauth2Config.AuthCodeURL(stateStr, extras...) log.Debug().Msgf("Redirecting to %s for authentication", authURL) - ctx.Redirect(http.StatusFound, authURL) + http.Redirect(writer, req, authURL, http.StatusFound) } type oidcCallbackTemplateConfig struct { @@ -125,12 +132,23 @@ var oidcCallbackTemplate = template.Must( // TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities // TODO: Add groups information from OIDC tokens into machine HostInfo // Listens in /oidc/callback. -func (h *Headscale) OIDCCallback(ctx *gin.Context) { - code := ctx.Query("code") - state := ctx.Query("state") +func (h *Headscale) OIDCCallback( + writer http.ResponseWriter, + req *http.Request, +) { + code := req.URL.Query().Get("code") + state := req.URL.Query().Get("state") if code == "" || state == "" { - ctx.String(http.StatusBadRequest, "Wrong params") + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusBadRequest) + _, err := writer.Write([]byte("Wrong params")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } return } @@ -141,7 +159,15 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { Err(err). Caller(). Msg("Could not exchange code for token") - ctx.String(http.StatusBadRequest, "Could not exchange code for token") + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusBadRequest) + _, err := writer.Write([]byte("Could not exchange code for token")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } return } @@ -154,7 +180,15 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string) if !rawIDTokenOK { - ctx.String(http.StatusBadRequest, "Could not extract ID Token") + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusBadRequest) + _, err := writer.Write([]byte("Could not extract ID Token")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } return } @@ -167,7 +201,15 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { Err(err). Caller(). Msg("failed to verify id token") - ctx.String(http.StatusBadRequest, "Failed to verify id token") + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusBadRequest) + _, err := writer.Write([]byte("Failed to verify id token")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } return } @@ -186,10 +228,15 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { Err(err). Caller(). Msg("Failed to decode id token claims") - ctx.String( - http.StatusBadRequest, - "Failed to decode id token claims", - ) + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusBadRequest) + _, err := writer.Write([]byte("Failed to decode id token claims")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } return } @@ -199,10 +246,15 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { if at := strings.LastIndex(claims.Email, "@"); at < 0 || !IsStringInSlice(h.cfg.OIDC.AllowedDomains, claims.Email[at+1:]) { log.Error().Msg("authenticated principal does not match any allowed domain") - ctx.String( - http.StatusBadRequest, - "unauthorized principal (domain mismatch)", - ) + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusBadRequest) + _, err := writer.Write([]byte("unauthorized principal (domain mismatch)")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } return } @@ -212,7 +264,15 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { if len(h.cfg.OIDC.AllowedUsers) > 0 && !IsStringInSlice(h.cfg.OIDC.AllowedUsers, claims.Email) { log.Error().Msg("authenticated principal does not match any allowed user") - ctx.String(http.StatusBadRequest, "unauthorized principal (user mismatch)") + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusBadRequest) + _, err := writer.Write([]byte("unauthorized principal (user mismatch)")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } return } @@ -223,7 +283,15 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { if !machineKeyFound { log.Error(). Msg("requested machine state key expired before authorisation completed") - ctx.String(http.StatusBadRequest, "state has expired") + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusBadRequest) + _, err := writer.Write([]byte("state has expired")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } return } @@ -237,17 +305,30 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { if err != nil { log.Error(). Msg("could not parse machine public key") - ctx.String(http.StatusBadRequest, "could not parse public key") + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusBadRequest) + _, err := writer.Write([]byte("could not parse public key")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } return } if !machineKeyOK { log.Error().Msg("could not get machine key from cache") - ctx.String( - http.StatusInternalServerError, - "could not get machine key from cache", - ) + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusInternalServerError) + _, err := writer.Write([]byte("could not get machine key from cache")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } return } @@ -264,7 +345,16 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { Str("machine", machine.Hostname). Msg("machine already registered, reauthenticating") - h.RefreshMachine(machine, time.Time{}) + err := h.RefreshMachine(machine, time.Time{}) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to refresh machine") + http.Error(writer, "Failed to refresh machine", http.StatusInternalServerError) + + return + } var content bytes.Buffer if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{ @@ -276,14 +366,29 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { Str("type", "reauthenticate"). Err(err). Msg("Could not render OIDC callback template") - ctx.Data( - http.StatusInternalServerError, - "text/html; charset=utf-8", - []byte("Could not render OIDC callback template"), - ) + + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusInternalServerError) + _, err := writer.Write([]byte("Could not render OIDC callback template")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } + + return } - ctx.Data(http.StatusOK, "text/html; charset=utf-8", content.Bytes()) + writer.Header().Set("Content-Type", "text/html; charset=utf-8") + writer.WriteHeader(http.StatusOK) + _, err = writer.Write(content.Bytes()) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } return } @@ -294,10 +399,15 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { ) if err != nil { log.Error().Err(err).Caller().Msgf("couldn't normalize email") - ctx.String( - http.StatusInternalServerError, - "couldn't normalize email", - ) + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusInternalServerError) + _, err := writer.Write([]byte("couldn't normalize email")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } return } @@ -314,10 +424,15 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { Err(err). Caller(). Msgf("could not create new namespace '%s'", namespaceName) - ctx.String( - http.StatusInternalServerError, - "could not create new namespace", - ) + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusInternalServerError) + _, err := writer.Write([]byte("could not create namespace")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } return } @@ -327,10 +442,15 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { Err(err). Str("namespace", namespaceName). Msg("could not find or create namespace") - ctx.String( - http.StatusInternalServerError, - "could not find or create namespace", - ) + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusInternalServerError) + _, err := writer.Write([]byte("could not find or create namespace")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } return } @@ -347,10 +467,15 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { Caller(). Err(err). Msg("could not register machine") - ctx.String( - http.StatusInternalServerError, - "could not register machine", - ) + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusInternalServerError) + _, err := writer.Write([]byte("could not register machine")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } return } @@ -365,12 +490,27 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { Str("type", "authenticate"). Err(err). Msg("Could not render OIDC callback template") - ctx.Data( - http.StatusInternalServerError, - "text/html; charset=utf-8", - []byte("Could not render OIDC callback template"), - ) + + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusInternalServerError) + _, err := writer.Write([]byte("Could not render OIDC callback template")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } + + return } - ctx.Data(http.StatusOK, "text/html; charset=utf-8", content.Bytes()) + writer.Header().Set("Content-Type", "text/html; charset=utf-8") + writer.WriteHeader(http.StatusOK) + _, err = writer.Write(content.Bytes()) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } } diff --git a/platform_config.go b/platform_config.go index d36a37ce..6bb195c7 100644 --- a/platform_config.go +++ b/platform_config.go @@ -6,13 +6,16 @@ import ( "net/http" textTemplate "text/template" - "github.com/gin-gonic/gin" "github.com/gofrs/uuid" + "github.com/gorilla/mux" "github.com/rs/zerolog/log" ) // WindowsConfigMessage shows a simple message in the browser for how to configure the Windows Tailscale client. -func (h *Headscale) WindowsConfigMessage(ctx *gin.Context) { +func (h *Headscale) WindowsConfigMessage( + writer http.ResponseWriter, + req *http.Request, +) { winTemplate := template.Must(template.New("windows").Parse(` @@ -63,20 +66,36 @@ REG ADD "HKLM\Software\Tailscale IPN" /v LoginURL /t REG_SZ /d "{{.URL}}" Str("handler", "WindowsRegConfig"). Err(err). Msg("Could not render Windows index template") - ctx.Data( - http.StatusInternalServerError, - "text/html; charset=utf-8", - []byte("Could not render Windows index template"), - ) + + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusInternalServerError) + _, err := writer.Write([]byte("Could not render Windows index template")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } return } - ctx.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes()) + writer.Header().Set("Content-Type", "text/html; charset=utf-8") + writer.WriteHeader(http.StatusOK) + _, err := writer.Write(payload.Bytes()) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } } // WindowsRegConfig generates and serves a .reg file configured with the Headscale server address. -func (h *Headscale) WindowsRegConfig(ctx *gin.Context) { +func (h *Headscale) WindowsRegConfig( + writer http.ResponseWriter, + req *http.Request, +) { config := WindowsRegistryConfig{ URL: h.cfg.ServerURL, } @@ -87,24 +106,36 @@ func (h *Headscale) WindowsRegConfig(ctx *gin.Context) { Str("handler", "WindowsRegConfig"). Err(err). Msg("Could not render Apple macOS template") - ctx.Data( - http.StatusInternalServerError, - "text/html; charset=utf-8", - []byte("Could not render Windows registry template"), - ) + + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusInternalServerError) + _, err := writer.Write([]byte("Could not render Windows registry template")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } return } - ctx.Data( - http.StatusOK, - "text/x-ms-regedit; charset=utf-8", - content.Bytes(), - ) + writer.Header().Set("Content-Type", "text/x-ms-regedit; charset=utf-8") + writer.WriteHeader(http.StatusOK) + _, err := writer.Write(content.Bytes()) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } } // AppleConfigMessage shows a simple message in the browser to point the user to the iOS/MacOS profile and instructions for how to install it. -func (h *Headscale) AppleConfigMessage(ctx *gin.Context) { +func (h *Headscale) AppleConfigMessage( + writer http.ResponseWriter, + req *http.Request, +) { appleTemplate := template.Must(template.New("apple").Parse(` @@ -165,20 +196,45 @@ func (h *Headscale) AppleConfigMessage(ctx *gin.Context) { Str("handler", "AppleMobileConfig"). Err(err). Msg("Could not render Apple index template") - ctx.Data( - http.StatusInternalServerError, - "text/html; charset=utf-8", - []byte("Could not render Apple index template"), - ) + + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusInternalServerError) + _, err := writer.Write([]byte("Could not render Apple index template")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } return } - ctx.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes()) + writer.Header().Set("Content-Type", "text/html; charset=utf-8") + writer.WriteHeader(http.StatusOK) + _, err := writer.Write(payload.Bytes()) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } } -func (h *Headscale) ApplePlatformConfig(ctx *gin.Context) { - platform := ctx.Param("platform") +func (h *Headscale) ApplePlatformConfig( + writer http.ResponseWriter, + req *http.Request, +) { + vars := mux.Vars(req) + platform, ok := vars["platform"] + if !ok { + log.Error(). + Str("handler", "ApplePlatformConfig"). + Msg("No platform specified") + http.Error(writer, "No platform specified", http.StatusBadRequest) + + return + } id, err := uuid.NewV4() if err != nil { @@ -186,11 +242,16 @@ func (h *Headscale) ApplePlatformConfig(ctx *gin.Context) { Str("handler", "ApplePlatformConfig"). Err(err). Msg("Failed not create UUID") - ctx.Data( - http.StatusInternalServerError, - "text/html; charset=utf-8", - []byte("Failed to create UUID"), - ) + + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusInternalServerError) + _, err := writer.Write([]byte("Failed to create UUID")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } return } @@ -201,11 +262,16 @@ func (h *Headscale) ApplePlatformConfig(ctx *gin.Context) { Str("handler", "ApplePlatformConfig"). Err(err). Msg("Failed not create UUID") - ctx.Data( - http.StatusInternalServerError, - "text/html; charset=utf-8", - []byte("Failed to create UUID"), - ) + + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusInternalServerError) + _, err := writer.Write([]byte("Failed to create content UUID")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } return } @@ -224,11 +290,16 @@ func (h *Headscale) ApplePlatformConfig(ctx *gin.Context) { Str("handler", "ApplePlatformConfig"). Err(err). Msg("Could not render Apple macOS template") - ctx.Data( - http.StatusInternalServerError, - "text/html; charset=utf-8", - []byte("Could not render Apple macOS template"), - ) + + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusInternalServerError) + _, err := writer.Write([]byte("Could not render Apple macOS template")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } return } @@ -238,20 +309,29 @@ func (h *Headscale) ApplePlatformConfig(ctx *gin.Context) { Str("handler", "ApplePlatformConfig"). Err(err). Msg("Could not render Apple iOS template") - ctx.Data( - http.StatusInternalServerError, - "text/html; charset=utf-8", - []byte("Could not render Apple iOS template"), - ) + + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusInternalServerError) + _, err := writer.Write([]byte("Could not render Apple iOS template")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } return } default: - ctx.Data( - http.StatusOK, - "text/html; charset=utf-8", - []byte("Invalid platform, only ios and macos is supported"), - ) + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusBadRequest) + _, err := writer.Write([]byte("Invalid platform, only ios and macos is supported")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } return } @@ -268,20 +348,29 @@ func (h *Headscale) ApplePlatformConfig(ctx *gin.Context) { Str("handler", "ApplePlatformConfig"). Err(err). Msg("Could not render Apple platform template") - ctx.Data( - http.StatusInternalServerError, - "text/html; charset=utf-8", - []byte("Could not render Apple platform template"), - ) + + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusInternalServerError) + _, err := writer.Write([]byte("Could not render Apple platform template")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } return } - ctx.Data( - http.StatusOK, - "application/x-apple-aspen-config; charset=utf-8", - content.Bytes(), - ) + writer.Header().Set("Content-Type", "application/x-apple-aspen-config; charset=utf-8") + writer.WriteHeader(http.StatusOK) + _, err = writer.Write(content.Bytes()) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } } type WindowsRegistryConfig struct { diff --git a/poll.go b/poll.go index 239f260b..9218495d 100644 --- a/poll.go +++ b/poll.go @@ -8,7 +8,7 @@ import ( "net/http" "time" - "github.com/gin-gonic/gin" + "github.com/gorilla/mux" "github.com/rs/zerolog/log" "gorm.io/gorm" "tailscale.com/tailcfg" @@ -33,13 +33,25 @@ const machineNameContextKey = contextKey("machineName") // only after their first request (marked with the ReadOnly field). // // At this moment the updates are sent in a quite horrendous way, but they kinda work. -func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { +func (h *Headscale) PollNetMapHandler( + writer http.ResponseWriter, + req *http.Request, +) { + vars := mux.Vars(req) + machineKeyStr, ok := vars["mkey"] + if !ok || machineKeyStr == "" { + log.Error(). + Str("handler", "PollNetMap"). + Msg("No machine key in request") + http.Error(writer, "No machine key in request", http.StatusBadRequest) + + return + } log.Trace(). Str("handler", "PollNetMap"). - Str("id", ctx.Param("id")). + Str("id", machineKeyStr). Msg("PollNetMapHandler called") - body, _ := io.ReadAll(ctx.Request.Body) - machineKeyStr := ctx.Param("id") + body, _ := io.ReadAll(req.Body) var machineKey key.MachinePublic err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr))) @@ -48,18 +60,19 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { Str("handler", "PollNetMap"). Err(err). Msg("Cannot parse client key") - ctx.String(http.StatusBadRequest, "") + + http.Error(writer, "Cannot parse client key", http.StatusBadRequest) return } - req := tailcfg.MapRequest{} - err = decode(body, &req, &machineKey, h.privateKey) + mapRequest := tailcfg.MapRequest{} + err = decode(body, &mapRequest, &machineKey, h.privateKey) if err != nil { log.Error(). Str("handler", "PollNetMap"). Err(err). Msg("Cannot decode message") - ctx.String(http.StatusBadRequest, "") + http.Error(writer, "Cannot decode message", http.StatusBadRequest) return } @@ -70,26 +83,27 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { log.Warn(). Str("handler", "PollNetMap"). Msgf("Ignoring request, cannot find machine with key %s", machineKey.String()) - ctx.String(http.StatusUnauthorized, "") + + http.Error(writer, "", http.StatusUnauthorized) return } log.Error(). Str("handler", "PollNetMap"). Msgf("Failed to fetch machine from the database with Machine key: %s", machineKey.String()) - ctx.String(http.StatusInternalServerError, "") + http.Error(writer, "", http.StatusInternalServerError) return } log.Trace(). Str("handler", "PollNetMap"). - Str("id", ctx.Param("id")). + Str("id", machineKeyStr). Str("machine", machine.Hostname). Msg("Found machine in database") - machine.Hostname = req.Hostinfo.Hostname - machine.HostInfo = HostInfo(*req.Hostinfo) - machine.DiscoKey = DiscoPublicKeyStripPrefix(req.DiscoKey) + machine.Hostname = mapRequest.Hostinfo.Hostname + machine.HostInfo = HostInfo(*mapRequest.Hostinfo) + machine.DiscoKey = DiscoPublicKeyStripPrefix(mapRequest.DiscoKey) now := time.Now().UTC() // update ACLRules with peer informations (to update server tags if necessary) @@ -111,8 +125,8 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { // // The intended use is for clients to discover the DERP map at start-up // before their first real endpoint update. - if !req.ReadOnly { - machine.Endpoints = req.Endpoints + if !mapRequest.ReadOnly { + machine.Endpoints = mapRequest.Endpoints machine.LastSeen = &now } @@ -120,25 +134,25 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { if err != nil { log.Error(). Str("handler", "PollNetMap"). - Str("id", ctx.Param("id")). + Str("id", machineKeyStr). Str("machine", machine.Hostname). Err(err). Msg("Failed to persist/update machine in the database") - ctx.String(http.StatusInternalServerError, ":(") + http.Error(writer, "", http.StatusInternalServerError) return } } - data, err := h.getMapResponse(machineKey, req, machine) + data, err := h.getMapResponse(machineKey, mapRequest, machine) if err != nil { log.Error(). Str("handler", "PollNetMap"). - Str("id", ctx.Param("id")). + Str("id", machineKeyStr). Str("machine", machine.Hostname). Err(err). Msg("Failed to get Map response") - ctx.String(http.StatusInternalServerError, ":(") + http.Error(writer, "", http.StatusInternalServerError) return } @@ -150,19 +164,28 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { // Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696 log.Debug(). Str("handler", "PollNetMap"). - Str("id", ctx.Param("id")). + Str("id", machineKeyStr). Str("machine", machine.Hostname). - Bool("readOnly", req.ReadOnly). - Bool("omitPeers", req.OmitPeers). - Bool("stream", req.Stream). + Bool("readOnly", mapRequest.ReadOnly). + Bool("omitPeers", mapRequest.OmitPeers). + Bool("stream", mapRequest.Stream). Msg("Client map request processed") - if req.ReadOnly { + if mapRequest.ReadOnly { log.Info(). Str("handler", "PollNetMap"). Str("machine", machine.Hostname). Msg("Client is starting up. Probably interested in a DERP map") - ctx.Data(http.StatusOK, "application/json; charset=utf-8", data) + + writer.Header().Set("Content-Type", "application/json; charset=utf-8") + writer.WriteHeader(http.StatusOK) + _, err := writer.Write(data) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } return } @@ -177,7 +200,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { // Only create update channel if it has not been created log.Trace(). Str("handler", "PollNetMap"). - Str("id", ctx.Param("id")). + Str("id", machineKeyStr). Str("machine", machine.Hostname). Msg("Loading or creating update channel") @@ -189,13 +212,20 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { keepAliveChan := make(chan []byte) - if req.OmitPeers && !req.Stream { + if mapRequest.OmitPeers && !mapRequest.Stream { log.Info(). Str("handler", "PollNetMap"). Str("machine", machine.Hostname). Msg("Client sent endpoint update and is ok with a response without peer list") - ctx.Data(http.StatusOK, "application/json; charset=utf-8", data) - + writer.Header().Set("Content-Type", "application/json; charset=utf-8") + writer.WriteHeader(http.StatusOK) + _, err := writer.Write(data) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } // It sounds like we should update the nodes when we have received a endpoint update // even tho the comments in the tailscale code dont explicitly say so. updateRequestsFromNode.WithLabelValues(machine.Namespace.Name, machine.Hostname, "endpoint-update"). @@ -203,12 +233,12 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { updateChan <- struct{}{} return - } else if req.OmitPeers && req.Stream { + } else if mapRequest.OmitPeers && mapRequest.Stream { log.Warn(). Str("handler", "PollNetMap"). Str("machine", machine.Hostname). Msg("Ignoring request, don't know how to handle it") - ctx.String(http.StatusBadRequest, "") + http.Error(writer, "", http.StatusBadRequest) return } @@ -232,9 +262,10 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { updateChan <- struct{}{} h.PollNetMapStream( - ctx, - machine, + writer, req, + machine, + mapRequest, machineKey, pollDataChan, keepAliveChan, @@ -242,7 +273,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { ) log.Trace(). Str("handler", "PollNetMap"). - Str("id", ctx.Param("id")). + Str("id", machineKeyStr). Str("machine", machine.Hostname). Msg("Finished stream, closing PollNetMap session") } @@ -251,7 +282,8 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { // stream logic, ensuring we communicate updates and data // to the connected clients. func (h *Headscale) PollNetMapStream( - ctx *gin.Context, + writer http.ResponseWriter, + req *http.Request, machine *Machine, mapRequest tailcfg.MapRequest, machineKey key.MachinePublic, @@ -259,51 +291,31 @@ func (h *Headscale) PollNetMapStream( keepAliveChan chan []byte, updateChan chan struct{}, ) { - { - machine, err := h.GetMachineByMachineKey(machineKey) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - log.Warn(). - Str("handler", "PollNetMap"). - Msgf("Ignoring request, cannot find machine with key %s", machineKey.String()) - ctx.String(http.StatusUnauthorized, "") + ctx := context.WithValue(req.Context(), machineNameContextKey, machine.Hostname) - return - } - log.Error(). - Str("handler", "PollNetMap"). - Msgf("Failed to fetch machine from the database with Machine key: %s", machineKey.String()) - ctx.String(http.StatusInternalServerError, "") + ctx, cancel := context.WithCancel(ctx) + defer cancel() - return - } + go h.scheduledPollWorker( + ctx, + updateChan, + keepAliveChan, + machineKey, + mapRequest, + machine, + ) - ctx := context.WithValue(ctx.Request.Context(), machineNameContextKey, machine.Hostname) + log.Trace(). + Str("handler", "PollNetMapStream"). + Str("machine", machine.Hostname). + Msg("Waiting for data to stream...") - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - go h.scheduledPollWorker( - ctx, - updateChan, - keepAliveChan, - machineKey, - mapRequest, - machine, - ) - } - - ctx.Stream(func(writer io.Writer) bool { - log.Trace(). - Str("handler", "PollNetMapStream"). - Str("machine", machine.Hostname). - Msg("Waiting for data to stream...") - - log.Trace(). - Str("handler", "PollNetMapStream"). - Str("machine", machine.Hostname). - Msgf("pollData is %#v, keepAliveChan is %#v, updateChan is %#v", pollDataChan, keepAliveChan, updateChan) + log.Trace(). + Str("handler", "PollNetMapStream"). + Str("machine", machine.Hostname). + Msgf("pollData is %#v, keepAliveChan is %#v, updateChan is %#v", pollDataChan, keepAliveChan, updateChan) + for { select { case data := <-pollDataChan: log.Trace(). @@ -321,8 +333,21 @@ func (h *Headscale) PollNetMapStream( Err(err). Msg("Cannot write data") - return false + return } + + flusher, ok := writer.(http.Flusher) + if !ok { + log.Error(). + Caller(). + Str("handler", "PollNetMapStream"). + Str("machine", machine.Hostname). + Str("channel", "pollData"). + Msg("Cannot cast writer to http.Flusher") + } else { + flusher.Flush() + } + log.Trace(). Str("handler", "PollNetMapStream"). Str("machine", machine.Hostname). @@ -343,7 +368,7 @@ func (h *Headscale) PollNetMapStream( // client has been removed from database // since the stream opened, terminate connection. - return false + return } now := time.Now().UTC() machine.LastSeen = &now @@ -360,16 +385,16 @@ func (h *Headscale) PollNetMapStream( Str("channel", "pollData"). Err(err). Msg("Cannot update machine LastSuccessfulUpdate") - } else { - log.Trace(). - Str("handler", "PollNetMapStream"). - Str("machine", machine.Hostname). - Str("channel", "pollData"). - Int("bytes", len(data)). - Msg("Machine entry in database updated successfully after sending pollData") + + return } - return true + log.Trace(). + Str("handler", "PollNetMapStream"). + Str("machine", machine.Hostname). + Str("channel", "pollData"). + Int("bytes", len(data)). + Msg("Machine entry in database updated successfully after sending data") case data := <-keepAliveChan: log.Trace(). @@ -387,8 +412,20 @@ func (h *Headscale) PollNetMapStream( Err(err). Msg("Cannot write keep alive message") - return false + return } + flusher, ok := writer.(http.Flusher) + if !ok { + log.Error(). + Caller(). + Str("handler", "PollNetMapStream"). + Str("machine", machine.Hostname). + Str("channel", "keepAlive"). + Msg("Cannot cast writer to http.Flusher") + } else { + flusher.Flush() + } + log.Trace(). Str("handler", "PollNetMapStream"). Str("machine", machine.Hostname). @@ -409,7 +446,7 @@ func (h *Headscale) PollNetMapStream( // client has been removed from database // since the stream opened, terminate connection. - return false + return } now := time.Now().UTC() machine.LastSeen = &now @@ -421,16 +458,16 @@ func (h *Headscale) PollNetMapStream( Str("channel", "keepAlive"). Err(err). Msg("Cannot update machine LastSeen") - } else { - log.Trace(). - Str("handler", "PollNetMapStream"). - Str("machine", machine.Hostname). - Str("channel", "keepAlive"). - Int("bytes", len(data)). - Msg("Machine updated successfully after sending keep alive") + + return } - return true + log.Trace(). + Str("handler", "PollNetMapStream"). + Str("machine", machine.Hostname). + Str("channel", "keepAlive"). + Int("bytes", len(data)). + Msg("Machine updated successfully after sending keep alive") case <-updateChan: log.Trace(). @@ -440,6 +477,7 @@ func (h *Headscale) PollNetMapStream( Msg("Received a request for update") updateRequestsReceivedOnChannel.WithLabelValues(machine.Namespace.Name, machine.Hostname). Inc() + if h.isOutdated(machine) { var lastUpdate time.Time if machine.LastSuccessfulUpdate != nil { @@ -459,6 +497,8 @@ func (h *Headscale) PollNetMapStream( Str("channel", "update"). Err(err). Msg("Could not get the map update") + + return } _, err = writer.Write(data) if err != nil { @@ -471,8 +511,21 @@ func (h *Headscale) PollNetMapStream( updateRequestsSentToNode.WithLabelValues(machine.Namespace.Name, machine.Hostname, "failed"). Inc() - return false + return } + + flusher, ok := writer.(http.Flusher) + if !ok { + log.Error(). + Caller(). + Str("handler", "PollNetMapStream"). + Str("machine", machine.Hostname). + Str("channel", "update"). + Msg("Cannot cast writer to http.Flusher") + } else { + flusher.Flush() + } + log.Trace(). Str("handler", "PollNetMapStream"). Str("machine", machine.Hostname). @@ -499,7 +552,7 @@ func (h *Headscale) PollNetMapStream( // client has been removed from database // since the stream opened, terminate connection. - return false + return } now := time.Now().UTC() @@ -515,6 +568,8 @@ func (h *Headscale) PollNetMapStream( Str("channel", "update"). Err(err). Msg("Cannot update machine LastSuccessfulUpdate") + + return } } else { var lastUpdate time.Time @@ -529,9 +584,7 @@ func (h *Headscale) PollNetMapStream( Msgf("%s is up to date", machine.Hostname) } - return true - - case <-ctx.Request.Context().Done(): + case <-ctx.Done(): log.Info(). Str("handler", "PollNetMapStream"). Str("machine", machine.Hostname). @@ -550,7 +603,7 @@ func (h *Headscale) PollNetMapStream( // client has been removed from database // since the stream opened, terminate connection. - return false + return } now := time.Now().UTC() machine.LastSeen = &now @@ -564,9 +617,18 @@ func (h *Headscale) PollNetMapStream( Msg("Cannot update machine LastSeen") } - return false + // The connection has been closed, so we can stop polling. + return + + case <-h.shutdownChan: + log.Info(). + Str("handler", "PollNetMapStream"). + Str("machine", machine.Hostname). + Msg("The long-poll handler is shutting down") + + return } - }) + } } func (h *Headscale) scheduledPollWorker( diff --git a/routes_test.go b/routes_test.go index 0108d888..89b712b5 100644 --- a/routes_test.go +++ b/routes_test.go @@ -28,7 +28,7 @@ func (s *Suite) TestGetRoutes(c *check.C) { MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", - Hostname: "test_get_route_machine", + Hostname: "test_get_route_machine", NamespaceID: namespace.ID, RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), @@ -79,7 +79,7 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) { MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", - Hostname: "test_enable_route_machine", + Hostname: "test_enable_route_machine", NamespaceID: namespace.ID, RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), diff --git a/swagger.go b/swagger.go index bad348db..588b42ab 100644 --- a/swagger.go +++ b/swagger.go @@ -6,14 +6,16 @@ import ( "html/template" "net/http" - "github.com/gin-gonic/gin" "github.com/rs/zerolog/log" ) //go:embed gen/openapiv2/headscale/v1/headscale.swagger.json var apiV1JSON []byte -func SwaggerUI(ctx *gin.Context) { +func SwaggerUI( + writer http.ResponseWriter, + req *http.Request, +) { swaggerTemplate := template.Must(template.New("swagger").Parse(` @@ -52,18 +54,41 @@ func SwaggerUI(ctx *gin.Context) { Caller(). Err(err). Msg("Could not render Swagger") - ctx.Data( - http.StatusInternalServerError, - "text/html; charset=utf-8", - []byte("Could not render Swagger"), - ) + + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusInternalServerError) + _, err := writer.Write([]byte("Could not render Swagger")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } return } - ctx.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes()) + writer.Header().Set("Content-Type", "text/html; charset=utf-8") + writer.WriteHeader(http.StatusOK) + _, err := writer.Write(payload.Bytes()) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } } -func SwaggerAPIv1(ctx *gin.Context) { - ctx.Data(http.StatusOK, "application/json; charset=utf-8", apiV1JSON) +func SwaggerAPIv1( + writer http.ResponseWriter, + req *http.Request, +) { + writer.Header().Set("Content-Type", "application/json; charset=utf-88") + writer.WriteHeader(http.StatusOK) + if _, err := writer.Write(apiV1JSON); err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } } diff --git a/utils.go b/utils.go index fd4cda86..87930a16 100644 --- a/utils.go +++ b/utils.go @@ -324,18 +324,18 @@ func GenerateRandomStringURLSafe(n int) (string, error) { // It will return an error if the system's secure random // number generator fails to function correctly, in which // case the caller should not continue. -func GenerateRandomStringDNSSafe(n int) (string, error) { +func GenerateRandomStringDNSSafe(size int) (string, error) { var str string var err error - for len(str) < n { - str, err = GenerateRandomStringURLSafe(n) + for len(str) < size { + str, err = GenerateRandomStringURLSafe(size) if err != nil { return "", err } str = strings.ToLower(strings.ReplaceAll(strings.ReplaceAll(str, "_", ""), "-", "")) } - return str[:n], nil + return str[:size], nil } func IsStringInSlice(slice []string, str string) bool {