Switch to use gorilla's mux as muxer

This commit is contained in:
Juan Font Alonso 2022-06-18 18:41:42 +02:00
parent d5e331a2fb
commit d89fb68a7a
4 changed files with 92 additions and 67 deletions

137
app.go
View file

@ -18,6 +18,7 @@ import (
"github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gorilla/mux"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
@ -326,48 +327,56 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
return handler(ctx, req) return handler(ctx, req)
} }
func (h *Headscale) httpAuthenticationMiddleware(ctx *gin.Context) { func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler {
log.Trace(). return http.HandlerFunc(func(
Caller(). w http.ResponseWriter,
Str("client_address", ctx.ClientIP()). r *http.Request,
Msg("HTTP authentication invoked") ) {
log.Trace().
authHeader := ctx.GetHeader("authorization")
if !strings.HasPrefix(authHeader, AuthPrefix) {
log.Error().
Caller(). Caller().
Str("client_address", ctx.ClientIP()). Str("client_address", r.RemoteAddr).
Msg(`missing "Bearer " prefix in "Authorization" header`) Msg("HTTP authentication invoked")
ctx.AbortWithStatus(http.StatusUnauthorized)
return authHeader := r.Header.Get("X-Session-Token")
}
valid, err := h.ValidateAPIKey(strings.TrimPrefix(authHeader, AuthPrefix)) if !strings.HasPrefix(authHeader, AuthPrefix) {
if err != nil { log.Error().
log.Error(). Caller().
Caller(). Str("client_address", r.RemoteAddr).
Err(err). Msg(`missing "Bearer " prefix in "Authorization" header`)
Str("client_address", ctx.ClientIP()). w.WriteHeader(http.StatusUnauthorized)
Msg("failed to validate token") w.Write([]byte("Unauthorized"))
ctx.AbortWithStatus(http.StatusInternalServerError) return
}
return valid, err := h.ValidateAPIKey(strings.TrimPrefix(authHeader, AuthPrefix))
} if err != nil {
log.Error().
Caller().
Err(err).
Str("client_address", r.RemoteAddr).
Msg("failed to validate token")
if !valid { w.WriteHeader(http.StatusInternalServerError)
log.Info(). w.Write([]byte("Unauthorized"))
Str("client_address", ctx.ClientIP()).
Msg("invalid token")
ctx.AbortWithStatus(http.StatusUnauthorized) return
}
return if !valid {
} log.Info().
Str("client_address", r.RemoteAddr).
Msg("invalid token")
ctx.Next() w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte("Unauthorized"))
return
}
next.ServeHTTP(w, r)
})
} }
// ensureUnixSocketIsAbsent will check if the given path for headscales unix socket is clear // ensureUnixSocketIsAbsent will check if the given path for headscales unix socket is clear
@ -390,39 +399,42 @@ func (h *Headscale) createPrometheusRouter() *gin.Engine {
return promRouter return promRouter
} }
func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *gin.Engine { func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *mux.Router {
router := gin.Default() router := mux.NewRouter()
router.GET( router.HandleFunc(
"/health", "/health",
func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"healthy": "ok"}) }, func(w http.ResponseWriter, r *http.Request) {
) w.WriteHeader(http.StatusOK)
router.GET("/key", gin.WrapF(h.KeyHandler)) w.Write([]byte("{\"healthy\": \"ok\"}"))
router.GET("/register", gin.WrapF(h.RegisterWebAPI)) }).Methods(http.MethodGet)
router.POST("/machine/:id/map", h.PollNetMapHandler)
router.POST("/machine/:id", h.RegistrationHandler) router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet)
router.GET("/oidc/register/:mkey", h.RegisterOIDC) router.HandleFunc("/register", h.RegisterWebAPI).Methods(http.MethodGet)
router.GET("/oidc/callback", gin.WrapF(h.OIDCCallback)) router.HandleFunc("/machine/:id/map", h.PollNetMapHandler).Methods(http.MethodPost)
router.GET("/apple", gin.WrapF(h.AppleConfigMessage)) router.HandleFunc("/machine/:id", h.RegistrationHandler).Methods(http.MethodPost)
router.GET("/apple/:platform", gin.WrapF(h.ApplePlatformConfig)) router.HandleFunc("/oidc/register/:mkey", h.RegisterOIDC).Methods(http.MethodGet)
router.GET("/windows", gin.WrapF(h.WindowsConfigMessage)) router.HandleFunc("/oidc/callback", h.OIDCCallback).Methods(http.MethodGet)
router.GET("/windows/tailscale.reg", gin.WrapF(h.WindowsRegConfig)) router.HandleFunc("/apple", h.AppleConfigMessage).Methods(http.MethodGet)
router.GET("/swagger", gin.WrapF(SwaggerUI)) router.HandleFunc("/apple/:platform", h.ApplePlatformConfig).Methods(http.MethodGet)
router.GET("/swagger/v1/openapiv2.json", gin.WrapF(SwaggerAPIv1)) router.HandleFunc("/windows", h.WindowsConfigMessage).Methods(http.MethodGet)
router.HandleFunc("/windows/tailscale.reg", h.WindowsRegConfig).Methods(http.MethodGet)
router.HandleFunc("/swagger", SwaggerUI).Methods(http.MethodGet)
router.HandleFunc("/swagger/v1/openapiv2.json", SwaggerAPIv1).Methods(http.MethodGet)
if h.cfg.DERP.ServerEnabled { if h.cfg.DERP.ServerEnabled {
router.Any("/derp", h.DERPHandler) router.HandleFunc("/derp", h.DERPHandler)
router.Any("/derp/probe", h.DERPProbeHandler) router.HandleFunc("/derp/probe", h.DERPProbeHandler)
router.Any("/bootstrap-dns", h.DERPBootstrapDNSHandler) router.HandleFunc("/bootstrap-dns", h.DERPBootstrapDNSHandler)
} }
api := router.Group("/api") api := router.PathPrefix("/api").Subrouter()
api.Use(h.httpAuthenticationMiddleware) api.Use(h.httpAuthenticationMiddleware)
{ {
api.Any("/v1/*any", gin.WrapF(grpcMux.ServeHTTP)) api.HandleFunc("/v1/*any", grpcMux.ServeHTTP)
} }
router.NoRoute(stdoutHandler) router.PathPrefix("/").HandlerFunc(stdoutHandler)
return router return router
} }
@ -811,13 +823,16 @@ func (h *Headscale) getLastStateChange(namespaces ...string) time.Time {
} }
} }
func stdoutHandler(ctx *gin.Context) { func stdoutHandler(
body, _ := io.ReadAll(ctx.Request.Body) w http.ResponseWriter,
r *http.Request,
) {
body, _ := io.ReadAll(r.Body)
log.Trace(). log.Trace().
Interface("header", ctx.Request.Header). Interface("header", r.Header).
Interface("proto", ctx.Request.Proto). Interface("proto", r.Proto).
Interface("url", ctx.Request.URL). Interface("url", r.URL).
Bytes("body", body). Bytes("body", body).
Msg("Request did not match") Msg("Request did not match")
} }

View file

@ -10,7 +10,6 @@ import (
"strings" "strings"
"time" "time"
"github.com/gin-gonic/gin"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"tailscale.com/derp" "tailscale.com/derp"
"tailscale.com/net/stun" "tailscale.com/net/stun"
@ -90,7 +89,10 @@ func (h *Headscale) generateRegionLocalDERP() (tailcfg.DERPRegion, error) {
return localDERPregion, nil return localDERPregion, nil
} }
func (h *Headscale) DERPHandler(ctx *gin.Context) { func (h *Headscale) DERPHandler(
w http.ResponseWriter,
r *http.Request,
) {
log.Trace().Caller().Msgf("/derp request from %v", ctx.ClientIP()) log.Trace().Caller().Msgf("/derp request from %v", ctx.ClientIP())
up := strings.ToLower(ctx.Request.Header.Get("Upgrade")) up := strings.ToLower(ctx.Request.Header.Get("Upgrade"))
if up != "websocket" && up != "derp" { if up != "websocket" && up != "derp" {
@ -143,7 +145,10 @@ func (h *Headscale) DERPHandler(ctx *gin.Context) {
// DERPProbeHandler is the endpoint that js/wasm clients hit to measure // DERPProbeHandler is the endpoint that js/wasm clients hit to measure
// DERP latency, since they can't do UDP STUN queries. // DERP latency, since they can't do UDP STUN queries.
func (h *Headscale) DERPProbeHandler(ctx *gin.Context) { func (h *Headscale) DERPProbeHandler(
w http.ResponseWriter,
r *http.Request,
) {
switch ctx.Request.Method { switch ctx.Request.Method {
case "HEAD", "GET": case "HEAD", "GET":
ctx.Writer.Header().Set("Access-Control-Allow-Origin", "*") ctx.Writer.Header().Set("Access-Control-Allow-Origin", "*")
@ -159,15 +164,18 @@ func (h *Headscale) DERPProbeHandler(ctx *gin.Context) {
// The initial implementation is here https://github.com/tailscale/tailscale/pull/1406 // The initial implementation is here https://github.com/tailscale/tailscale/pull/1406
// They have a cache, but not clear if that is really necessary at Headscale, uh, scale. // They have a cache, but not clear if that is really necessary at Headscale, uh, scale.
// An example implementation is found here https://derp.tailscale.com/bootstrap-dns // An example implementation is found here https://derp.tailscale.com/bootstrap-dns
func (h *Headscale) DERPBootstrapDNSHandler(ctx *gin.Context) { func (h *Headscale) DERPBootstrapDNSHandler(
w http.ResponseWriter,
r *http.Request,
) {
dnsEntries := make(map[string][]net.IP) dnsEntries := make(map[string][]net.IP)
resolvCtx, cancel := context.WithTimeout(context.Background(), time.Minute) resolvCtx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel() defer cancel()
var r net.Resolver var resolver net.Resolver
for _, region := range h.DERPMap.Regions { for _, region := range h.DERPMap.Regions {
for _, node := range region.Nodes { // we don't care if we override some nodes for _, node := range region.Nodes { // we don't care if we override some nodes
addrs, err := r.LookupIP(resolvCtx, "ip", node.HostName) addrs, err := resolver.LookupIP(resolvCtx, "ip", node.HostName)
if err != nil { if err != nil {
log.Trace(). log.Trace().
Caller(). Caller().

1
go.mod
View file

@ -73,6 +73,7 @@ require (
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
github.com/google/uuid v1.3.0 // indirect github.com/google/uuid v1.3.0 // indirect
github.com/gookit/color v1.5.0 // indirect github.com/gookit/color v1.5.0 // indirect
github.com/gorilla/mux v1.8.0 // indirect
github.com/hashicorp/go-version v1.4.0 // indirect github.com/hashicorp/go-version v1.4.0 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect
github.com/imdario/mergo v0.3.12 // indirect github.com/imdario/mergo v0.3.12 // indirect

1
go.sum
View file

@ -403,6 +403,7 @@ github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORR
github.com/gordonklaus/ineffassign v0.0.0-20200309095847-7953dde2c7bf/go.mod h1:cuNKsD1zp2v6XfE/orVX2QE1LC+i254ceGcVeDT3pTU= github.com/gordonklaus/ineffassign v0.0.0-20200309095847-7953dde2c7bf/go.mod h1:cuNKsD1zp2v6XfE/orVX2QE1LC+i254ceGcVeDT3pTU=
github.com/gordonklaus/ineffassign v0.0.0-20210225214923-2e10b2664254/go.mod h1:M9mZEtGIsR1oDaZagNPNG9iq9n2HrhZ17dsXk73V3Lw= github.com/gordonklaus/ineffassign v0.0.0-20210225214923-2e10b2664254/go.mod h1:M9mZEtGIsR1oDaZagNPNG9iq9n2HrhZ17dsXk73V3Lw=
github.com/gorhill/cronexpr v0.0.0-20180427100037-88b0669f7d75/go.mod h1:g2644b03hfBX9Ov0ZBDgXXens4rxSxmqFBbhvKv2yVA= github.com/gorhill/cronexpr v0.0.0-20180427100037-88b0669f7d75/go.mod h1:g2644b03hfBX9Ov0ZBDgXXens4rxSxmqFBbhvKv2yVA=
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
github.com/gorilla/websocket v0.0.0-20170926233335-4201258b820c/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= github.com/gorilla/websocket v0.0.0-20170926233335-4201258b820c/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ=
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=