Merge branch 'main' into update-deps-20220904

This commit is contained in:
Juan Font Alonso 2022-09-04 15:37:25 +02:00
commit 3c73cbe92b
16 changed files with 52 additions and 49 deletions

View file

@ -26,7 +26,7 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
uses: golangci/golangci-lint-action@v2 uses: golangci/golangci-lint-action@v2
with: with:
version: v1.46.1 version: v1.49.0
# Only block PRs on new problems. # Only block PRs on new problems.
# If this is not enabled, we will end up having PRs # If this is not enabled, we will end up having PRs

View file

@ -825,7 +825,6 @@ func Test_listMachinesInNamespace(t *testing.T) {
} }
} }
// nolint
func Test_expandAlias(t *testing.T) { func Test_expandAlias(t *testing.T) {
type args struct { type args struct {
machines []Machine machines []Machine

2
api.go
View file

@ -52,7 +52,7 @@ func (h *Headscale) HealthHandler(
} }
} }
if err := h.pingDB(); err != nil { if err := h.pingDB(req.Context()); err != nil {
respond(err) respond(err)
return return

21
app.go
View file

@ -18,7 +18,7 @@ import (
"github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc"
"github.com/gorilla/mux" "github.com/gorilla/mux"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/patrickmn/go-cache" "github.com/patrickmn/go-cache"
@ -601,7 +601,7 @@ func (h *Headscale) Serve() error {
grpcOptions := []grpc.ServerOption{ grpcOptions := []grpc.ServerOption{
grpc.UnaryInterceptor( grpc.UnaryInterceptor(
grpc_middleware.ChainUnaryServer( grpcMiddleware.ChainUnaryServer(
h.grpcAuthenticationInterceptor, h.grpcAuthenticationInterceptor,
zerolog.NewUnaryServerInterceptor(), zerolog.NewUnaryServerInterceptor(),
), ),
@ -820,10 +820,19 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
// Configuration via autocert with HTTP-01. This requires listening on // Configuration via autocert with HTTP-01. This requires listening on
// port 80 for the certificate validation in addition to the headscale // port 80 for the certificate validation in addition to the headscale
// service, which can be configured to run on any other port. // service, which can be configured to run on any other port.
server := &http.Server{
Addr: h.cfg.TLS.LetsEncrypt.Listen,
Handler: certManager.HTTPHandler(http.HandlerFunc(h.redirect)),
ReadTimeout: HTTPReadTimeout,
}
err := server.ListenAndServe()
go func() { go func() {
log.Fatal(). log.Fatal().
Caller(). Caller().
Err(http.ListenAndServe(h.cfg.TLS.LetsEncrypt.Listen, certManager.HTTPHandler(http.HandlerFunc(h.redirect)))). Err(err).
Msg("failed to set up a HTTP server") Msg("failed to set up a HTTP server")
}() }()
@ -860,20 +869,18 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
} }
} }
func (h *Headscale) setLastStateChangeToNow(namespaces ...string) { func (h *Headscale) setLastStateChangeToNow() {
var err error var err error
now := time.Now().UTC() now := time.Now().UTC()
if len(namespaces) == 0 { namespaces, err := h.ListNamespacesStr()
namespaces, err = h.ListNamespacesStr()
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
Err(err). Err(err).
Msg("failed to fetch all namespaces, failing to update last changed state.") Msg("failed to fetch all namespaces, failing to update last changed state.")
} }
}
for _, namespace := range namespaces { for _, namespace := range namespaces {
lastStateUpdate.WithLabelValues(namespace, "headscale").Set(float64(now.Unix())) lastStateUpdate.WithLabelValues(namespace, "headscale").Set(float64(now.Unix()))

View file

@ -5,12 +5,11 @@ import (
"errors" "errors"
"fmt" "fmt"
"io/fs" "io/fs"
"net/netip"
"net/url" "net/url"
"strings" "strings"
"time" "time"
"net/netip"
"github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"

4
db.go
View file

@ -221,8 +221,8 @@ func (h *Headscale) setValue(key string, value string) error {
return nil return nil
} }
func (h *Headscale) pingDB() error { func (h *Headscale) pingDB(ctx context.Context) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel() defer cancel()
db, err := h.db.DB() db, err := h.db.DB()
if err != nil { if err != nil {

View file

@ -34,7 +34,7 @@ func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) {
ctx, cancel := context.WithTimeout(context.Background(), HTTPReadTimeout) ctx, cancel := context.WithTimeout(context.Background(), HTTPReadTimeout)
defer cancel() defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", addr.String(), nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, addr.String(), nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -174,7 +174,7 @@ func (h *Headscale) DERPProbeHandler(
req *http.Request, req *http.Request,
) { ) {
switch req.Method { switch req.Method {
case "HEAD", "GET": case http.MethodHead, http.MethodGet:
writer.Header().Set("Access-Control-Allow-Origin", "*") writer.Header().Set("Access-Control-Allow-Origin", "*")
writer.WriteHeader(http.StatusOK) writer.WriteHeader(http.StatusOK)
default: default:
@ -202,7 +202,7 @@ func (h *Headscale) DERPBootstrapDNSHandler(
) { ) {
dnsEntries := make(map[string][]net.IP) dnsEntries := make(map[string][]net.IP)
resolvCtx, cancel := context.WithTimeout(context.Background(), time.Minute) resolvCtx, cancel := context.WithTimeout(req.Context(), time.Minute)
defer cancel() defer cancel()
var resolver net.Resolver var resolver net.Resolver
for _, region := range h.DERPMap.Regions { for _, region := range h.DERPMap.Regions {

View file

@ -540,7 +540,6 @@ func Test_getTags(t *testing.T) {
} }
} }
// nolint
func Test_getFilteredByACLPeers(t *testing.T) { func Test_getFilteredByACLPeers(t *testing.T) {
type args struct { type args struct {
machines []Machine machines []Machine

View file

@ -31,7 +31,9 @@ func (h *Headscale) NoiseUpgradeHandler(
return return
} }
server := http.Server{} server := http.Server{
ReadTimeout: HTTPReadTimeout,
}
server.Handler = h2c.NewHandler(h.noiseMux, &http2.Server{}) server.Handler = h2c.NewHandler(h.noiseMux, &http2.Server{})
err = server.Serve(netutil.NewOneConnListener(noiseConn, nil)) err = server.Serve(netutil.NewOneConnListener(noiseConn, nil))
if err != nil { if err != nil {

10
oidc.go
View file

@ -148,12 +148,12 @@ func (h *Headscale) OIDCCallback(
return return
} }
rawIDToken, err := h.getIDTokenForOIDCCallback(writer, code, state) rawIDToken, err := h.getIDTokenForOIDCCallback(req.Context(), writer, code, state)
if err != nil { if err != nil {
return return
} }
idToken, err := h.verifyIDTokenForOIDCCallback(writer, rawIDToken) idToken, err := h.verifyIDTokenForOIDCCallback(req.Context(), writer, rawIDToken)
if err != nil { if err != nil {
return return
} }
@ -240,10 +240,11 @@ func validateOIDCCallbackParams(
} }
func (h *Headscale) getIDTokenForOIDCCallback( func (h *Headscale) getIDTokenForOIDCCallback(
ctx context.Context,
writer http.ResponseWriter, writer http.ResponseWriter,
code, state string, code, state string,
) (string, error) { ) (string, error) {
oauth2Token, err := h.oauth2Config.Exchange(context.Background(), code) oauth2Token, err := h.oauth2Config.Exchange(ctx, code)
if err != nil { if err != nil {
log.Error(). log.Error().
Err(err). Err(err).
@ -287,11 +288,12 @@ func (h *Headscale) getIDTokenForOIDCCallback(
} }
func (h *Headscale) verifyIDTokenForOIDCCallback( func (h *Headscale) verifyIDTokenForOIDCCallback(
ctx context.Context,
writer http.ResponseWriter, writer http.ResponseWriter,
rawIDToken string, rawIDToken string,
) (*oidc.IDToken, error) { ) (*oidc.IDToken, error) {
verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDC.ClientID}) verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDC.ClientID})
idToken, err := verifier.Verify(context.Background(), rawIDToken) idToken, err := verifier.Verify(ctx, rawIDToken)
if err != nil { if err != nil {
log.Error(). log.Error().
Err(err). Err(err).

View file

@ -105,7 +105,7 @@ func (h *Headscale) handleRegisterCommon(
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
// If the machine has AuthKey set, handle registration via PreAuthKeys // If the machine has AuthKey set, handle registration via PreAuthKeys
if registerRequest.Auth.AuthKey != "" { if registerRequest.Auth.AuthKey != "" {
h.handleAuthKeyCommon(writer, req, registerRequest, machineKey) h.handleAuthKeyCommon(writer, registerRequest, machineKey)
return return
} }
@ -134,7 +134,7 @@ func (h *Headscale) handleRegisterCommon(
case <-req.Context().Done(): case <-req.Context().Done():
return return
case <-ticker.C: case <-ticker.C:
h.handleNewMachineCommon(writer, req, registerRequest, machineKey) h.handleNewMachineCommon(writer, registerRequest, machineKey)
return return
} }
@ -190,7 +190,7 @@ func (h *Headscale) handleRegisterCommon(
registerCacheExpiration, registerCacheExpiration,
) )
h.handleNewMachineCommon(writer, req, registerRequest, machineKey) h.handleNewMachineCommon(writer, registerRequest, machineKey)
return return
} }
@ -207,7 +207,7 @@ func (h *Headscale) handleRegisterCommon(
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648 // https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
if !registerRequest.Expiry.IsZero() && if !registerRequest.Expiry.IsZero() &&
registerRequest.Expiry.UTC().Before(now) { registerRequest.Expiry.UTC().Before(now) {
h.handleMachineLogOutCommon(writer, req, *machine, machineKey) h.handleMachineLogOutCommon(writer, *machine, machineKey)
return return
} }
@ -215,7 +215,7 @@ func (h *Headscale) handleRegisterCommon(
// If machine is not expired, and is register, we have a already accepted this machine, // If machine is not expired, and is register, we have a already accepted this machine,
// let it proceed with a valid registration // let it proceed with a valid registration
if !machine.isExpired() { if !machine.isExpired() {
h.handleMachineValidRegistrationCommon(writer, req, *machine, machineKey) h.handleMachineValidRegistrationCommon(writer, *machine, machineKey)
return return
} }
@ -226,7 +226,6 @@ func (h *Headscale) handleRegisterCommon(
!machine.isExpired() { !machine.isExpired() {
h.handleMachineRefreshKeyCommon( h.handleMachineRefreshKeyCommon(
writer, writer,
req,
registerRequest, registerRequest,
*machine, *machine,
machineKey, machineKey,
@ -236,7 +235,7 @@ func (h *Headscale) handleRegisterCommon(
} }
// The machine has expired // The machine has expired
h.handleMachineExpiredCommon(writer, req, registerRequest, *machine, machineKey) h.handleMachineExpiredCommon(writer, registerRequest, *machine, machineKey)
machine.Expiry = &time.Time{} machine.Expiry = &time.Time{}
h.registrationCache.Set( h.registrationCache.Set(
@ -256,7 +255,6 @@ func (h *Headscale) handleRegisterCommon(
// TODO: check if any locks are needed around IP allocation. // TODO: check if any locks are needed around IP allocation.
func (h *Headscale) handleAuthKeyCommon( func (h *Headscale) handleAuthKeyCommon(
writer http.ResponseWriter, writer http.ResponseWriter,
req *http.Request,
registerRequest tailcfg.RegisterRequest, registerRequest tailcfg.RegisterRequest,
machineKey key.MachinePublic, machineKey key.MachinePublic,
) { ) {
@ -455,7 +453,6 @@ func (h *Headscale) handleAuthKeyCommon(
// for authorizing the machine. This url is then showed to the user by the local Tailscale client. // for authorizing the machine. This url is then showed to the user by the local Tailscale client.
func (h *Headscale) handleNewMachineCommon( func (h *Headscale) handleNewMachineCommon(
writer http.ResponseWriter, writer http.ResponseWriter,
req *http.Request,
registerRequest tailcfg.RegisterRequest, registerRequest tailcfg.RegisterRequest,
machineKey key.MachinePublic, machineKey key.MachinePublic,
) { ) {
@ -511,7 +508,6 @@ func (h *Headscale) handleNewMachineCommon(
func (h *Headscale) handleMachineLogOutCommon( func (h *Headscale) handleMachineLogOutCommon(
writer http.ResponseWriter, writer http.ResponseWriter,
req *http.Request,
machine Machine, machine Machine,
machineKey key.MachinePublic, machineKey key.MachinePublic,
) { ) {
@ -570,7 +566,6 @@ func (h *Headscale) handleMachineLogOutCommon(
func (h *Headscale) handleMachineValidRegistrationCommon( func (h *Headscale) handleMachineValidRegistrationCommon(
writer http.ResponseWriter, writer http.ResponseWriter,
req *http.Request,
machine Machine, machine Machine,
machineKey key.MachinePublic, machineKey key.MachinePublic,
) { ) {
@ -624,7 +619,6 @@ func (h *Headscale) handleMachineValidRegistrationCommon(
func (h *Headscale) handleMachineRefreshKeyCommon( func (h *Headscale) handleMachineRefreshKeyCommon(
writer http.ResponseWriter, writer http.ResponseWriter,
req *http.Request,
registerRequest tailcfg.RegisterRequest, registerRequest tailcfg.RegisterRequest,
machine Machine, machine Machine,
machineKey key.MachinePublic, machineKey key.MachinePublic,
@ -684,7 +678,6 @@ func (h *Headscale) handleMachineRefreshKeyCommon(
func (h *Headscale) handleMachineExpiredCommon( func (h *Headscale) handleMachineExpiredCommon(
writer http.ResponseWriter, writer http.ResponseWriter,
req *http.Request,
registerRequest tailcfg.RegisterRequest, registerRequest tailcfg.RegisterRequest,
machine Machine, machine Machine,
machineKey key.MachinePublic, machineKey key.MachinePublic,
@ -699,7 +692,7 @@ func (h *Headscale) handleMachineExpiredCommon(
Msg("Machine registration has expired. Sending a authurl to register") Msg("Machine registration has expired. Sending a authurl to register")
if registerRequest.Auth.AuthKey != "" { if registerRequest.Auth.AuthKey != "" {
h.handleAuthKeyCommon(writer, req, registerRequest, machineKey) h.handleAuthKeyCommon(writer, registerRequest, machineKey)
return return
} }

View file

@ -22,7 +22,7 @@ const machineNameContextKey = contextKey("machineName")
// managed the poll loop. // managed the poll loop.
func (h *Headscale) handlePollCommon( func (h *Headscale) handlePollCommon(
writer http.ResponseWriter, writer http.ResponseWriter,
req *http.Request, ctx context.Context,
machine *Machine, machine *Machine,
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
isNoise bool, isNoise bool,
@ -201,7 +201,7 @@ func (h *Headscale) handlePollCommon(
h.pollNetMapStream( h.pollNetMapStream(
writer, writer,
req, ctx,
machine, machine,
mapRequest, mapRequest,
pollDataChan, pollDataChan,
@ -221,7 +221,7 @@ func (h *Headscale) handlePollCommon(
// ensuring we communicate updates and data to the connected clients. // ensuring we communicate updates and data to the connected clients.
func (h *Headscale) pollNetMapStream( func (h *Headscale) pollNetMapStream(
writer http.ResponseWriter, writer http.ResponseWriter,
req *http.Request, ctxReq context.Context,
machine *Machine, machine *Machine,
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
pollDataChan chan []byte, pollDataChan chan []byte,
@ -232,7 +232,7 @@ func (h *Headscale) pollNetMapStream(
h.pollNetMapStreamWG.Add(1) h.pollNetMapStreamWG.Add(1)
defer h.pollNetMapStreamWG.Done() defer h.pollNetMapStreamWG.Done()
ctx := context.WithValue(req.Context(), machineNameContextKey, machine.Hostname) ctx := context.WithValue(ctxReq, machineNameContextKey, machine.Hostname)
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()

View file

@ -75,6 +75,8 @@ func (h *Headscale) marshalResponse(
Caller(). Caller().
Err(err). Err(err).
Msg("Cannot marshal response") Msg("Cannot marshal response")
return nil, err
} }
if machineKey.IsZero() { // if Noise if machineKey.IsZero() { // if Noise

View file

@ -90,5 +90,5 @@ func (h *Headscale) PollNetMapHandler(
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Msg("A machine is entering polling via the legacy protocol") Msg("A machine is entering polling via the legacy protocol")
h.handlePollCommon(writer, req, machine, mapRequest, false) h.handlePollCommon(writer, req.Context(), machine, mapRequest, false)
} }

View file

@ -63,5 +63,5 @@ func (h *Headscale) NoisePollNetMapHandler(
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Msg("A machine is entering polling via the Noise protocol") Msg("A machine is entering polling via the Noise protocol")
h.handlePollCommon(writer, req, machine, mapRequest, true) h.handlePollCommon(writer, req.Context(), machine, mapRequest, true)
} }