improve testing of route failover logic

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2024-04-10 15:35:09 +02:00 committed by Juan Font
parent bf4fd078fc
commit 1704977e76
11 changed files with 518 additions and 143 deletions

View file

@ -503,7 +503,7 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
return router return router
} }
// Serve launches a GIN server with the Headscale API. // Serve launches the HTTP and gRPC server service Headscale and the API.
func (h *Headscale) Serve() error { func (h *Headscale) Serve() error {
if _, enableProfile := os.LookupEnv("HEADSCALE_PROFILING_ENABLED"); enableProfile { if _, enableProfile := os.LookupEnv("HEADSCALE_PROFILING_ENABLED"); enableProfile {
if profilePath, ok := os.LookupEnv("HEADSCALE_PROFILING_PATH"); ok { if profilePath, ok := os.LookupEnv("HEADSCALE_PROFILING_PATH"); ok {
@ -532,7 +532,7 @@ func (h *Headscale) Serve() error {
region, err := h.DERPServer.GenerateRegion() region, err := h.DERPServer.GenerateRegion()
if err != nil { if err != nil {
return err return fmt.Errorf("generating DERP region for embedded server: %w", err)
} }
if h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion { if h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion {
@ -607,14 +607,14 @@ func (h *Headscale) Serve() error {
}..., }...,
) )
if err != nil { if err != nil {
return err return fmt.Errorf("setting up gRPC gateway via socket: %w", err)
} }
// Connect to the gRPC server over localhost to skip // Connect to the gRPC server over localhost to skip
// the authentication. // the authentication.
err = v1.RegisterHeadscaleServiceHandler(ctx, grpcGatewayMux, grpcGatewayConn) err = v1.RegisterHeadscaleServiceHandler(ctx, grpcGatewayMux, grpcGatewayConn)
if err != nil { if err != nil {
return err return fmt.Errorf("registering Headscale API service to gRPC: %w", err)
} }
// Start the local gRPC server without TLS and without authentication // Start the local gRPC server without TLS and without authentication
@ -635,9 +635,7 @@ func (h *Headscale) Serve() error {
tlsConfig, err := h.getTLSSettings() tlsConfig, err := h.getTLSSettings()
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to set up TLS configuration") return fmt.Errorf("configuring TLS settings: %w", err)
return err
} }
// //
@ -702,15 +700,11 @@ func (h *Headscale) Serve() error {
httpServer := &http.Server{ httpServer := &http.Server{
Addr: h.cfg.Addr, Addr: h.cfg.Addr,
Handler: router, Handler: router,
ReadTimeout: types.HTTPReadTimeout, ReadTimeout: types.HTTPTimeout,
// Go does not handle timeouts in HTTP very well, and there is
// no good way to handle streaming timeouts, therefore we need to // Long polling should not have any timeout, this is overriden
// keep this at unlimited and be careful to clean up connections // further down the chain
// https://blog.cloudflare.com/the-complete-guide-to-golang-net-http-timeouts/#aboutstreaming WriteTimeout: types.HTTPTimeout,
// TODO(kradalby): this timeout can now be set per handler with http.ResponseController:
// https://www.alexedwards.net/blog/how-to-use-the-http-responsecontroller-type
// replace this so only the longpoller has no timeout.
WriteTimeout: 0,
} }
var httpListener net.Listener var httpListener net.Listener
@ -729,27 +723,46 @@ func (h *Headscale) Serve() error {
log.Info(). log.Info().
Msgf("listening and serving HTTP on: %s", h.cfg.Addr) Msgf("listening and serving HTTP on: %s", h.cfg.Addr)
promMux := http.NewServeMux() debugMux := http.NewServeMux()
promMux.Handle("/metrics", promhttp.Handler()) debugMux.HandleFunc("/debug/notifier", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(h.nodeNotifier.String()))
promHTTPServer := &http.Server{ return
})
debugMux.HandleFunc("/debug/mapresp", func(w http.ResponseWriter, r *http.Request) {
h.mapSessionMu.Lock()
defer h.mapSessionMu.Unlock()
var b strings.Builder
b.WriteString("mapresponders:\n")
for k, v := range h.mapSessions {
fmt.Fprintf(&b, "\t%d: %p\n", k, v)
}
w.WriteHeader(http.StatusOK)
w.Write([]byte(b.String()))
return
})
debugMux.Handle("/metrics", promhttp.Handler())
debugHTTPServer := &http.Server{
Addr: h.cfg.MetricsAddr, Addr: h.cfg.MetricsAddr,
Handler: promMux, Handler: debugMux,
ReadTimeout: types.HTTPReadTimeout, ReadTimeout: types.HTTPTimeout,
WriteTimeout: 0, WriteTimeout: 0,
} }
var promHTTPListener net.Listener debugHTTPListener, err := net.Listen("tcp", h.cfg.MetricsAddr)
promHTTPListener, err = net.Listen("tcp", h.cfg.MetricsAddr)
if err != nil { if err != nil {
return fmt.Errorf("failed to bind to TCP address: %w", err) return fmt.Errorf("failed to bind to TCP address: %w", err)
} }
errorGroup.Go(func() error { return promHTTPServer.Serve(promHTTPListener) }) errorGroup.Go(func() error { return debugHTTPServer.Serve(debugHTTPListener) })
log.Info(). log.Info().
Msgf("listening and serving metrics on: %s", h.cfg.MetricsAddr) Msgf("listening and serving debug and metrics on: %s", h.cfg.MetricsAddr)
var tailsqlContext context.Context var tailsqlContext context.Context
if tailsqlEnabled { if tailsqlEnabled {
@ -815,7 +828,7 @@ func (h *Headscale) Serve() error {
context.Background(), context.Background(),
types.HTTPShutdownTimeout, types.HTTPShutdownTimeout,
) )
if err := promHTTPServer.Shutdown(ctx); err != nil { if err := debugHTTPServer.Shutdown(ctx); err != nil {
log.Error().Err(err).Msg("Failed to shutdown prometheus http") log.Error().Err(err).Msg("Failed to shutdown prometheus http")
} }
if err := httpServer.Shutdown(ctx); err != nil { if err := httpServer.Shutdown(ctx); err != nil {
@ -833,7 +846,7 @@ func (h *Headscale) Serve() error {
} }
// Close network listeners // Close network listeners
promHTTPListener.Close() debugHTTPListener.Close()
httpListener.Close() httpListener.Close()
grpcGatewayConn.Close() grpcGatewayConn.Close()
@ -898,7 +911,7 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
server := &http.Server{ server := &http.Server{
Addr: h.cfg.TLS.LetsEncrypt.Listen, Addr: h.cfg.TLS.LetsEncrypt.Listen,
Handler: certManager.HTTPHandler(http.HandlerFunc(h.redirect)), Handler: certManager.HTTPHandler(http.HandlerFunc(h.redirect)),
ReadTimeout: types.HTTPReadTimeout, ReadTimeout: types.HTTPTimeout,
} }
go func() { go func() {

View file

@ -4,11 +4,13 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/netip" "net/netip"
"sort"
"github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/util/set"
) )
var ErrRouteIsNotAvailable = errors.New("route is not available") var ErrRouteIsNotAvailable = errors.New("route is not available")
@ -402,11 +404,10 @@ func SaveNodeRoutes(tx *gorm.DB, node *types.Node) (bool, error) {
return sendUpdate, nil return sendUpdate, nil
} }
// FailoverRouteIfAvailable takes a node and checks if the node's route // FailoverNodeRoutesIfNeccessary takes a node and checks if the node's route
// currently have a functioning host that exposes the network. // need to be failed over to another host.
// If it does not, it is failed over to another suitable route if there // If needed, the failover will be attempted.
// is one. func FailoverNodeRoutesIfNeccessary(
func FailoverRouteIfAvailable(
tx *gorm.DB, tx *gorm.DB,
isConnected types.NodeConnectedMap, isConnected types.NodeConnectedMap,
node *types.Node, node *types.Node,
@ -416,8 +417,12 @@ func FailoverRouteIfAvailable(
return nil, nil return nil, nil
} }
var changedNodes []types.NodeID log.Trace().Msgf("NODE ROUTES: %d", len(nodeRoutes))
changedNodes := make(set.Set[types.NodeID])
nodeRouteLoop:
for _, nodeRoute := range nodeRoutes { for _, nodeRoute := range nodeRoutes {
log.Trace().Msgf("NODE ROUTE: %d", nodeRoute.ID)
routes, err := getRoutesByPrefix(tx, netip.Prefix(nodeRoute.Prefix)) routes, err := getRoutesByPrefix(tx, netip.Prefix(nodeRoute.Prefix))
if err != nil { if err != nil {
return nil, fmt.Errorf("getting routes by prefix: %w", err) return nil, fmt.Errorf("getting routes by prefix: %w", err)
@ -427,29 +432,37 @@ func FailoverRouteIfAvailable(
if route.IsPrimary { if route.IsPrimary {
// if we have a primary route, and the node is connected // if we have a primary route, and the node is connected
// nothing needs to be done. // nothing needs to be done.
if isConnected[route.Node.ID] { if conn, ok := isConnected[route.Node.ID]; conn && ok {
return nil, nil continue nodeRouteLoop
} }
// if not, we need to failover the route // if not, we need to failover the route
failover := failoverRoute(isConnected, &route, routes) failover := failoverRoute(isConnected, &route, routes)
if failover != nil { if failover != nil {
failover.save(tx) err := failover.save(tx)
if err != nil { if err != nil {
return nil, fmt.Errorf("saving failover routes: %w", err) return nil, fmt.Errorf("saving failover routes: %w", err)
} }
changedNodes = append(changedNodes, failover.old.Node.ID, failover.new.Node.ID) changedNodes.Add(failover.old.Node.ID)
changedNodes.Add(failover.new.Node.ID)
continue nodeRouteLoop
} }
} }
} }
} }
chng := changedNodes.Slice()
sort.SliceStable(chng, func(i, j int) bool {
return chng[i] < chng[j]
})
if len(changedNodes) != 0 { if len(changedNodes) != 0 {
return &types.StateUpdate{ return &types.StateUpdate{
Type: types.StatePeerChanged, Type: types.StatePeerChanged,
ChangeNodes: changedNodes, ChangeNodes: chng,
Message: "called from db.FailoverRouteIfAvailable", Message: "called from db.FailoverNodeRoutesIfNeccessary",
}, nil }, nil
} }

View file

@ -7,9 +7,9 @@ import (
"time" "time"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/stretchr/testify/assert"
"gopkg.in/check.v1" "gopkg.in/check.v1"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
@ -270,6 +270,370 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
} }
var ipp = func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) } var ipp = func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) }
var n = func(nid types.NodeID) types.Node {
return types.Node{ID: nid}
}
var np = func(nid types.NodeID) *types.Node {
no := n(nid)
return &no
}
var r = func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) types.Route {
return types.Route{
Model: gorm.Model{
ID: id,
},
Node: n(nid),
Prefix: prefix,
Enabled: enabled,
IsPrimary: primary,
}
}
var rp = func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) *types.Route {
ro := r(id, nid, prefix, enabled, primary)
return &ro
}
func dbForTest(t *testing.T, testName string) *HSDatabase {
t.Helper()
tmpDir, err := os.MkdirTemp("", testName)
if err != nil {
t.Fatalf("creating tempdir: %s", err)
}
dbPath := tmpDir + "/headscale_test.db"
db, err = NewHeadscaleDatabase(
types.DatabaseConfig{
Type: "sqlite3",
Sqlite: types.SqliteConfig{
Path: dbPath,
},
},
"",
)
if err != nil {
t.Fatalf("setting up database: %s", err)
}
t.Logf("database set up at: %s", dbPath)
return db
}
func TestFailoverNodeRoutesIfNeccessary(t *testing.T) {
su := func(nids ...types.NodeID) *types.StateUpdate {
return &types.StateUpdate{
ChangeNodes: nids,
}
}
tests := []struct {
name string
nodes types.Nodes
routes types.Routes
isConnected []types.NodeConnectedMap
want []*types.StateUpdate
wantErr bool
}{
{
name: "n1-down-n2-down-n1-up",
nodes: types.Nodes{
np(1),
np(2),
np(1),
},
routes: types.Routes{
r(1, 1, ipp("10.0.0.0/24"), true, true),
r(2, 2, ipp("10.0.0.0/24"), true, false),
},
isConnected: []types.NodeConnectedMap{
// n1 goes down
{
1: false,
2: true,
},
// n2 goes down
{
1: false,
2: false,
},
// n1 comes up
{
1: true,
2: false,
},
},
want: []*types.StateUpdate{
// route changes from 1 -> 2
su(1, 2),
// both down, no change
nil,
// route changes from 2 -> 1
su(1, 2),
},
},
{
name: "n1-recon-n2-down-n1-recon-n2-up",
nodes: types.Nodes{
np(1),
np(2),
np(1),
np(2),
},
routes: types.Routes{
r(1, 1, ipp("10.0.0.0/24"), true, true),
r(2, 2, ipp("10.0.0.0/24"), true, false),
},
isConnected: []types.NodeConnectedMap{
// n1 up recon = noop
{
1: true,
2: true,
},
// n2 goes down
{
1: true,
2: false,
},
// n1 up recon = noop
{
1: true,
2: false,
},
// n2 comes back up
{
1: true,
2: false,
},
},
want: []*types.StateUpdate{
nil,
nil,
nil,
nil,
},
},
{
name: "n1-recon-n2-down-n1-recon-n2-up",
nodes: types.Nodes{
np(1),
np(1),
np(3),
np(3),
np(2),
np(1),
},
routes: types.Routes{
r(1, 1, ipp("10.0.0.0/24"), true, true),
r(2, 2, ipp("10.0.0.0/24"), true, false),
r(3, 3, ipp("10.0.0.0/24"), true, false),
},
isConnected: []types.NodeConnectedMap{
// n1 goes down
{
1: false,
2: false,
3: true,
},
// n1 comes up
{
1: true,
2: false,
3: true,
},
// n3 goes down
{
1: true,
2: false,
3: false,
},
// n3 comes up
{
1: true,
2: false,
3: true,
},
// n2 comes up
{
1: true,
2: true,
3: true,
},
// n1 goes down
{
1: false,
2: true,
3: true,
},
},
want: []*types.StateUpdate{
su(1, 3), // n1 -> n3
nil,
su(1, 3), // n3 -> n1
nil,
nil,
su(1, 2), // n1 -> n2
},
},
{
name: "n1-recon-n2-dis-n3-take",
nodes: types.Nodes{
np(1),
np(3),
},
routes: types.Routes{
r(1, 1, ipp("10.0.0.0/24"), true, true),
r(2, 2, ipp("10.0.0.0/24"), false, false),
r(3, 3, ipp("10.0.0.0/24"), true, false),
},
isConnected: []types.NodeConnectedMap{
// n1 goes down
{
1: false,
2: true,
3: true,
},
// n3 goes down
{
1: false,
2: true,
3: false,
},
},
want: []*types.StateUpdate{
su(1, 3), // n1 -> n3
nil,
},
},
{
name: "multi-n1-oneforeach-n2-n3",
nodes: types.Nodes{
np(1),
},
routes: types.Routes{
r(1, 1, ipp("10.0.0.0/24"), true, true),
r(4, 1, ipp("10.1.0.0/24"), true, true),
r(2, 2, ipp("10.0.0.0/24"), true, false),
r(3, 3, ipp("10.1.0.0/24"), true, false),
},
isConnected: []types.NodeConnectedMap{
// n1 goes down
{
1: false,
2: true,
3: true,
},
},
want: []*types.StateUpdate{
su(1, 2, 3), // n1 -> n2,n3
},
},
{
name: "multi-n1-onefor-n2-disabled-n3",
nodes: types.Nodes{
np(1),
},
routes: types.Routes{
r(1, 1, ipp("10.0.0.0/24"), true, true),
r(4, 1, ipp("10.1.0.0/24"), true, true),
r(2, 2, ipp("10.0.0.0/24"), true, false),
r(3, 3, ipp("10.1.0.0/24"), false, false),
},
isConnected: []types.NodeConnectedMap{
// n1 goes down
{
1: false,
2: true,
3: true,
},
},
want: []*types.StateUpdate{
su(1, 2), // n1 -> n2, n3 is not enabled
},
},
{
name: "multi-n1-onefor-n2-offline-n3",
nodes: types.Nodes{
np(1),
},
routes: types.Routes{
r(1, 1, ipp("10.0.0.0/24"), true, true),
r(4, 1, ipp("10.1.0.0/24"), true, true),
r(2, 2, ipp("10.0.0.0/24"), true, false),
r(3, 3, ipp("10.1.0.0/24"), true, false),
},
isConnected: []types.NodeConnectedMap{
// n1 goes down
{
1: false,
2: true,
3: false,
},
},
want: []*types.StateUpdate{
su(1, 2), // n1 -> n2, n3 is offline
},
},
{
name: "multi-n2-back-to-multi-n1",
nodes: types.Nodes{
np(1),
},
routes: types.Routes{
r(1, 1, ipp("10.0.0.0/24"), true, false),
r(4, 1, ipp("10.1.0.0/24"), true, true),
r(2, 2, ipp("10.0.0.0/24"), true, true),
r(3, 3, ipp("10.1.0.0/24"), true, false),
},
isConnected: []types.NodeConnectedMap{
// n1 goes down
{
1: true,
2: false,
3: true,
},
},
want: []*types.StateUpdate{
su(1, 2), // n2 -> n1
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if (len(tt.isConnected) != len(tt.want)) && len(tt.want) != len(tt.nodes) {
t.Fatalf("nodes (%d), isConnected updates (%d), wants (%d) must be equal", len(tt.nodes), len(tt.isConnected), len(tt.want))
}
db := dbForTest(t, tt.name)
for _, route := range tt.routes {
if err := db.DB.Save(&route).Error; err != nil {
t.Fatalf("failed to create route: %s", err)
}
}
for step := range len(tt.isConnected) {
node := tt.nodes[step]
isConnected := tt.isConnected[step]
want := tt.want[step]
got, err := Write(db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
return FailoverNodeRoutesIfNeccessary(tx, isConnected, node)
})
if (err != nil) != tt.wantErr {
t.Errorf("failoverRoute() error = %v, wantErr %v", err, tt.wantErr)
return
}
if diff := cmp.Diff(want, got, cmpopts.IgnoreFields(types.StateUpdate{}, "Type", "Message")); diff != "" {
t.Errorf("failoverRoute() unexpected result (-want +got):\n%s", diff)
}
}
})
}
}
func TestFailoverRouteTx(t *testing.T) { func TestFailoverRouteTx(t *testing.T) {
tests := []struct { tests := []struct {
@ -637,19 +1001,7 @@ func TestFailoverRouteTx(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "failover-db-test") db := dbForTest(t, tt.name)
assert.NoError(t, err)
db, err = NewHeadscaleDatabase(
types.DatabaseConfig{
Type: "sqlite3",
Sqlite: types.SqliteConfig{
Path: tmpDir + "/headscale_test.db",
},
},
"",
)
assert.NoError(t, err)
for _, route := range tt.routes { for _, route := range tt.routes {
if err := db.DB.Save(&route).Error; err != nil { if err := db.DB.Save(&route).Error; err != nil {

View file

@ -31,7 +31,7 @@ func loadDERPMapFromPath(path string) (*tailcfg.DERPMap, error) {
} }
func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) { func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) {
ctx, cancel := context.WithTimeout(context.Background(), types.HTTPReadTimeout) ctx, cancel := context.WithTimeout(context.Background(), types.HTTPTimeout)
defer cancel() defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, addr.String(), nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, addr.String(), nil)
@ -40,7 +40,7 @@ func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) {
} }
client := http.Client{ client := http.Client{
Timeout: types.HTTPReadTimeout, Timeout: types.HTTPTimeout,
} }
resp, err := client.Do(req) resp, err := client.Do(req)

View file

@ -3,7 +3,6 @@ package hscontrol
import ( import (
"encoding/binary" "encoding/binary"
"encoding/json" "encoding/json"
"errors"
"io" "io"
"net/http" "net/http"
@ -12,7 +11,6 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/h2c" "golang.org/x/net/http2/h2c"
"gorm.io/gorm"
"tailscale.com/control/controlbase" "tailscale.com/control/controlbase"
"tailscale.com/control/controlhttp" "tailscale.com/control/controlhttp"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
@ -103,12 +101,12 @@ func (h *Headscale) NoiseUpgradeHandler(
router.HandleFunc("/machine/map", noiseServer.NoisePollNetMapHandler) router.HandleFunc("/machine/map", noiseServer.NoisePollNetMapHandler)
server := http.Server{ server := http.Server{
ReadTimeout: types.HTTPReadTimeout, ReadTimeout: types.HTTPTimeout,
} }
noiseServer.httpBaseConfig = &http.Server{ noiseServer.httpBaseConfig = &http.Server{
Handler: router, Handler: router,
ReadHeaderTimeout: types.HTTPReadTimeout, ReadHeaderTimeout: types.HTTPTimeout,
} }
noiseServer.http2Server = &http2.Server{} noiseServer.http2Server = &http2.Server{}
@ -225,15 +223,6 @@ func (ns *noiseServer) NoisePollNetMapHandler(
key.NodePublic{}, key.NodePublic{},
) )
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Warn().
Str("handler", "NoisePollNetMap").
Uint64("node.id", node.ID.Uint64()).
Msgf("Ignoring request, cannot find node with key %s", mapRequest.NodeKey.String())
http.Error(writer, "Internal error", http.StatusNotFound)
return
}
log.Error(). log.Error().
Str("handler", "NoisePollNetMap"). Str("handler", "NoisePollNetMap").
Uint64("node.id", node.ID.Uint64()). Uint64("node.id", node.ID.Uint64()).
@ -242,58 +231,59 @@ func (ns *noiseServer) NoisePollNetMapHandler(
return return
} }
log.Debug(). sess := ns.headscale.newMapSession(req.Context(), mapRequest, writer, node)
Str("handler", "NoisePollNetMap").
Str("node", node.Hostname).
Int("cap_ver", int(mapRequest.Version)).
Uint64("node.id", node.ID.Uint64()).
Msg("A node sending a MapRequest with Noise protocol")
session := ns.headscale.newMapSession(req.Context(), mapRequest, writer, node) sess.tracef("a node sending a MapRequest with Noise protocol")
// If a streaming mapSession exists for this node, close it // If a streaming mapSession exists for this node, close it
// and start a new one. // and start a new one.
if session.isStreaming() { if sess.isStreaming() {
log.Debug(). sess.tracef("aquiring lock to check stream")
Caller().
Uint64("node.id", node.ID.Uint64()).
Int("cap_ver", int(mapRequest.Version)).
Msg("Aquiring lock to check stream")
ns.headscale.mapSessionMu.Lock() ns.headscale.mapSessionMu.Lock()
if oldSession, ok := ns.headscale.mapSessions[node.ID]; ok { if _, ok := ns.headscale.mapSessions[node.ID]; ok {
log.Info(). // NOTE/TODO(kradalby): From how I understand the protocol, when
Caller(). // a client connects with stream=true, and already has a streaming
Uint64("node.id", node.ID.Uint64()). // connection open, the correct way is to close the current channel
Msg("Node has an open streaming session, replacing") // and replace it. However, I cannot manage to get that working with
oldSession.close() // some sort of lock/block happening on the cancelCh in the streaming
// session.
// Not closing the channel and replacing it puts us in a weird state
// which keeps a ghost stream open, receiving keep alives, but no updates.
//
// Typically a new connection is opened when one exists as a client which
// is already authenticated reconnects (e.g. down, then up). The client will
// start auth and streaming at the same time, and then cancel the streaming
// when the auth has finished successfully, opening a new connection.
//
// As a work-around to not replacing, abusing the clients "resilience"
// by reject the new connection which will cause the client to immediately
// reconnect and "fix" the issue, as the other connection typically has been
// closed, meaning there is nothing to replace.
//
// sess.infof("node has an open stream(%p), replacing with %p", oldSession, sess)
// oldSession.close()
defer ns.headscale.mapSessionMu.Unlock()
sess.infof("node has an open stream(%p), rejecting new stream", sess)
return
} }
ns.headscale.mapSessions[node.ID] = session ns.headscale.mapSessions[node.ID] = sess
ns.headscale.mapSessionMu.Unlock() ns.headscale.mapSessionMu.Unlock()
log.Debug(). sess.tracef("releasing lock to check stream")
Caller().
Uint64("node.id", node.ID.Uint64()).
Int("cap_ver", int(mapRequest.Version)).
Msg("Releasing lock to check stream")
} }
session.serve() sess.serve()
if session.isStreaming() { if sess.isStreaming() {
log.Debug(). sess.tracef("aquiring lock to remove stream")
Caller().
Uint64("node.id", node.ID.Uint64()).
Int("cap_ver", int(mapRequest.Version)).
Msg("Aquiring lock to remove stream")
ns.headscale.mapSessionMu.Lock() ns.headscale.mapSessionMu.Lock()
defer ns.headscale.mapSessionMu.Unlock()
delete(ns.headscale.mapSessions, node.ID) delete(ns.headscale.mapSessions, node.ID)
ns.headscale.mapSessionMu.Unlock() sess.tracef("releasing lock to remove stream")
log.Debug().
Caller().
Uint64("node.id", node.ID.Uint64()).
Int("cap_ver", int(mapRequest.Version)).
Msg("Releasing lock to remove stream")
} }
} }

View file

@ -172,11 +172,19 @@ func (n *Notifier) String() string {
n.l.RLock() n.l.RLock()
defer n.l.RUnlock() defer n.l.RUnlock()
str := []string{"Notifier, in map:\n"} var b strings.Builder
b.WriteString("chans:\n")
for k, v := range n.nodes { for k, v := range n.nodes {
str = append(str, fmt.Sprintf("\t%d: %v\n", k, v)) fmt.Fprintf(&b, "\t%d: %p\n", k, v)
} }
return strings.Join(str, "") b.WriteString("\n")
b.WriteString("connected:\n")
for k, v := range n.connected {
fmt.Fprintf(&b, "\t%d: %t\n", k, v)
}
return b.String()
} }

View file

@ -48,6 +48,8 @@ type mapSession struct {
ch chan types.StateUpdate ch chan types.StateUpdate
cancelCh chan struct{} cancelCh chan struct{}
keepAliveTicker *time.Ticker
node *types.Node node *types.Node
w http.ResponseWriter w http.ResponseWriter
@ -85,6 +87,8 @@ func (h *Headscale) newMapSession(
ch: updateChan, ch: updateChan,
cancelCh: make(chan struct{}), cancelCh: make(chan struct{}),
keepAliveTicker: time.NewTicker(keepAliveInterval + (time.Duration(rand.IntN(9000)) * time.Millisecond)),
// Loggers // Loggers
warnf: warnf, warnf: warnf,
infof: infof, infof: infof,
@ -100,10 +104,9 @@ func (m *mapSession) close() {
return return
} }
select { m.tracef("mapSession (%p) sending message on cancel chan")
case m.cancelCh <- struct{}{}: m.cancelCh <- struct{}{}
default: m.tracef("mapSession (%p) sent message on cancel chan")
}
} }
func (m *mapSession) isStreaming() bool { func (m *mapSession) isStreaming() bool {
@ -118,13 +121,6 @@ func (m *mapSession) isReadOnlyUpdate() bool {
return !m.req.Stream && m.req.OmitPeers && m.req.ReadOnly return !m.req.Stream && m.req.OmitPeers && m.req.ReadOnly
} }
func (m *mapSession) flush200() {
m.w.WriteHeader(http.StatusOK)
if f, ok := m.w.(http.Flusher); ok {
f.Flush()
}
}
// handlePoll ensures the node gets the appropriate updates from either // handlePoll ensures the node gets the appropriate updates from either
// polling or immediate responses. // polling or immediate responses.
// //
@ -211,7 +207,12 @@ func (m *mapSession) serve() {
m.pollFailoverRoutes("node connected", m.node) m.pollFailoverRoutes("node connected", m.node)
keepAliveTicker := time.NewTicker(keepAliveInterval + (time.Duration(rand.IntN(9000)) * time.Millisecond)) // Upgrade the writer to a ResponseController
rc := http.NewResponseController(m.w)
// Longpolling will break if there is a write timeout,
// so it needs to be disabled.
rc.SetWriteDeadline(time.Time{})
ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, m.node.Hostname)) ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, m.node.Hostname))
defer cancel() defer cancel()
@ -324,18 +325,16 @@ func (m *mapSession) serve() {
startWrite := time.Now() startWrite := time.Now()
_, err = m.w.Write(data) _, err = m.w.Write(data)
if err != nil { if err != nil {
m.errf(err, "Could not write the map response, for mapSession: %p, stream: %t", m, m.isStreaming()) m.errf(err, "Could not write the map response, for mapSession: %p", m)
return return
} }
if flusher, ok := m.w.(http.Flusher); ok { err = rc.Flush()
flusher.Flush() if err != nil {
} else { m.errf(err, "flushing the map response to client, for mapSession: %p", m)
log.Error().Msg("Failed to create http flusher")
return return
} }
log.Trace().Str("node", m.node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey.String()).Msg("finished writing mapresp to node") log.Trace().Str("node", m.node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey.String()).Msg("finished writing mapresp to node")
m.infof("update sent") m.infof("update sent")
@ -402,7 +401,7 @@ func (m *mapSession) serve() {
derp = true derp = true
} }
case <-keepAliveTicker.C: case <-m.keepAliveTicker.C:
data, err := m.mapper.KeepAliveResponse(m.req, m.node) data, err := m.mapper.KeepAliveResponse(m.req, m.node)
if err != nil { if err != nil {
m.errf(err, "Error generating the keep alive msg") m.errf(err, "Error generating the keep alive msg")
@ -415,11 +414,9 @@ func (m *mapSession) serve() {
return return
} }
if flusher, ok := m.w.(http.Flusher); ok { err = rc.Flush()
flusher.Flush() if err != nil {
} else { m.errf(err, "flushing keep alive to client, for mapSession: %p", m)
log.Error().Msg("Failed to create http flusher")
return return
} }
} }
@ -428,7 +425,7 @@ func (m *mapSession) serve() {
func (m *mapSession) pollFailoverRoutes(where string, node *types.Node) { func (m *mapSession) pollFailoverRoutes(where string, node *types.Node) {
update, err := db.Write(m.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { update, err := db.Write(m.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
return db.FailoverRouteIfAvailable(tx, m.h.nodeNotifier.ConnectedMap(), node) return db.FailoverNodeRoutesIfNeccessary(tx, m.h.nodeNotifier.ConnectedMap(), node)
}) })
if err != nil { if err != nil {
m.errf(err, fmt.Sprintf("failed to ensure failover routes, %s", where)) m.errf(err, fmt.Sprintf("failed to ensure failover routes, %s", where))
@ -565,7 +562,7 @@ func (m *mapSession) handleEndpointUpdate() {
}, },
m.node.ID) m.node.ID)
m.flush200() m.w.WriteHeader(http.StatusOK)
return return
} }
@ -654,7 +651,9 @@ func (m *mapSession) handleReadOnlyRequest() {
m.errf(err, "Failed to write response") m.errf(err, "Failed to write response")
} }
m.flush200() m.w.WriteHeader(http.StatusOK)
return
} }
func logTracePeerChange(hostname string, hostinfoChange bool, change *tailcfg.PeerChange) { func logTracePeerChange(hostname string, hostinfoChange bool, change *tailcfg.PeerChange) {

View file

@ -3,7 +3,7 @@ package types
import "time" import "time"
const ( const (
HTTPReadTimeout = 30 * time.Second HTTPTimeout = 30 * time.Second
HTTPShutdownTimeout = 3 * time.Second HTTPShutdownTimeout = 3 * time.Second
TLSALPN01ChallengeType = "TLS-ALPN-01" TLSALPN01ChallengeType = "TLS-ALPN-01"
HTTP01ChallengeType = "HTTP-01" HTTP01ChallengeType = "HTTP-01"

View file

@ -124,7 +124,7 @@ func DefaultConfigEnv() map[string]string {
"HEADSCALE_PRIVATE_KEY_PATH": "/tmp/private.key", "HEADSCALE_PRIVATE_KEY_PATH": "/tmp/private.key",
"HEADSCALE_NOISE_PRIVATE_KEY_PATH": "/tmp/noise_private.key", "HEADSCALE_NOISE_PRIVATE_KEY_PATH": "/tmp/noise_private.key",
"HEADSCALE_LISTEN_ADDR": "0.0.0.0:8080", "HEADSCALE_LISTEN_ADDR": "0.0.0.0:8080",
"HEADSCALE_METRICS_LISTEN_ADDR": "127.0.0.1:9090", "HEADSCALE_METRICS_LISTEN_ADDR": "0.0.0.0:9090",
"HEADSCALE_SERVER_URL": "http://headscale:8080", "HEADSCALE_SERVER_URL": "http://headscale:8080",
"HEADSCALE_DERP_URLS": "https://controlplane.tailscale.com/derpmap/default", "HEADSCALE_DERP_URLS": "https://controlplane.tailscale.com/derpmap/default",
"HEADSCALE_DERP_AUTO_UPDATE_ENABLED": "false", "HEADSCALE_DERP_AUTO_UPDATE_ENABLED": "false",

View file

@ -260,7 +260,7 @@ func New(
runOptions := &dockertest.RunOptions{ runOptions := &dockertest.RunOptions{
Name: hsic.hostname, Name: hsic.hostname,
ExposedPorts: append([]string{portProto}, hsic.extraPorts...), ExposedPorts: append([]string{portProto, "9090/tcp"}, hsic.extraPorts...),
Networks: []*dockertest.Network{network}, Networks: []*dockertest.Network{network},
// Cmd: []string{"headscale", "serve"}, // Cmd: []string{"headscale", "serve"},
// TODO(kradalby): Get rid of this hack, we currently need to give us some // TODO(kradalby): Get rid of this hack, we currently need to give us some

View file

@ -252,7 +252,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
scenario, err := NewScenario() scenario, err := NewScenario()
assertNoErrf(t, "failed to create scenario: %s", err) assertNoErrf(t, "failed to create scenario: %s", err)
defer scenario.Shutdown() // defer scenario.Shutdown()
spec := map[string]int{ spec := map[string]int{
user: 3, user: 3,