Remove gin from the poll handlers

This commit is contained in:
Juan Font Alonso 2022-06-20 12:30:51 +02:00
parent dedeb4c181
commit 53e5c05b0a

125
poll.go
View file

@ -2,13 +2,14 @@ package headscale
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gorilla/mux"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
@ -33,13 +34,25 @@ const machineNameContextKey = contextKey("machineName")
// only after their first request (marked with the ReadOnly field). // only after their first request (marked with the ReadOnly field).
// //
// At this moment the updates are sent in a quite horrendous way, but they kinda work. // At this moment the updates are sent in a quite horrendous way, but they kinda work.
func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { func (h *Headscale) PollNetMapHandler(
w http.ResponseWriter,
r *http.Request,
) {
vars := mux.Vars(r)
machineKeyStr, ok := vars["mkey"]
if !ok || machineKeyStr == "" {
log.Error().
Str("handler", "PollNetMap").
Msg("No machine key in request")
http.Error(w, "No machine key in request", http.StatusBadRequest)
return
}
log.Trace(). log.Trace().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("id", ctx.Param("id")). Str("id", machineKeyStr).
Msg("PollNetMapHandler called") Msg("PollNetMapHandler called")
body, _ := io.ReadAll(ctx.Request.Body) body, _ := io.ReadAll(r.Body)
machineKeyStr := ctx.Param("id")
var machineKey key.MachinePublic var machineKey key.MachinePublic
err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr))) err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr)))
@ -48,7 +61,8 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Err(err). Err(err).
Msg("Cannot parse client key") Msg("Cannot parse client key")
ctx.String(http.StatusBadRequest, "")
http.Error(w, "Cannot parse client key", http.StatusBadRequest)
return return
} }
@ -59,7 +73,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Err(err). Err(err).
Msg("Cannot decode message") Msg("Cannot decode message")
ctx.String(http.StatusBadRequest, "") http.Error(w, "Cannot decode message", http.StatusBadRequest)
return return
} }
@ -70,20 +84,21 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
log.Warn(). log.Warn().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Msgf("Ignoring request, cannot find machine with key %s", machineKey.String()) Msgf("Ignoring request, cannot find machine with key %s", machineKey.String())
ctx.String(http.StatusUnauthorized, "")
http.Error(w, "", http.StatusUnauthorized)
return return
} }
log.Error(). log.Error().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Msgf("Failed to fetch machine from the database with Machine key: %s", machineKey.String()) Msgf("Failed to fetch machine from the database with Machine key: %s", machineKey.String())
ctx.String(http.StatusInternalServerError, "") http.Error(w, "", http.StatusInternalServerError)
return return
} }
log.Trace(). log.Trace().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("id", ctx.Param("id")). Str("id", machineKeyStr).
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Msg("Found machine in database") Msg("Found machine in database")
@ -120,11 +135,11 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("id", ctx.Param("id")). Str("id", machineKeyStr).
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Err(err). Err(err).
Msg("Failed to persist/update machine in the database") Msg("Failed to persist/update machine in the database")
ctx.String(http.StatusInternalServerError, ":(") http.Error(w, "", http.StatusInternalServerError)
return return
} }
@ -134,11 +149,11 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("id", ctx.Param("id")). Str("id", machineKeyStr).
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Err(err). Err(err).
Msg("Failed to get Map response") Msg("Failed to get Map response")
ctx.String(http.StatusInternalServerError, ":(") http.Error(w, "", http.StatusInternalServerError)
return return
} }
@ -150,7 +165,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
// Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696 // Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696
log.Debug(). log.Debug().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("id", ctx.Param("id")). Str("id", machineKeyStr).
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Bool("readOnly", req.ReadOnly). Bool("readOnly", req.ReadOnly).
Bool("omitPeers", req.OmitPeers). Bool("omitPeers", req.OmitPeers).
@ -162,7 +177,10 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Msg("Client is starting up. Probably interested in a DERP map") Msg("Client is starting up. Probably interested in a DERP map")
ctx.Data(http.StatusOK, "application/json; charset=utf-8", data)
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(data)
return return
} }
@ -177,7 +195,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
// Only create update channel if it has not been created // Only create update channel if it has not been created
log.Trace(). log.Trace().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("id", ctx.Param("id")). Str("id", machineKeyStr).
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Msg("Loading or creating update channel") Msg("Loading or creating update channel")
@ -194,8 +212,9 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Msg("Client sent endpoint update and is ok with a response without peer list") Msg("Client sent endpoint update and is ok with a response without peer list")
ctx.Data(http.StatusOK, "application/json; charset=utf-8", data) w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(data)
// It sounds like we should update the nodes when we have received a endpoint update // It sounds like we should update the nodes when we have received a endpoint update
// even tho the comments in the tailscale code dont explicitly say so. // even tho the comments in the tailscale code dont explicitly say so.
updateRequestsFromNode.WithLabelValues(machine.Namespace.Name, machine.Hostname, "endpoint-update"). updateRequestsFromNode.WithLabelValues(machine.Namespace.Name, machine.Hostname, "endpoint-update").
@ -208,7 +227,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Msg("Ignoring request, don't know how to handle it") Msg("Ignoring request, don't know how to handle it")
ctx.String(http.StatusBadRequest, "") http.Error(w, "", http.StatusBadRequest)
return return
} }
@ -232,7 +251,8 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
updateChan <- struct{}{} updateChan <- struct{}{}
h.PollNetMapStream( h.PollNetMapStream(
ctx, w,
r,
machine, machine,
req, req,
machineKey, machineKey,
@ -242,7 +262,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
) )
log.Trace(). log.Trace().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("id", ctx.Param("id")). Str("id", machineKeyStr).
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Msg("Finished stream, closing PollNetMap session") Msg("Finished stream, closing PollNetMap session")
} }
@ -251,7 +271,8 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
// stream logic, ensuring we communicate updates and data // stream logic, ensuring we communicate updates and data
// to the connected clients. // to the connected clients.
func (h *Headscale) PollNetMapStream( func (h *Headscale) PollNetMapStream(
ctx *gin.Context, w http.ResponseWriter,
r *http.Request,
machine *Machine, machine *Machine,
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
machineKey key.MachinePublic, machineKey key.MachinePublic,
@ -259,26 +280,7 @@ func (h *Headscale) PollNetMapStream(
keepAliveChan chan []byte, keepAliveChan chan []byte,
updateChan chan struct{}, updateChan chan struct{},
) { ) {
{ ctx := context.WithValue(context.Background(), machineNameContextKey, machine.Hostname)
machine, err := h.GetMachineByMachineKey(machineKey)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Warn().
Str("handler", "PollNetMap").
Msgf("Ignoring request, cannot find machine with key %s", machineKey.String())
ctx.String(http.StatusUnauthorized, "")
return
}
log.Error().
Str("handler", "PollNetMap").
Msgf("Failed to fetch machine from the database with Machine key: %s", machineKey.String())
ctx.String(http.StatusInternalServerError, "")
return
}
ctx := context.WithValue(ctx.Request.Context(), machineNameContextKey, machine.Hostname)
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
@ -291,9 +293,8 @@ func (h *Headscale) PollNetMapStream(
mapRequest, mapRequest,
machine, machine,
) )
}
ctx.Stream(func(writer io.Writer) bool { for {
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
@ -312,7 +313,7 @@ func (h *Headscale) PollNetMapStream(
Str("channel", "pollData"). Str("channel", "pollData").
Int("bytes", len(data)). Int("bytes", len(data)).
Msg("Sending data received via pollData channel") Msg("Sending data received via pollData channel")
_, err := writer.Write(data) _, err := w.Write(data)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
@ -321,7 +322,7 @@ func (h *Headscale) PollNetMapStream(
Err(err). Err(err).
Msg("Cannot write data") Msg("Cannot write data")
return false break
} }
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
@ -343,7 +344,7 @@ func (h *Headscale) PollNetMapStream(
// client has been removed from database // client has been removed from database
// since the stream opened, terminate connection. // since the stream opened, terminate connection.
return false break
} }
now := time.Now().UTC() now := time.Now().UTC()
machine.LastSeen = &now machine.LastSeen = &now
@ -369,7 +370,7 @@ func (h *Headscale) PollNetMapStream(
Msg("Machine entry in database updated successfully after sending pollData") Msg("Machine entry in database updated successfully after sending pollData")
} }
return true break
case data := <-keepAliveChan: case data := <-keepAliveChan:
log.Trace(). log.Trace().
@ -378,7 +379,7 @@ func (h *Headscale) PollNetMapStream(
Str("channel", "keepAlive"). Str("channel", "keepAlive").
Int("bytes", len(data)). Int("bytes", len(data)).
Msg("Sending keep alive message") Msg("Sending keep alive message")
_, err := writer.Write(data) _, err := w.Write(data)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
@ -387,7 +388,7 @@ func (h *Headscale) PollNetMapStream(
Err(err). Err(err).
Msg("Cannot write keep alive message") Msg("Cannot write keep alive message")
return false break
} }
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
@ -409,7 +410,7 @@ func (h *Headscale) PollNetMapStream(
// client has been removed from database // client has been removed from database
// since the stream opened, terminate connection. // since the stream opened, terminate connection.
return false break
} }
now := time.Now().UTC() now := time.Now().UTC()
machine.LastSeen = &now machine.LastSeen = &now
@ -430,7 +431,7 @@ func (h *Headscale) PollNetMapStream(
Msg("Machine updated successfully after sending keep alive") Msg("Machine updated successfully after sending keep alive")
} }
return true break
case <-updateChan: case <-updateChan:
log.Trace(). log.Trace().
@ -460,7 +461,7 @@ func (h *Headscale) PollNetMapStream(
Err(err). Err(err).
Msg("Could not get the map update") Msg("Could not get the map update")
} }
_, err = writer.Write(data) _, err = w.Write(data)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
@ -471,7 +472,7 @@ func (h *Headscale) PollNetMapStream(
updateRequestsSentToNode.WithLabelValues(machine.Namespace.Name, machine.Hostname, "failed"). updateRequestsSentToNode.WithLabelValues(machine.Namespace.Name, machine.Hostname, "failed").
Inc() Inc()
return false return
} }
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
@ -499,7 +500,7 @@ func (h *Headscale) PollNetMapStream(
// client has been removed from database // client has been removed from database
// since the stream opened, terminate connection. // since the stream opened, terminate connection.
return false return
} }
now := time.Now().UTC() now := time.Now().UTC()
@ -529,9 +530,9 @@ func (h *Headscale) PollNetMapStream(
Msgf("%s is up to date", machine.Hostname) Msgf("%s is up to date", machine.Hostname)
} }
return true return
case <-ctx.Request.Context().Done(): case <-ctx.Done():
log.Info(). log.Info().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
@ -550,7 +551,7 @@ func (h *Headscale) PollNetMapStream(
// client has been removed from database // client has been removed from database
// since the stream opened, terminate connection. // since the stream opened, terminate connection.
return false break
} }
now := time.Now().UTC() now := time.Now().UTC()
machine.LastSeen = &now machine.LastSeen = &now
@ -564,9 +565,11 @@ func (h *Headscale) PollNetMapStream(
Msg("Cannot update machine LastSeen") Msg("Cannot update machine LastSeen")
} }
return false break
} }
}) }
log.Info().Msgf("Closing poll loop to %s", machine.Hostname)
} }
func (h *Headscale) scheduledPollWorker( func (h *Headscale) scheduledPollWorker(