Merge branch 'main' into update-deps-20220904

This commit is contained in:
Juan Font Alonso 2022-09-04 15:37:25 +02:00
commit 3c73cbe92b
16 changed files with 52 additions and 49 deletions

View file

@ -26,7 +26,7 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
uses: golangci/golangci-lint-action@v2
with:
version: v1.46.1
version: v1.49.0
# Only block PRs on new problems.
# If this is not enabled, we will end up having PRs

View file

@ -825,7 +825,6 @@ func Test_listMachinesInNamespace(t *testing.T) {
}
}
// nolint
func Test_expandAlias(t *testing.T) {
type args struct {
machines []Machine

2
api.go
View file

@ -52,7 +52,7 @@ func (h *Headscale) HealthHandler(
}
}
if err := h.pingDB(); err != nil {
if err := h.pingDB(req.Context()); err != nil {
respond(err)
return

21
app.go
View file

@ -18,7 +18,7 @@ import (
"github.com/coreos/go-oidc/v3/oidc"
"github.com/gorilla/mux"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/patrickmn/go-cache"
@ -601,7 +601,7 @@ func (h *Headscale) Serve() error {
grpcOptions := []grpc.ServerOption{
grpc.UnaryInterceptor(
grpc_middleware.ChainUnaryServer(
grpcMiddleware.ChainUnaryServer(
h.grpcAuthenticationInterceptor,
zerolog.NewUnaryServerInterceptor(),
),
@ -820,10 +820,19 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
// Configuration via autocert with HTTP-01. This requires listening on
// port 80 for the certificate validation in addition to the headscale
// service, which can be configured to run on any other port.
server := &http.Server{
Addr: h.cfg.TLS.LetsEncrypt.Listen,
Handler: certManager.HTTPHandler(http.HandlerFunc(h.redirect)),
ReadTimeout: HTTPReadTimeout,
}
err := server.ListenAndServe()
go func() {
log.Fatal().
Caller().
Err(http.ListenAndServe(h.cfg.TLS.LetsEncrypt.Listen, certManager.HTTPHandler(http.HandlerFunc(h.redirect)))).
Err(err).
Msg("failed to set up a HTTP server")
}()
@ -860,20 +869,18 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
}
}
func (h *Headscale) setLastStateChangeToNow(namespaces ...string) {
func (h *Headscale) setLastStateChangeToNow() {
var err error
now := time.Now().UTC()
if len(namespaces) == 0 {
namespaces, err = h.ListNamespacesStr()
namespaces, err := h.ListNamespacesStr()
if err != nil {
log.Error().
Caller().
Err(err).
Msg("failed to fetch all namespaces, failing to update last changed state.")
}
}
for _, namespace := range namespaces {
lastStateUpdate.WithLabelValues(namespace, "headscale").Set(float64(now.Unix()))

View file

@ -5,12 +5,11 @@ import (
"errors"
"fmt"
"io/fs"
"net/netip"
"net/url"
"strings"
"time"
"net/netip"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"

4
db.go
View file

@ -221,8 +221,8 @@ func (h *Headscale) setValue(key string, value string) error {
return nil
}
func (h *Headscale) pingDB() error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
func (h *Headscale) pingDB(ctx context.Context) error {
ctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
db, err := h.db.DB()
if err != nil {

View file

@ -34,7 +34,7 @@ func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) {
ctx, cancel := context.WithTimeout(context.Background(), HTTPReadTimeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", addr.String(), nil)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, addr.String(), nil)
if err != nil {
return nil, err
}

View file

@ -174,7 +174,7 @@ func (h *Headscale) DERPProbeHandler(
req *http.Request,
) {
switch req.Method {
case "HEAD", "GET":
case http.MethodHead, http.MethodGet:
writer.Header().Set("Access-Control-Allow-Origin", "*")
writer.WriteHeader(http.StatusOK)
default:
@ -202,7 +202,7 @@ func (h *Headscale) DERPBootstrapDNSHandler(
) {
dnsEntries := make(map[string][]net.IP)
resolvCtx, cancel := context.WithTimeout(context.Background(), time.Minute)
resolvCtx, cancel := context.WithTimeout(req.Context(), time.Minute)
defer cancel()
var resolver net.Resolver
for _, region := range h.DERPMap.Regions {

View file

@ -540,7 +540,6 @@ func Test_getTags(t *testing.T) {
}
}
// nolint
func Test_getFilteredByACLPeers(t *testing.T) {
type args struct {
machines []Machine

View file

@ -31,7 +31,9 @@ func (h *Headscale) NoiseUpgradeHandler(
return
}
server := http.Server{}
server := http.Server{
ReadTimeout: HTTPReadTimeout,
}
server.Handler = h2c.NewHandler(h.noiseMux, &http2.Server{})
err = server.Serve(netutil.NewOneConnListener(noiseConn, nil))
if err != nil {

10
oidc.go
View file

@ -148,12 +148,12 @@ func (h *Headscale) OIDCCallback(
return
}
rawIDToken, err := h.getIDTokenForOIDCCallback(writer, code, state)
rawIDToken, err := h.getIDTokenForOIDCCallback(req.Context(), writer, code, state)
if err != nil {
return
}
idToken, err := h.verifyIDTokenForOIDCCallback(writer, rawIDToken)
idToken, err := h.verifyIDTokenForOIDCCallback(req.Context(), writer, rawIDToken)
if err != nil {
return
}
@ -240,10 +240,11 @@ func validateOIDCCallbackParams(
}
func (h *Headscale) getIDTokenForOIDCCallback(
ctx context.Context,
writer http.ResponseWriter,
code, state string,
) (string, error) {
oauth2Token, err := h.oauth2Config.Exchange(context.Background(), code)
oauth2Token, err := h.oauth2Config.Exchange(ctx, code)
if err != nil {
log.Error().
Err(err).
@ -287,11 +288,12 @@ func (h *Headscale) getIDTokenForOIDCCallback(
}
func (h *Headscale) verifyIDTokenForOIDCCallback(
ctx context.Context,
writer http.ResponseWriter,
rawIDToken string,
) (*oidc.IDToken, error) {
verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDC.ClientID})
idToken, err := verifier.Verify(context.Background(), rawIDToken)
idToken, err := verifier.Verify(ctx, rawIDToken)
if err != nil {
log.Error().
Err(err).

View file

@ -105,7 +105,7 @@ func (h *Headscale) handleRegisterCommon(
if errors.Is(err, gorm.ErrRecordNotFound) {
// If the machine has AuthKey set, handle registration via PreAuthKeys
if registerRequest.Auth.AuthKey != "" {
h.handleAuthKeyCommon(writer, req, registerRequest, machineKey)
h.handleAuthKeyCommon(writer, registerRequest, machineKey)
return
}
@ -134,7 +134,7 @@ func (h *Headscale) handleRegisterCommon(
case <-req.Context().Done():
return
case <-ticker.C:
h.handleNewMachineCommon(writer, req, registerRequest, machineKey)
h.handleNewMachineCommon(writer, registerRequest, machineKey)
return
}
@ -190,7 +190,7 @@ func (h *Headscale) handleRegisterCommon(
registerCacheExpiration,
)
h.handleNewMachineCommon(writer, req, registerRequest, machineKey)
h.handleNewMachineCommon(writer, registerRequest, machineKey)
return
}
@ -207,7 +207,7 @@ func (h *Headscale) handleRegisterCommon(
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
if !registerRequest.Expiry.IsZero() &&
registerRequest.Expiry.UTC().Before(now) {
h.handleMachineLogOutCommon(writer, req, *machine, machineKey)
h.handleMachineLogOutCommon(writer, *machine, machineKey)
return
}
@ -215,7 +215,7 @@ func (h *Headscale) handleRegisterCommon(
// If machine is not expired, and is register, we have a already accepted this machine,
// let it proceed with a valid registration
if !machine.isExpired() {
h.handleMachineValidRegistrationCommon(writer, req, *machine, machineKey)
h.handleMachineValidRegistrationCommon(writer, *machine, machineKey)
return
}
@ -226,7 +226,6 @@ func (h *Headscale) handleRegisterCommon(
!machine.isExpired() {
h.handleMachineRefreshKeyCommon(
writer,
req,
registerRequest,
*machine,
machineKey,
@ -236,7 +235,7 @@ func (h *Headscale) handleRegisterCommon(
}
// The machine has expired
h.handleMachineExpiredCommon(writer, req, registerRequest, *machine, machineKey)
h.handleMachineExpiredCommon(writer, registerRequest, *machine, machineKey)
machine.Expiry = &time.Time{}
h.registrationCache.Set(
@ -256,7 +255,6 @@ func (h *Headscale) handleRegisterCommon(
// TODO: check if any locks are needed around IP allocation.
func (h *Headscale) handleAuthKeyCommon(
writer http.ResponseWriter,
req *http.Request,
registerRequest tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) {
@ -455,7 +453,6 @@ func (h *Headscale) handleAuthKeyCommon(
// for authorizing the machine. This url is then showed to the user by the local Tailscale client.
func (h *Headscale) handleNewMachineCommon(
writer http.ResponseWriter,
req *http.Request,
registerRequest tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) {
@ -511,7 +508,6 @@ func (h *Headscale) handleNewMachineCommon(
func (h *Headscale) handleMachineLogOutCommon(
writer http.ResponseWriter,
req *http.Request,
machine Machine,
machineKey key.MachinePublic,
) {
@ -570,7 +566,6 @@ func (h *Headscale) handleMachineLogOutCommon(
func (h *Headscale) handleMachineValidRegistrationCommon(
writer http.ResponseWriter,
req *http.Request,
machine Machine,
machineKey key.MachinePublic,
) {
@ -624,7 +619,6 @@ func (h *Headscale) handleMachineValidRegistrationCommon(
func (h *Headscale) handleMachineRefreshKeyCommon(
writer http.ResponseWriter,
req *http.Request,
registerRequest tailcfg.RegisterRequest,
machine Machine,
machineKey key.MachinePublic,
@ -684,7 +678,6 @@ func (h *Headscale) handleMachineRefreshKeyCommon(
func (h *Headscale) handleMachineExpiredCommon(
writer http.ResponseWriter,
req *http.Request,
registerRequest tailcfg.RegisterRequest,
machine Machine,
machineKey key.MachinePublic,
@ -699,7 +692,7 @@ func (h *Headscale) handleMachineExpiredCommon(
Msg("Machine registration has expired. Sending a authurl to register")
if registerRequest.Auth.AuthKey != "" {
h.handleAuthKeyCommon(writer, req, registerRequest, machineKey)
h.handleAuthKeyCommon(writer, registerRequest, machineKey)
return
}

View file

@ -22,7 +22,7 @@ const machineNameContextKey = contextKey("machineName")
// managed the poll loop.
func (h *Headscale) handlePollCommon(
writer http.ResponseWriter,
req *http.Request,
ctx context.Context,
machine *Machine,
mapRequest tailcfg.MapRequest,
isNoise bool,
@ -201,7 +201,7 @@ func (h *Headscale) handlePollCommon(
h.pollNetMapStream(
writer,
req,
ctx,
machine,
mapRequest,
pollDataChan,
@ -221,7 +221,7 @@ func (h *Headscale) handlePollCommon(
// ensuring we communicate updates and data to the connected clients.
func (h *Headscale) pollNetMapStream(
writer http.ResponseWriter,
req *http.Request,
ctxReq context.Context,
machine *Machine,
mapRequest tailcfg.MapRequest,
pollDataChan chan []byte,
@ -232,7 +232,7 @@ func (h *Headscale) pollNetMapStream(
h.pollNetMapStreamWG.Add(1)
defer h.pollNetMapStreamWG.Done()
ctx := context.WithValue(req.Context(), machineNameContextKey, machine.Hostname)
ctx := context.WithValue(ctxReq, machineNameContextKey, machine.Hostname)
ctx, cancel := context.WithCancel(ctx)
defer cancel()

View file

@ -75,6 +75,8 @@ func (h *Headscale) marshalResponse(
Caller().
Err(err).
Msg("Cannot marshal response")
return nil, err
}
if machineKey.IsZero() { // if Noise

View file

@ -90,5 +90,5 @@ func (h *Headscale) PollNetMapHandler(
Str("machine", machine.Hostname).
Msg("A machine is entering polling via the legacy protocol")
h.handlePollCommon(writer, req, machine, mapRequest, false)
h.handlePollCommon(writer, req.Context(), machine, mapRequest, false)
}

View file

@ -63,5 +63,5 @@ func (h *Headscale) NoisePollNetMapHandler(
Str("machine", machine.Hostname).
Msg("A machine is entering polling via the Noise protocol")
h.handlePollCommon(writer, req, machine, mapRequest, true)
h.handlePollCommon(writer, req.Context(), machine, mapRequest, true)
}