From 036061664ef1f1f97058a65eba1d97a82ef2cd20 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Fri, 13 Aug 2021 16:12:01 +0100 Subject: [PATCH 01/18] initial integration test file --- .github/workflows/test-integration.yml | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 .github/workflows/test-integration.yml diff --git a/.github/workflows/test-integration.yml b/.github/workflows/test-integration.yml new file mode 100644 index 00000000..e939df22 --- /dev/null +++ b/.github/workflows/test-integration.yml @@ -0,0 +1,23 @@ +name: CI + +on: [pull_request] + +jobs: + # The "build" workflow + integration-test: + # The type of runner that the job will run on + runs-on: ubuntu-latest + + # Steps represent a sequence of tasks that will be executed as part of the job + steps: + # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it + - uses: actions/checkout@v2 + + # Setup Go + - name: Setup Go + uses: actions/setup-go@v2 + with: + go-version: "1.16.3" + + - name: Run Integration tests + run: go test -tags integration -timeout 30m From 7d1a5c00a099d103103be9312823a8f3651d61b7 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Fri, 13 Aug 2021 16:56:28 +0100 Subject: [PATCH 02/18] Try with longer timeout --- integration_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration_test.go b/integration_test.go index dd96fb82..4c1c54b4 100644 --- a/integration_test.go +++ b/integration_test.go @@ -295,7 +295,7 @@ func (s *IntegrationTestSuite) TestPingAllPeers() { s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) { // We currently cant ping ourselves, so skip that. if peername != hostname { - command := []string{"tailscale", "ping", "--timeout=1s", "--c=1", ip.String()} + command := []string{"tailscale", "ping", "--timeout=5s", "--c=1", ip.String()} fmt.Printf("Pinging from %s (%s) to %s (%s)\n", hostname, ips[hostname], peername, ip) result, err := executeCommand( From 6fa61380b25c522d43f4889993ee0ad55b432c5c Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 18 Aug 2021 23:17:09 +0100 Subject: [PATCH 03/18] Up client count, make arguments more explicit and clean up unused assignments --- integration_test.go | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/integration_test.go b/integration_test.go index 4c1c54b4..892c7ece 100644 --- a/integration_test.go +++ b/integration_test.go @@ -34,7 +34,7 @@ var ih Headscale var pool dockertest.Pool var network dockertest.Network var headscale dockertest.Resource -var tailscaleCount int = 5 +var tailscaleCount int = 20 var tailscales map[string]dockertest.Resource func executeCommand(resource *dockertest.Resource, cmd []string) (string, error) { @@ -115,7 +115,6 @@ func (s *IntegrationTestSuite) SetupSuite() { PortBindings: map[docker.Port][]docker.PortBinding{ "8080/tcp": []docker.PortBinding{{HostPort: "8080"}}, }, - Env: []string{}, } fmt.Println("Creating headscale container") @@ -134,7 +133,6 @@ func (s *IntegrationTestSuite) SetupSuite() { Name: hostname, Networks: []*dockertest.Network{&network}, Cmd: []string{"tailscaled", "--tun=userspace-networking", "--socks5-server=localhost:1055"}, - Env: []string{}, } if pts, err := pool.BuildAndRunWithBuildOptions(tailscaleBuildOptions, tailscaleOptions, dockerRestartPolicy); err == nil { @@ -145,7 +143,6 @@ func (s *IntegrationTestSuite) SetupSuite() { fmt.Printf("Created %s container\n", hostname) } - // TODO: Replace this logic with something that can be detected on Github Actions fmt.Println("Waiting for headscale to be ready") hostEndpoint := fmt.Sprintf("localhost:%s", headscale.GetPort("8080/tcp")) @@ -197,20 +194,20 @@ func (s *IntegrationTestSuite) SetupSuite() { // The nodes need a bit of time to get their updated maps from headscale // TODO: See if we can have a more deterministic wait here. - time.Sleep(20 * time.Second) + time.Sleep(120 * time.Second) } func (s *IntegrationTestSuite) TearDownSuite() { - if err := pool.Purge(&headscale); err != nil { - log.Printf("Could not purge resource: %s\n", err) - } - for _, tailscale := range tailscales { if err := pool.Purge(&tailscale); err != nil { log.Printf("Could not purge resource: %s\n", err) } } + if err := pool.Purge(&headscale); err != nil { + log.Printf("Could not purge resource: %s\n", err) + } + if err := network.Close(); err != nil { log.Printf("Could not close network: %s\n", err) } @@ -295,7 +292,15 @@ func (s *IntegrationTestSuite) TestPingAllPeers() { s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) { // We currently cant ping ourselves, so skip that. if peername != hostname { - command := []string{"tailscale", "ping", "--timeout=5s", "--c=1", ip.String()} + // We are only interested in "direct ping" which means what we + // might need a couple of more attempts before reaching the node. + command := []string{ + "tailscale", "ping", + "--timeout=1s", + "--c=20", + "--until-direct=true", + ip.String(), + } fmt.Printf("Pinging from %s (%s) to %s (%s)\n", hostname, ips[hostname], peername, ip) result, err := executeCommand( From 2f883410d2a53f291979260810878d8f062b0771 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 18 Aug 2021 23:17:38 +0100 Subject: [PATCH 04/18] Add lastUpdate field to machine, function issue message on update channel This commit adds a new field to machine, lastSuccessfulUpdate which tracks when we last was able to send a proper mapupdate to the node. The purpose of this is to be able to compare to a "global" last updated time and determine if we need to send an update map request to a node. In addition it allows us to create a scheduled check to see if all known nodes are up to date. Also, add a helper function to send a message to the update channel of a machine. --- machine.go | 81 +++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 68 insertions(+), 13 deletions(-) diff --git a/machine.go b/machine.go index 69de453d..14efd008 100644 --- a/machine.go +++ b/machine.go @@ -2,6 +2,7 @@ package headscale import ( "encoding/json" + "errors" "fmt" "sort" "strconv" @@ -31,8 +32,9 @@ type Machine struct { AuthKeyID uint AuthKey *PreAuthKey - LastSeen *time.Time - Expiry *time.Time + LastSeen *time.Time + LastSuccessfulUpdate *time.Time + Expiry *time.Time HostInfo datatypes.JSON Endpoints datatypes.JSON @@ -211,6 +213,13 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) { return &m, nil } +func (h *Headscale) UpdateMachine(m *Machine) error { + if result := h.db.Find(m).First(&m); result.Error != nil { + return result.Error + } + return nil +} + // DeleteMachine softs deletes a Machine from the database func (h *Headscale) DeleteMachine(m *Machine) error { m.Registered = false @@ -251,21 +260,67 @@ func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) { func (h *Headscale) notifyChangesToPeers(m *Machine) { peers, _ := h.getPeers(*m) for _, p := range *peers { - pUp, ok := h.clientsPolling.Load(uint64(p.ID)) - if ok { + log.Info(). + Str("func", "notifyChangesToPeers"). + Str("machine", m.Name). + Str("peer", p.Name). + Str("address", p.Addresses[0].String()). + Msgf("Notifying peer %s (%s)", p.Name, p.Addresses[0]) + err := h.requestUpdate(p) + if err != nil { log.Info(). Str("func", "notifyChangesToPeers"). Str("machine", m.Name). - Str("peer", m.Name). - Str("address", p.Addresses[0].String()). - Msgf("Notifying peer %s (%s)", p.Name, p.Addresses[0]) - pUp.(chan []byte) <- []byte{} - } else { - log.Info(). - Str("func", "notifyChangesToPeers"). - Str("machine", m.Name). - Str("peer", m.Name). + Str("peer", p.Name). Msgf("Peer %s does not appear to be polling", p.Name) } + log.Trace(). + Str("func", "notifyChangesToPeers"). + Str("machine", m.Name). + Str("peer", p.Name). + Str("address", p.Addresses[0].String()). + Msgf("Notified peer %s (%s)", p.Name, p.Addresses[0]) } } + +func (h *Headscale) requestUpdate(m *tailcfg.Node) error { + pUp, ok := h.clientsUpdateChannels.Load(uint64(m.ID)) + if ok { + log.Info(). + Str("func", "requestUpdate"). + Str("machine", m.Name). + Msgf("Notifying peer %s", m.Name) + + if update, ok := pUp.(chan struct{}); ok { + log.Trace(). + Str("func", "requestUpdate"). + Str("machine", m.Name). + Msgf("Update channel is %#v", update) + + update <- struct{}{} + + log.Trace(). + Str("func", "requestUpdate"). + Str("machine", m.Name). + Msgf("Notified machine %s", m.Name) + } + } else { + log.Info(). + Str("func", "requestUpdate"). + Str("machine", m.Name). + Msgf("Machine %s does not appear to be polling", m.Name) + return errors.New("machine does not seem to be polling") + } + return nil +} + +func (h *Headscale) isOutdated(m *Machine) bool { + lastChange := h.getLastStateChange() + log.Trace(). + Str("func", "keepAlive"). + Str("machine", m.Name). + Time("last_successful_update", *m.LastSuccessfulUpdate). + Time("last_state_change", lastChange). + Msgf("Checking if %s is missing updates", m.Name) + return m.LastSuccessfulUpdate.Before(lastChange) +} From 57b79aa852973bbc9ac3ab6a952061b83c1d008b Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 18 Aug 2021 23:21:11 +0100 Subject: [PATCH 05/18] Set timeout, add lastupdate field This commit makes two reasonably major changes: Set a default timeout for the go HTTP server (which gin uses), which allows us to actually have broken long poll sessions fail so we can have the client re-establish them. The current 10s number is chosen randomly and we need more testing to ensure that the feature work as intended. The second is adding a last updated field to keep track of the last time we had an update that needs to be propagated to all of our clients/nodes. This will be used to keep track of our machines and if they are up to date or need us to push an update. --- app.go | 52 ++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 40 insertions(+), 12 deletions(-) diff --git a/app.go b/app.go index fcf287f7..255a7df1 100644 --- a/app.go +++ b/app.go @@ -58,7 +58,10 @@ type Headscale struct { aclPolicy *ACLPolicy aclRules *[]tailcfg.FilterRule - clientsPolling sync.Map + clientsUpdateChannels sync.Map + + lastStateChangeMutex sync.RWMutex + lastStateChange time.Time } // NewHeadscale returns the Headscale app @@ -85,12 +88,13 @@ func NewHeadscale(cfg Config) (*Headscale, error) { } h := Headscale{ - cfg: cfg, - dbType: cfg.DBtype, - dbString: dbString, - privateKey: privKey, - publicKey: &pubKey, - aclRules: &tailcfg.FilterAllowAll, // default allowall + cfg: cfg, + dbType: cfg.DBtype, + dbString: dbString, + privateKey: privKey, + publicKey: &pubKey, + aclRules: &tailcfg.FilterAllowAll, // default allowall + lastStateChange: time.Now().UTC(), } err = h.initDB() @@ -168,6 +172,13 @@ func (h *Headscale) Serve() error { go h.watchForKVUpdates(5000) go h.expireEphemeralNodes(5000) + s := &http.Server{ + Addr: h.cfg.Addr, + Handler: r, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + } + if h.cfg.TLSLetsEncryptHostname != "" { if !strings.HasPrefix(h.cfg.ServerURL, "https://") { log.Warn().Msg("Listening with TLS but ServerURL does not start with https://") @@ -179,9 +190,11 @@ func (h *Headscale) Serve() error { Cache: autocert.DirCache(h.cfg.TLSLetsEncryptCacheDir), } s := &http.Server{ - Addr: h.cfg.Addr, - TLSConfig: m.TLSConfig(), - Handler: r, + Addr: h.cfg.Addr, + TLSConfig: m.TLSConfig(), + Handler: r, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, } if h.cfg.TLSLetsEncryptChallengeType == "TLS-ALPN-01" { // Configuration via autocert with TLS-ALPN-01 (https://tools.ietf.org/html/rfc8737) @@ -206,12 +219,27 @@ func (h *Headscale) Serve() error { if !strings.HasPrefix(h.cfg.ServerURL, "http://") { log.Warn().Msg("Listening without TLS but ServerURL does not start with http://") } - err = r.Run(h.cfg.Addr) + err = s.ListenAndServe() } else { if !strings.HasPrefix(h.cfg.ServerURL, "https://") { log.Warn().Msg("Listening with TLS but ServerURL does not start with https://") } - err = r.RunTLS(h.cfg.Addr, h.cfg.TLSCertPath, h.cfg.TLSKeyPath) + err = s.ListenAndServeTLS(h.cfg.TLSCertPath, h.cfg.TLSKeyPath) } return err } + +func (h *Headscale) setLastStateChangeToNow() { + h.lastStateChangeMutex.Lock() + + now := time.Now().UTC() + h.lastStateChange = now + + h.lastStateChangeMutex.Unlock() +} + +func (h *Headscale) getLastStateChange() time.Time { + h.lastStateChangeMutex.RLock() + defer h.lastStateChangeMutex.RUnlock() + return h.lastStateChange +} From dd8c0d1e9e2415247fc37414ed468ec25e8c5f37 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 18 Aug 2021 23:24:22 +0100 Subject: [PATCH 06/18] Move most "poll" functionality to poll.go This function migrates more poll functions (including keepalive) to poll.go to keep it somehow in the same file. In addition it makes changes to improve the stability and ensure nodes get the appropriate updates from the headscale control and are not left in an inconsistent state. Two new additions is: omitpeers=true will now trigger an update if the clients are not already up to date keepalive has been extended with a timer that will check every 120s if all nodes are up to date. --- api.go | 203 ++---------------------------- poll.go | 385 ++++++++++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 365 insertions(+), 223 deletions(-) diff --git a/api.go b/api.go index 0dc2bec5..7a6b4b16 100644 --- a/api.go +++ b/api.go @@ -13,7 +13,6 @@ import ( "github.com/gin-gonic/gin" "github.com/klauspost/compress/zstd" - "gorm.io/datatypes" "gorm.io/gorm" "inet.af/netaddr" "tailscale.com/tailcfg" @@ -82,14 +81,16 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { return } + now := time.Now().UTC() var m Machine if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) { log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine") m = Machine{ - Expiry: &req.Expiry, - MachineKey: mKey.HexString(), - Name: req.Hostinfo.Hostname, - NodeKey: wgkey.Key(req.NodeKey).HexString(), + Expiry: &req.Expiry, + MachineKey: mKey.HexString(), + Name: req.Hostinfo.Hostname, + NodeKey: wgkey.Key(req.NodeKey).HexString(), + LastSuccessfulUpdate: &now, } if err := h.db.Create(&m).Error; err != nil { log.Error(). @@ -215,196 +216,6 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { c.Data(200, "application/json; charset=utf-8", respBody) } -// PollNetMapHandler takes care of /machine/:id/map -// -// This is the busiest endpoint, as it keeps the HTTP long poll that updates -// the clients when something in the network changes. -// -// The clients POST stuff like HostInfo and their Endpoints here, but -// only after their first request (marked with the ReadOnly field). -// -// At this moment the updates are sent in a quite horrendous way, but they kinda work. -func (h *Headscale) PollNetMapHandler(c *gin.Context) { - log.Trace(). - Str("handler", "PollNetMap"). - Str("id", c.Param("id")). - Msg("PollNetMapHandler called") - body, _ := io.ReadAll(c.Request.Body) - mKeyStr := c.Param("id") - mKey, err := wgkey.ParseHex(mKeyStr) - if err != nil { - log.Error(). - Str("handler", "PollNetMap"). - Err(err). - Msg("Cannot parse client key") - c.String(http.StatusBadRequest, "") - return - } - req := tailcfg.MapRequest{} - err = decode(body, &req, &mKey, h.privateKey) - if err != nil { - log.Error(). - Str("handler", "PollNetMap"). - Err(err). - Msg("Cannot decode message") - c.String(http.StatusBadRequest, "") - return - } - - var m Machine - if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) { - log.Warn(). - Str("handler", "PollNetMap"). - Msgf("Ignoring request, cannot find machine with key %s", mKey.HexString()) - c.String(http.StatusUnauthorized, "") - return - } - log.Trace(). - Str("handler", "PollNetMap"). - Str("id", c.Param("id")). - Str("machine", m.Name). - Msg("Found machine in database") - - hostinfo, _ := json.Marshal(req.Hostinfo) - m.Name = req.Hostinfo.Hostname - m.HostInfo = datatypes.JSON(hostinfo) - m.DiscoKey = wgkey.Key(req.DiscoKey).HexString() - now := time.Now().UTC() - - // From Tailscale client: - // - // ReadOnly is whether the client just wants to fetch the MapResponse, - // without updating their Endpoints. The Endpoints field will be ignored and - // LastSeen will not be updated and peers will not be notified of changes. - // - // The intended use is for clients to discover the DERP map at start-up - // before their first real endpoint update. - if !req.ReadOnly { - endpoints, _ := json.Marshal(req.Endpoints) - m.Endpoints = datatypes.JSON(endpoints) - m.LastSeen = &now - } - h.db.Save(&m) - - data, err := h.getMapResponse(mKey, req, m) - if err != nil { - log.Error(). - Str("handler", "PollNetMap"). - Str("id", c.Param("id")). - Str("machine", m.Name). - Err(err). - Msg("Failed to get Map response") - c.String(http.StatusInternalServerError, ":(") - return - } - - // We update our peers if the client is not sending ReadOnly in the MapRequest - // so we don't distribute its initial request (it comes with - // empty endpoints to peers) - - // Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696 - log.Debug(). - Str("handler", "PollNetMap"). - Str("id", c.Param("id")). - Str("machine", m.Name). - Bool("readOnly", req.ReadOnly). - Bool("omitPeers", req.OmitPeers). - Bool("stream", req.Stream). - Msg("Client map request processed") - - if req.ReadOnly { - log.Info(). - Str("handler", "PollNetMap"). - Str("machine", m.Name). - Msg("Client is starting up. Asking for DERP map") - c.Data(200, "application/json; charset=utf-8", *data) - return - } - if req.OmitPeers && !req.Stream { - log.Info(). - Str("handler", "PollNetMap"). - Str("machine", m.Name). - Msg("Client sent endpoint update and is ok with a response without peer list") - c.Data(200, "application/json; charset=utf-8", *data) - return - } else if req.OmitPeers && req.Stream { - log.Warn(). - Str("handler", "PollNetMap"). - Str("machine", m.Name). - Msg("Ignoring request, don't know how to handle it") - c.String(http.StatusBadRequest, "") - return - } - - // Only create update channel if it has not been created - var update chan []byte - log.Trace(). - Str("handler", "PollNetMap"). - Str("id", c.Param("id")). - Str("machine", m.Name). - Msg("Creating or loading update channel") - if result, ok := h.clientsPolling.LoadOrStore(m.ID, make(chan []byte, 1)); ok { - update = result.(chan []byte) - } - - pollData := make(chan []byte, 1) - defer close(pollData) - - cancelKeepAlive := make(chan []byte, 1) - defer close(cancelKeepAlive) - - log.Info(). - Str("handler", "PollNetMap"). - Str("machine", m.Name). - Msg("Client is ready to access the tailnet") - log.Info(). - Str("handler", "PollNetMap"). - Str("machine", m.Name). - Msg("Sending initial map") - pollData <- *data - - log.Info(). - Str("handler", "PollNetMap"). - Str("machine", m.Name). - Msg("Notifying peers") - // TODO: Why does this block? - go h.notifyChangesToPeers(&m) - - h.PollNetMapStream(c, m, req, mKey, pollData, update, cancelKeepAlive) - log.Trace(). - Str("handler", "PollNetMap"). - Str("id", c.Param("id")). - Str("machine", m.Name). - Msg("Finished stream, closing PollNetMap session") -} - -func (h *Headscale) keepAlive(cancel chan []byte, pollData chan []byte, mKey wgkey.Key, req tailcfg.MapRequest, m Machine) { - for { - select { - case <-cancel: - return - - default: - data, err := h.getMapKeepAliveResponse(mKey, req, m) - if err != nil { - log.Error(). - Str("func", "keepAlive"). - Err(err). - Msg("Error generating the keep alive msg") - return - } - - log.Debug(). - Str("func", "keepAlive"). - Str("machine", m.Name). - Msg("Sending keepalive") - pollData <- *data - - time.Sleep(60 * time.Second) - } - } -} - func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m Machine) (*[]byte, error) { log.Trace(). Str("func", "getMapResponse"). @@ -542,7 +353,7 @@ func (h *Headscale) handleAuthKey(c *gin.Context, db *gorm.DB, idKey wgkey.Key, Str("func", "handleAuthKey"). Str("machine", m.Name). Str("ip", ip.String()). - Msgf("Assining %s to %s", ip, m.Name) + Msgf("Assigning %s to %s", ip, m.Name) m.AuthKeyID = uint(pak.ID) m.IPAddress = ip.String() diff --git a/poll.go b/poll.go index f0bfe706..d58d45f1 100644 --- a/poll.go +++ b/poll.go @@ -1,38 +1,242 @@ package headscale import ( + "encoding/json" + "errors" "io" + "net/http" "time" "github.com/gin-gonic/gin" "github.com/rs/zerolog/log" + "gorm.io/datatypes" + "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/wgkey" ) +// PollNetMapHandler takes care of /machine/:id/map +// +// This is the busiest endpoint, as it keeps the HTTP long poll that updates +// the clients when something in the network changes. +// +// The clients POST stuff like HostInfo and their Endpoints here, but +// only after their first request (marked with the ReadOnly field). +// +// At this moment the updates are sent in a quite horrendous way, but they kinda work. +func (h *Headscale) PollNetMapHandler(c *gin.Context) { + log.Trace(). + Str("handler", "PollNetMap"). + Str("id", c.Param("id")). + Msg("PollNetMapHandler called") + body, _ := io.ReadAll(c.Request.Body) + mKeyStr := c.Param("id") + mKey, err := wgkey.ParseHex(mKeyStr) + if err != nil { + log.Error(). + Str("handler", "PollNetMap"). + Err(err). + Msg("Cannot parse client key") + c.String(http.StatusBadRequest, "") + return + } + req := tailcfg.MapRequest{} + err = decode(body, &req, &mKey, h.privateKey) + if err != nil { + log.Error(). + Str("handler", "PollNetMap"). + Err(err). + Msg("Cannot decode message") + c.String(http.StatusBadRequest, "") + return + } + + var m Machine + if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) { + log.Warn(). + Str("handler", "PollNetMap"). + Msgf("Ignoring request, cannot find machine with key %s", mKey.HexString()) + c.String(http.StatusUnauthorized, "") + return + } + log.Trace(). + Str("handler", "PollNetMap"). + Str("id", c.Param("id")). + Str("machine", m.Name). + Msg("Found machine in database") + + hostinfo, _ := json.Marshal(req.Hostinfo) + m.Name = req.Hostinfo.Hostname + m.HostInfo = datatypes.JSON(hostinfo) + m.DiscoKey = wgkey.Key(req.DiscoKey).HexString() + now := time.Now().UTC() + + // From Tailscale client: + // + // ReadOnly is whether the client just wants to fetch the MapResponse, + // without updating their Endpoints. The Endpoints field will be ignored and + // LastSeen will not be updated and peers will not be notified of changes. + // + // The intended use is for clients to discover the DERP map at start-up + // before their first real endpoint update. + if !req.ReadOnly { + endpoints, _ := json.Marshal(req.Endpoints) + m.Endpoints = datatypes.JSON(endpoints) + m.LastSeen = &now + } + h.db.Save(&m) + + data, err := h.getMapResponse(mKey, req, m) + if err != nil { + log.Error(). + Str("handler", "PollNetMap"). + Str("id", c.Param("id")). + Str("machine", m.Name). + Err(err). + Msg("Failed to get Map response") + c.String(http.StatusInternalServerError, ":(") + return + } + + // We update our peers if the client is not sending ReadOnly in the MapRequest + // so we don't distribute its initial request (it comes with + // empty endpoints to peers) + + // Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696 + log.Debug(). + Str("handler", "PollNetMap"). + Str("id", c.Param("id")). + Str("machine", m.Name). + Bool("readOnly", req.ReadOnly). + Bool("omitPeers", req.OmitPeers). + Bool("stream", req.Stream). + Msg("Client map request processed") + + if req.ReadOnly { + log.Info(). + Str("handler", "PollNetMap"). + Str("machine", m.Name). + Msg("Client is starting up. Probably interested in a DERP map") + c.Data(200, "application/json; charset=utf-8", *data) + return + } + + // There has been an update to _any_ of the nodes that the other nodes would + // need to know about + h.setLastStateChangeToNow() + + // The request is not ReadOnly, so we need to set up channels for updating + // peers via longpoll + + // Only create update channel if it has not been created + log.Trace(). + Str("handler", "PollNetMap"). + Str("id", c.Param("id")). + Str("machine", m.Name). + Msg("Loading or creating update channel") + var updateChan chan struct{} + if storedChan, ok := h.clientsUpdateChannels.Load(m.ID); ok { + if wrapped, ok := storedChan.(chan struct{}); ok { + updateChan = wrapped + } else { + log.Error(). + Str("handler", "PollNetMap"). + Str("id", c.Param("id")). + Str("machine", m.Name). + Msg("Failed to convert update channel to struct{}") + } + } else { + log.Debug(). + Str("handler", "PollNetMap"). + Str("id", c.Param("id")). + Str("machine", m.Name). + Msg("Update channel not found, creating") + + updateChan = make(chan struct{}) + h.clientsUpdateChannels.Store(m.ID, updateChan) + } + + pollDataChan := make(chan []byte) + // defer close(pollData) + + keepAliveChan := make(chan []byte) + + cancelKeepAlive := make(chan struct{}) + defer close(cancelKeepAlive) + + if req.OmitPeers && !req.Stream { + log.Info(). + Str("handler", "PollNetMap"). + Str("machine", m.Name). + Msg("Client sent endpoint update and is ok with a response without peer list") + c.Data(200, "application/json; charset=utf-8", *data) + + // It sounds like we should update the nodes when we have received a endpoint update + // even tho the comments in the tailscale code dont explicitly say so. + go h.notifyChangesToPeers(&m) + return + } else if req.OmitPeers && req.Stream { + log.Warn(). + Str("handler", "PollNetMap"). + Str("machine", m.Name). + Msg("Ignoring request, don't know how to handle it") + c.String(http.StatusBadRequest, "") + return + } + + log.Info(). + Str("handler", "PollNetMap"). + Str("machine", m.Name). + Msg("Client is ready to access the tailnet") + log.Info(). + Str("handler", "PollNetMap"). + Str("machine", m.Name). + Msg("Sending initial map") + go func() { pollDataChan <- *data }() + + log.Info(). + Str("handler", "PollNetMap"). + Str("machine", m.Name). + Msg("Notifying peers") + go h.notifyChangesToPeers(&m) + + h.PollNetMapStream(c, m, req, mKey, pollDataChan, keepAliveChan, updateChan, cancelKeepAlive) + log.Trace(). + Str("handler", "PollNetMap"). + Str("id", c.Param("id")). + Str("machine", m.Name). + Msg("Finished stream, closing PollNetMap session") +} + func (h *Headscale) PollNetMapStream( c *gin.Context, m Machine, req tailcfg.MapRequest, mKey wgkey.Key, - pollData chan []byte, - update chan []byte, - cancelKeepAlive chan []byte, + pollDataChan chan []byte, + keepAliveChan chan []byte, + updateChan chan struct{}, + cancelKeepAlive chan struct{}, ) { - - go h.keepAlive(cancelKeepAlive, pollData, mKey, req, m) + go h.keepAlive(cancelKeepAlive, keepAliveChan, mKey, req, m) c.Stream(func(w io.Writer) bool { log.Trace(). Str("handler", "PollNetMapStream"). Str("machine", m.Name). Msg("Waiting for data to stream...") - select { - case data := <-pollData: + log.Trace(). + Str("handler", "PollNetMapStream"). + Str("machine", m.Name). + Msgf("pollData is %#v, keepAliveChan is %#v, updateChan is %#v", pollDataChan, keepAliveChan, updateChan) + + select { + case data := <-pollDataChan: log.Trace(). Str("handler", "PollNetMapStream"). Str("machine", m.Name). + Str("channel", "pollData"). Int("bytes", len(data)). Msg("Sending data received via pollData channel") _, err := w.Write(data) @@ -40,44 +244,99 @@ func (h *Headscale) PollNetMapStream( log.Error(). Str("handler", "PollNetMapStream"). Str("machine", m.Name). + Str("channel", "pollData"). Err(err). Msg("Cannot write data") } log.Trace(). Str("handler", "PollNetMapStream"). Str("machine", m.Name). + Str("channel", "pollData"). Int("bytes", len(data)). Msg("Data from pollData channel written successfully") now := time.Now().UTC() m.LastSeen = &now + m.LastSuccessfulUpdate = &now + h.db.Save(&m) + log.Trace(). + Str("handler", "PollNetMapStream"). + Str("machine", m.Name). + Str("channel", "pollData"). + Int("bytes", len(data)). + Msg("Machine updated successfully after sending pollData") + return true + + case data := <-keepAliveChan: + log.Trace(). + Str("handler", "PollNetMapStream"). + Str("machine", m.Name). + Str("channel", "keepAlive"). + Int("bytes", len(data)). + Msg("Sending keep alive message") + _, err := w.Write(data) + if err != nil { + log.Error(). + Str("handler", "PollNetMapStream"). + Str("machine", m.Name). + Str("channel", "keepAlive"). + Err(err). + Msg("Cannot write keep alive message") + } + log.Trace(). + Str("handler", "PollNetMapStream"). + Str("machine", m.Name). + Str("channel", "keepAlive"). + Int("bytes", len(data)). + Msg("Keep alive sent successfully") + now := time.Now().UTC() + m.LastSeen = &now h.db.Save(&m) log.Trace(). Str("handler", "PollNetMapStream"). Str("machine", m.Name). + Str("channel", "keepAlive"). Int("bytes", len(data)). - Msg("Machine updated successfully after sending pollData") + Msg("Machine updated successfully after sending keep alive") return true - case <-update: - log.Debug(). - Str("handler", "PollNetMapStream"). - Str("machine", m.Name). - Msg("Received a request for update") - data, err := h.getMapResponse(mKey, req, m) - if err != nil { - log.Error(). + case <-updateChan: + if h.isOutdated(&m) { + log.Trace(). Str("handler", "PollNetMapStream"). Str("machine", m.Name). - Err(err). - Msg("Could not get the map update") - } - _, err = w.Write(*data) - if err != nil { - log.Error(). + Str("channel", "update"). + Msg("Received a request for update") + data, err := h.getMapResponse(mKey, req, m) + if err != nil { + log.Error(). + Str("handler", "PollNetMapStream"). + Str("machine", m.Name). + Str("channel", "update"). + Err(err). + Msg("Could not get the map update") + } + _, err = w.Write(*data) + if err != nil { + log.Error(). + Str("handler", "PollNetMapStream"). + Str("machine", m.Name). + Str("channel", "update"). + Err(err). + Msg("Could not write the map response") + } + log.Trace(). Str("handler", "PollNetMapStream"). Str("machine", m.Name). - Err(err). - Msg("Could not write the map response") + Str("channel", "update"). + Msg("Updated Map has been sent") + + // Keep track of the last successful update, + // we sometimes end in a state were the update + // is not picked up by a client and we use this + // to determine if we should "force" an update. + now := time.Now().UTC() + m.LastSuccessfulUpdate = &now + h.db.Save(&m) } return true @@ -89,10 +348,82 @@ func (h *Headscale) PollNetMapStream( now := time.Now().UTC() m.LastSeen = &now h.db.Save(&m) - cancelKeepAlive <- []byte{} - h.clientsPolling.Delete(m.ID) - close(update) + + cancelKeepAlive <- struct{}{} + + h.clientsUpdateChannels.Delete(m.ID) + // close(updateChan) + + close(pollDataChan) + + close(keepAliveChan) + return false } }) } + +// TODO: Rename this function to schedule ... +func (h *Headscale) keepAlive( + cancelChan <-chan struct{}, + keepAliveChan chan<- []byte, + mKey wgkey.Key, + req tailcfg.MapRequest, + m Machine, +) { + keepAliveTicker := time.NewTicker(60 * time.Second) + updateCheckerTicker := time.NewTicker(30 * time.Second) + + for { + select { + case <-cancelChan: + return + + case <-keepAliveTicker.C: + data, err := h.getMapKeepAliveResponse(mKey, req, m) + if err != nil { + log.Error(). + Str("func", "keepAlive"). + Err(err). + Msg("Error generating the keep alive msg") + return + } + + log.Debug(). + Str("func", "keepAlive"). + Str("machine", m.Name). + Msg("Sending keepalive") + keepAliveChan <- *data + + case <-updateCheckerTicker.C: + err := h.UpdateMachine(&m) + if err != nil { + log.Error(). + Str("func", "keepAlive"). + Str("machine", m.Name). + Err(err). + Msg("Could not refresh machine details from database") + return + } + if h.isOutdated(&m) { + log.Debug(). + Str("func", "keepAlive"). + Str("machine", m.Name). + Time("last_successful_update", *m.LastSuccessfulUpdate). + Time("last_state_change", h.getLastStateChange()). + Msgf("There has been updates since the last successful update to %s", m.Name) + + // TODO Error checking + n, _ := m.toNode() + h.requestUpdate(n) + } else { + log.Trace(). + Str("func", "keepAlive"). + Str("machine", m.Name). + Time("last_successful_update", *m.LastSuccessfulUpdate). + Time("last_state_change", h.getLastStateChange()). + Msgf("%s is up to date", m.Name) + } + } + } +} From 8d1adaaef3abac891f6794c63268b5ac47c747e2 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Thu, 19 Aug 2021 18:05:33 +0100 Subject: [PATCH 07/18] Move isOutdated logic to updateChan consumation --- machine.go | 5 +++++ poll.go | 46 +++++++++++++++++++++------------------------- 2 files changed, 26 insertions(+), 25 deletions(-) diff --git a/machine.go b/machine.go index 14efd008..13e3529a 100644 --- a/machine.go +++ b/machine.go @@ -315,6 +315,11 @@ func (h *Headscale) requestUpdate(m *tailcfg.Node) error { } func (h *Headscale) isOutdated(m *Machine) bool { + err := h.UpdateMachine(m) + if err != nil { + return true + } + lastChange := h.getLastStateChange() log.Trace(). Str("func", "keepAlive"). diff --git a/poll.go b/poll.go index d58d45f1..27358fc5 100644 --- a/poll.go +++ b/poll.go @@ -300,12 +300,18 @@ func (h *Headscale) PollNetMapStream( return true case <-updateChan: + log.Trace(). + Str("handler", "PollNetMapStream"). + Str("machine", m.Name). + Str("channel", "update"). + Msg("Received a request for update") if h.isOutdated(&m) { - log.Trace(). + log.Debug(). Str("handler", "PollNetMapStream"). Str("machine", m.Name). - Str("channel", "update"). - Msg("Received a request for update") + Time("last_successful_update", *m.LastSuccessfulUpdate). + Time("last_state_change", h.getLastStateChange()). + Msgf("There has been updates since the last successful update to %s", m.Name) data, err := h.getMapResponse(mKey, req, m) if err != nil { log.Error(). @@ -337,6 +343,13 @@ func (h *Headscale) PollNetMapStream( now := time.Now().UTC() m.LastSuccessfulUpdate = &now h.db.Save(&m) + } else { + log.Trace(). + Str("handler", "PollNetMapStream"). + Str("machine", m.Name). + Time("last_successful_update", *m.LastSuccessfulUpdate). + Time("last_state_change", h.getLastStateChange()). + Msgf("%s is up to date", m.Name) } return true @@ -396,33 +409,16 @@ func (h *Headscale) keepAlive( keepAliveChan <- *data case <-updateCheckerTicker.C: - err := h.UpdateMachine(&m) + // Send an update request regardless of outdated or not, if data is sent + // to the node is determined in the updateChan consumer block + n, _ := m.toNode() + err := h.requestUpdate(n) if err != nil { log.Error(). Str("func", "keepAlive"). Str("machine", m.Name). Err(err). - Msg("Could not refresh machine details from database") - return - } - if h.isOutdated(&m) { - log.Debug(). - Str("func", "keepAlive"). - Str("machine", m.Name). - Time("last_successful_update", *m.LastSuccessfulUpdate). - Time("last_state_change", h.getLastStateChange()). - Msgf("There has been updates since the last successful update to %s", m.Name) - - // TODO Error checking - n, _ := m.toNode() - h.requestUpdate(n) - } else { - log.Trace(). - Str("func", "keepAlive"). - Str("machine", m.Name). - Time("last_successful_update", *m.LastSuccessfulUpdate). - Time("last_state_change", h.getLastStateChange()). - Msgf("%s is up to date", m.Name) + Msgf("Failed to send update request to %s", m.Name) } } } From 48ef6e5a6f492d80aae8e7467498b4053f11da58 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Thu, 19 Aug 2021 18:06:57 +0100 Subject: [PATCH 08/18] Rename keepAlive function, as it now does more things --- poll.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/poll.go b/poll.go index 27358fc5..522c529e 100644 --- a/poll.go +++ b/poll.go @@ -218,7 +218,7 @@ func (h *Headscale) PollNetMapStream( updateChan chan struct{}, cancelKeepAlive chan struct{}, ) { - go h.keepAlive(cancelKeepAlive, keepAliveChan, mKey, req, m) + go h.scheduledPollWorker(cancelKeepAlive, keepAliveChan, mKey, req, m) c.Stream(func(w io.Writer) bool { log.Trace(). @@ -376,8 +376,7 @@ func (h *Headscale) PollNetMapStream( }) } -// TODO: Rename this function to schedule ... -func (h *Headscale) keepAlive( +func (h *Headscale) scheduledPollWorker( cancelChan <-chan struct{}, keepAliveChan chan<- []byte, mKey wgkey.Key, From b0ec945dbb59196b7542386a2a543ddf5ba987b9 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Thu, 19 Aug 2021 18:19:26 +0100 Subject: [PATCH 09/18] Make lastStateChange namespaced --- app.go | 40 ++++++++++++++++++++-------------------- machine.go | 2 +- poll.go | 6 +++--- 3 files changed, 24 insertions(+), 24 deletions(-) diff --git a/app.go b/app.go index 255a7df1..76cf92ea 100644 --- a/app.go +++ b/app.go @@ -60,8 +60,7 @@ type Headscale struct { clientsUpdateChannels sync.Map - lastStateChangeMutex sync.RWMutex - lastStateChange time.Time + lastStateChange sync.Map } // NewHeadscale returns the Headscale app @@ -88,13 +87,12 @@ func NewHeadscale(cfg Config) (*Headscale, error) { } h := Headscale{ - cfg: cfg, - dbType: cfg.DBtype, - dbString: dbString, - privateKey: privKey, - publicKey: &pubKey, - aclRules: &tailcfg.FilterAllowAll, // default allowall - lastStateChange: time.Now().UTC(), + cfg: cfg, + dbType: cfg.DBtype, + dbString: dbString, + privateKey: privKey, + publicKey: &pubKey, + aclRules: &tailcfg.FilterAllowAll, // default allowall } err = h.initDB() @@ -229,17 +227,19 @@ func (h *Headscale) Serve() error { return err } -func (h *Headscale) setLastStateChangeToNow() { - h.lastStateChangeMutex.Lock() +func (h *Headscale) setLastStateChangeToNow(namespace string) { + now := time.Now().UTC() + h.lastStateChange.Store(namespace, now) +} + +func (h *Headscale) getLastStateChange(namespace string) time.Time { + if wrapped, ok := h.lastStateChange.Load(namespace); ok { + lastChange, _ := wrapped.(time.Time) + return lastChange + + } now := time.Now().UTC() - h.lastStateChange = now - - h.lastStateChangeMutex.Unlock() -} - -func (h *Headscale) getLastStateChange() time.Time { - h.lastStateChangeMutex.RLock() - defer h.lastStateChangeMutex.RUnlock() - return h.lastStateChange + h.lastStateChange.Store(namespace, now) + return now } diff --git a/machine.go b/machine.go index 13e3529a..5352f741 100644 --- a/machine.go +++ b/machine.go @@ -320,7 +320,7 @@ func (h *Headscale) isOutdated(m *Machine) bool { return true } - lastChange := h.getLastStateChange() + lastChange := h.getLastStateChange(m.Namespace.Name) log.Trace(). Str("func", "keepAlive"). Str("machine", m.Name). diff --git a/poll.go b/poll.go index 522c529e..e85c7a9f 100644 --- a/poll.go +++ b/poll.go @@ -123,7 +123,7 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) { // There has been an update to _any_ of the nodes that the other nodes would // need to know about - h.setLastStateChangeToNow() + h.setLastStateChangeToNow(m.Namespace.Name) // The request is not ReadOnly, so we need to set up channels for updating // peers via longpoll @@ -310,7 +310,7 @@ func (h *Headscale) PollNetMapStream( Str("handler", "PollNetMapStream"). Str("machine", m.Name). Time("last_successful_update", *m.LastSuccessfulUpdate). - Time("last_state_change", h.getLastStateChange()). + Time("last_state_change", h.getLastStateChange(m.Namespace.Name)). Msgf("There has been updates since the last successful update to %s", m.Name) data, err := h.getMapResponse(mKey, req, m) if err != nil { @@ -348,7 +348,7 @@ func (h *Headscale) PollNetMapStream( Str("handler", "PollNetMapStream"). Str("machine", m.Name). Time("last_successful_update", *m.LastSuccessfulUpdate). - Time("last_state_change", h.getLastStateChange()). + Time("last_state_change", h.getLastStateChange(m.Namespace.Name)). Msgf("%s is up to date", m.Name) } return true From 53168d54d803ef0c4a182ed6457a95e5681657e7 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Thu, 19 Aug 2021 22:29:03 +0100 Subject: [PATCH 10/18] Make http timeout 30s instead of 10s --- app.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/app.go b/app.go index 76cf92ea..fe1b954b 100644 --- a/app.go +++ b/app.go @@ -167,14 +167,16 @@ func (h *Headscale) Serve() error { r.POST("/machine/:id", h.RegistrationHandler) var err error + timeout := 30 * time.Second + go h.watchForKVUpdates(5000) go h.expireEphemeralNodes(5000) s := &http.Server{ Addr: h.cfg.Addr, Handler: r, - ReadTimeout: 10 * time.Second, - WriteTimeout: 10 * time.Second, + ReadTimeout: timeout, + WriteTimeout: timeout, } if h.cfg.TLSLetsEncryptHostname != "" { @@ -191,8 +193,8 @@ func (h *Headscale) Serve() error { Addr: h.cfg.Addr, TLSConfig: m.TLSConfig(), Handler: r, - ReadTimeout: 10 * time.Second, - WriteTimeout: 10 * time.Second, + ReadTimeout: timeout, + WriteTimeout: timeout, } if h.cfg.TLSLetsEncryptChallengeType == "TLS-ALPN-01" { // Configuration via autocert with TLS-ALPN-01 (https://tools.ietf.org/html/rfc8737) From 1f422af1c80e7226b8555a660338bbd30b4ef36c Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Fri, 20 Aug 2021 16:50:55 +0100 Subject: [PATCH 11/18] Save headscale logs if jobs fail --- .gitignore | 2 + integration_test.go | 91 +++++++++++++++++++++++++++++++++++++-------- 2 files changed, 78 insertions(+), 15 deletions(-) diff --git a/.gitignore b/.gitignore index 3a64648f..44bec691 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,5 @@ config.json *.key /db.sqlite *.sqlite3 + +test_output/ diff --git a/integration_test.go b/integration_test.go index 892c7ece..fa0dfbf0 100644 --- a/integration_test.go +++ b/integration_test.go @@ -4,10 +4,13 @@ package headscale import ( "bytes" + "context" "fmt" + "io/ioutil" "log" "net/http" "os" + "path" "strings" "testing" "time" @@ -22,10 +25,35 @@ import ( type IntegrationTestSuite struct { suite.Suite + stats *suite.SuiteInformation } func TestIntegrationTestSuite(t *testing.T) { - suite.Run(t, new(IntegrationTestSuite)) + s := new(IntegrationTestSuite) + suite.Run(t, s) + + // HandleStats, which allows us to check if we passed and save logs + // is called after TearDown, so we cannot tear down containers before + // we have potentially saved the logs. + for _, tailscale := range tailscales { + if err := pool.Purge(&tailscale); err != nil { + log.Printf("Could not purge resource: %s\n", err) + } + } + + if !s.stats.Passed() { + err := saveLog(&headscale, "test_output") + if err != nil { + log.Printf("Could not save log: %s\n", err) + } + } + if err := pool.Purge(&headscale); err != nil { + log.Printf("Could not purge resource: %s\n", err) + } + + if err := network.Close(); err != nil { + log.Printf("Could not close network: %s\n", err) + } } var integrationTmpDir string @@ -34,7 +62,7 @@ var ih Headscale var pool dockertest.Pool var network dockertest.Network var headscale dockertest.Resource -var tailscaleCount int = 20 +var tailscaleCount int = 25 var tailscales map[string]dockertest.Resource func executeCommand(resource *dockertest.Resource, cmd []string) (string, error) { @@ -62,6 +90,48 @@ func executeCommand(resource *dockertest.Resource, cmd []string) (string, error) return stdout.String(), nil } +func saveLog(resource *dockertest.Resource, basePath string) error { + err := os.MkdirAll(basePath, os.ModePerm) + if err != nil { + return err + } + + var stdout bytes.Buffer + var stderr bytes.Buffer + + err = pool.Client.Logs( + docker.LogsOptions{ + Context: context.TODO(), + Container: resource.Container.ID, + OutputStream: &stdout, + ErrorStream: &stderr, + Tail: "all", + RawTerminal: false, + Stdout: true, + Stderr: true, + Follow: false, + Timestamps: false, + }, + ) + if err != nil { + return err + } + + fmt.Printf("Saving logs for %s to %s\n", resource.Container.Name, basePath) + + err = ioutil.WriteFile(path.Join(basePath, resource.Container.Name+".stdout.log"), []byte(stdout.String()), 0644) + if err != nil { + return err + } + + err = ioutil.WriteFile(path.Join(basePath, resource.Container.Name+".stderr.log"), []byte(stdout.String()), 0644) + if err != nil { + return err + } + + return nil +} + func dockerRestartPolicy(config *docker.HostConfig) { // set AutoRemove to true so that stopped container goes away by itself config.AutoRemove = true @@ -194,23 +264,14 @@ func (s *IntegrationTestSuite) SetupSuite() { // The nodes need a bit of time to get their updated maps from headscale // TODO: See if we can have a more deterministic wait here. - time.Sleep(120 * time.Second) + time.Sleep(60 * time.Second) } func (s *IntegrationTestSuite) TearDownSuite() { - for _, tailscale := range tailscales { - if err := pool.Purge(&tailscale); err != nil { - log.Printf("Could not purge resource: %s\n", err) - } - } +} - if err := pool.Purge(&headscale); err != nil { - log.Printf("Could not purge resource: %s\n", err) - } - - if err := network.Close(); err != nil { - log.Printf("Could not close network: %s\n", err) - } +func (s *IntegrationTestSuite) HandleStats(suiteName string, stats *suite.SuiteInformation) { + s.stats = stats } func (s *IntegrationTestSuite) TestListNodes() { From 88d7ac04bf7a78f378cd3015c1b1bb083ba54cb3 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Fri, 20 Aug 2021 16:52:34 +0100 Subject: [PATCH 12/18] Account for racecondition in deleting/closing update channel This commit tries to address the possible raceondition that can happen if a client closes its connection after we have fetched it from the syncmap before sending the message. To try to avoid introducing new dead lock conditions, all messages sent to updateChannel has been moved into a function, which handles the locking (instead of calling it all over the place) The same lock is used around the delete/close function. --- app.go | 3 ++- machine.go | 42 ++++++++++++++++++++++++++++++++++++++++-- poll.go | 29 ++++------------------------- 3 files changed, 46 insertions(+), 28 deletions(-) diff --git a/app.go b/app.go index fe1b954b..e5f44103 100644 --- a/app.go +++ b/app.go @@ -58,7 +58,8 @@ type Headscale struct { aclPolicy *ACLPolicy aclRules *[]tailcfg.FilterRule - clientsUpdateChannels sync.Map + clientsUpdateChannels sync.Map + clientsUpdateChannelMutex sync.Mutex lastStateChange sync.Map } diff --git a/machine.go b/machine.go index 5352f741..57c48ba8 100644 --- a/machine.go +++ b/machine.go @@ -266,7 +266,7 @@ func (h *Headscale) notifyChangesToPeers(m *Machine) { Str("peer", p.Name). Str("address", p.Addresses[0].String()). Msgf("Notifying peer %s (%s)", p.Name, p.Addresses[0]) - err := h.requestUpdate(p) + err := h.sendRequestOnUpdateChannel(p) if err != nil { log.Info(). Str("func", "notifyChangesToPeers"). @@ -283,7 +283,45 @@ func (h *Headscale) notifyChangesToPeers(m *Machine) { } } -func (h *Headscale) requestUpdate(m *tailcfg.Node) error { +func (h *Headscale) getOrOpenUpdateChannel(m *Machine) <-chan struct{} { + var updateChan chan struct{} + if storedChan, ok := h.clientsUpdateChannels.Load(m.ID); ok { + if unwrapped, ok := storedChan.(chan struct{}); ok { + updateChan = unwrapped + } else { + log.Error(). + Str("handler", "openUpdateChannel"). + Str("machine", m.Name). + Msg("Failed to convert update channel to struct{}") + } + } else { + log.Debug(). + Str("handler", "openUpdateChannel"). + Str("machine", m.Name). + Msg("Update channel not found, creating") + + updateChan = make(chan struct{}) + h.clientsUpdateChannels.Store(m.ID, updateChan) + } + return updateChan +} + +func (h *Headscale) closeUpdateChannel(m *Machine) { + h.clientsUpdateChannelMutex.Lock() + defer h.clientsUpdateChannelMutex.Unlock() + + if storedChan, ok := h.clientsUpdateChannels.Load(m.ID); ok { + if unwrapped, ok := storedChan.(chan struct{}); ok { + close(unwrapped) + } + } + h.clientsUpdateChannels.Delete(m.ID) +} + +func (h *Headscale) sendRequestOnUpdateChannel(m *tailcfg.Node) error { + h.clientsUpdateChannelMutex.Lock() + defer h.clientsUpdateChannelMutex.Unlock() + pUp, ok := h.clientsUpdateChannels.Load(uint64(m.ID)) if ok { log.Info(). diff --git a/poll.go b/poll.go index e85c7a9f..d086fc44 100644 --- a/poll.go +++ b/poll.go @@ -134,27 +134,7 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) { Str("id", c.Param("id")). Str("machine", m.Name). Msg("Loading or creating update channel") - var updateChan chan struct{} - if storedChan, ok := h.clientsUpdateChannels.Load(m.ID); ok { - if wrapped, ok := storedChan.(chan struct{}); ok { - updateChan = wrapped - } else { - log.Error(). - Str("handler", "PollNetMap"). - Str("id", c.Param("id")). - Str("machine", m.Name). - Msg("Failed to convert update channel to struct{}") - } - } else { - log.Debug(). - Str("handler", "PollNetMap"). - Str("id", c.Param("id")). - Str("machine", m.Name). - Msg("Update channel not found, creating") - - updateChan = make(chan struct{}) - h.clientsUpdateChannels.Store(m.ID, updateChan) - } + updateChan := h.getOrOpenUpdateChannel(&m) pollDataChan := make(chan []byte) // defer close(pollData) @@ -215,7 +195,7 @@ func (h *Headscale) PollNetMapStream( mKey wgkey.Key, pollDataChan chan []byte, keepAliveChan chan []byte, - updateChan chan struct{}, + updateChan <-chan struct{}, cancelKeepAlive chan struct{}, ) { go h.scheduledPollWorker(cancelKeepAlive, keepAliveChan, mKey, req, m) @@ -364,8 +344,7 @@ func (h *Headscale) PollNetMapStream( cancelKeepAlive <- struct{}{} - h.clientsUpdateChannels.Delete(m.ID) - // close(updateChan) + h.closeUpdateChannel(&m) close(pollDataChan) @@ -411,7 +390,7 @@ func (h *Headscale) scheduledPollWorker( // Send an update request regardless of outdated or not, if data is sent // to the node is determined in the updateChan consumer block n, _ := m.toNode() - err := h.requestUpdate(n) + err := h.sendRequestOnUpdateChannel(n) if err != nil { log.Error(). Str("func", "keepAlive"). From d93a7f2e02a4994a7b7f2c96630c4bde9fd989b4 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Fri, 20 Aug 2021 17:15:07 +0100 Subject: [PATCH 13/18] Make Info default log level --- cmd/headscale/cli/utils.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index 4ada6408..7e7e8f96 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -39,7 +39,7 @@ func LoadConfig(path string) error { viper.SetDefault("ip_prefix", "100.64.0.0/10") - viper.SetDefault("log_level", "debug") + viper.SetDefault("log_level", "info") err := viper.ReadInConfig() if err != nil { From c49fe26da7700592cbf7ef04f3ed1787cfeb08ef Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Sat, 21 Aug 2021 09:15:16 +0100 Subject: [PATCH 14/18] Code clean up, loglevel debug for integration tests --- integration_test.go | 18 +++++++++--------- integration_test/etc/config.json | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/integration_test.go b/integration_test.go index fa0dfbf0..98cb925a 100644 --- a/integration_test.go +++ b/integration_test.go @@ -23,6 +23,15 @@ import ( "inet.af/netaddr" ) +var integrationTmpDir string +var ih Headscale + +var pool dockertest.Pool +var network dockertest.Network +var headscale dockertest.Resource +var tailscaleCount int = 50 +var tailscales map[string]dockertest.Resource + type IntegrationTestSuite struct { suite.Suite stats *suite.SuiteInformation @@ -56,15 +65,6 @@ func TestIntegrationTestSuite(t *testing.T) { } } -var integrationTmpDir string -var ih Headscale - -var pool dockertest.Pool -var network dockertest.Network -var headscale dockertest.Resource -var tailscaleCount int = 25 -var tailscales map[string]dockertest.Resource - func executeCommand(resource *dockertest.Resource, cmd []string) (string, error) { var stdout bytes.Buffer var stderr bytes.Buffer diff --git a/integration_test/etc/config.json b/integration_test/etc/config.json index 5454f2f7..8a6fd962 100644 --- a/integration_test/etc/config.json +++ b/integration_test/etc/config.json @@ -7,5 +7,5 @@ "db_type": "sqlite3", "db_path": "/tmp/integration_test_db.sqlite3", "acl_policy_path": "", - "log_level": "trace" + "log_level": "debug" } From a054e2514ae363ea78052775b3fad44864f3a604 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Sat, 21 Aug 2021 09:26:18 +0100 Subject: [PATCH 15/18] Keep tailscale count at 25 in integration tests --- integration_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration_test.go b/integration_test.go index 98cb925a..8cdc1918 100644 --- a/integration_test.go +++ b/integration_test.go @@ -29,7 +29,7 @@ var ih Headscale var pool dockertest.Pool var network dockertest.Network var headscale dockertest.Resource -var tailscaleCount int = 50 +var tailscaleCount int = 25 var tailscales map[string]dockertest.Resource type IntegrationTestSuite struct { From 0aeeaac3614737861c53ae2ef5a956736d94fbc6 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Sat, 21 Aug 2021 16:52:19 +0100 Subject: [PATCH 16/18] Always load machine object from DB before save/modify We are currently holding Machine objects in memory for a long time, while waiting for stream/longpoll, this might make us end up with stale objects, that we just call save on, potentially overwriting stuff in the database. A typical scenario would be someone changing something from the CLI, e.g. enabling routes, which in turn is overwritten again by the stale object in the longpolling function. The code has been left with TODO's and a discussion is available in #93. --- poll.go | 56 ++++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 52 insertions(+), 4 deletions(-) diff --git a/poll.go b/poll.go index d086fc44..fdf522cd 100644 --- a/poll.go +++ b/poll.go @@ -234,6 +234,18 @@ func (h *Headscale) PollNetMapStream( Str("channel", "pollData"). Int("bytes", len(data)). Msg("Data from pollData channel written successfully") + // TODO: Abstract away all the database calls, this can cause race conditions + // when an outdated machine object is kept alive, e.g. db is update from + // command line, but then overwritten. + err = h.UpdateMachine(&m) + if err != nil { + log.Error(). + Str("handler", "PollNetMapStream"). + Str("machine", m.Name). + Str("channel", "pollData"). + Err(err). + Msg("Cannot update machine from database") + } now := time.Now().UTC() m.LastSeen = &now m.LastSuccessfulUpdate = &now @@ -268,6 +280,18 @@ func (h *Headscale) PollNetMapStream( Str("channel", "keepAlive"). Int("bytes", len(data)). Msg("Keep alive sent successfully") + // TODO: Abstract away all the database calls, this can cause race conditions + // when an outdated machine object is kept alive, e.g. db is update from + // command line, but then overwritten. + err = h.UpdateMachine(&m) + if err != nil { + log.Error(). + Str("handler", "PollNetMapStream"). + Str("machine", m.Name). + Str("channel", "keepAlive"). + Err(err). + Msg("Cannot update machine from database") + } now := time.Now().UTC() m.LastSeen = &now h.db.Save(&m) @@ -316,10 +340,22 @@ func (h *Headscale) PollNetMapStream( Str("channel", "update"). Msg("Updated Map has been sent") - // Keep track of the last successful update, - // we sometimes end in a state were the update - // is not picked up by a client and we use this - // to determine if we should "force" an update. + // Keep track of the last successful update, + // we sometimes end in a state were the update + // is not picked up by a client and we use this + // to determine if we should "force" an update. + // TODO: Abstract away all the database calls, this can cause race conditions + // when an outdated machine object is kept alive, e.g. db is update from + // command line, but then overwritten. + err = h.UpdateMachine(&m) + if err != nil { + log.Error(). + Str("handler", "PollNetMapStream"). + Str("machine", m.Name). + Str("channel", "update"). + Err(err). + Msg("Cannot update machine from database") + } now := time.Now().UTC() m.LastSuccessfulUpdate = &now h.db.Save(&m) @@ -338,6 +374,18 @@ func (h *Headscale) PollNetMapStream( Str("handler", "PollNetMapStream"). Str("machine", m.Name). Msg("The client has closed the connection") + // TODO: Abstract away all the database calls, this can cause race conditions + // when an outdated machine object is kept alive, e.g. db is update from + // command line, but then overwritten. + err := h.UpdateMachine(&m) + if err != nil { + log.Error(). + Str("handler", "PollNetMapStream"). + Str("machine", m.Name). + Str("channel", "Done"). + Err(err). + Msg("Cannot update machine from database") + } now := time.Now().UTC() m.LastSeen = &now h.db.Save(&m) From ebd27b46afabb3f160151e61c96b868e5cff3b2a Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Mon, 23 Aug 2021 07:35:44 +0100 Subject: [PATCH 17/18] Add comment to updatemachine --- machine.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/machine.go b/machine.go index 57c48ba8..4cdadd98 100644 --- a/machine.go +++ b/machine.go @@ -213,6 +213,8 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) { return &m, nil } +// UpdateMachine takes a Machine struct pointer (typically already loaded from database +// and updates it with the latest data from the database. func (h *Headscale) UpdateMachine(m *Machine) error { if result := h.db.Find(m).First(&m); result.Error != nil { return result.Error From 059f13fc9d8db9066c850433f5cb611746d17338 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Mon, 23 Aug 2021 07:38:14 +0100 Subject: [PATCH 18/18] Add missing comment for stream function --- poll.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/poll.go b/poll.go index fdf522cd..bea16164 100644 --- a/poll.go +++ b/poll.go @@ -188,6 +188,9 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) { Msg("Finished stream, closing PollNetMap session") } +// PollNetMapStream takes care of /machine/:id/map +// stream logic, ensuring we communicate updates and data +// to the connected clients. func (h *Headscale) PollNetMapStream( c *gin.Context, m Machine,