PollNetMapHandler: refactor with chan lifetimes in mind

* Resolves an issue where sometimes attempted sends on a closed channel
  happened by ensuring the channels remain open for the entire goroutine.
* May be of help with regards to issue #203
This commit is contained in:
Csaba Sarkadi 2021-12-31 20:51:20 +01:00
parent d35fb8bba0
commit a32175f791
2 changed files with 45 additions and 50 deletions

View file

@ -18,7 +18,7 @@ test:
@go test -coverprofile=coverage.out ./... @go test -coverprofile=coverage.out ./...
test_integration: test_integration:
go test -tags integration -timeout 30m ./... go test -tags integration -timeout 30m -count=1 ./...
test_integration_cli: test_integration_cli:
go test -tags integration -v integration_cli_test.go integration_common_test.go go test -tags integration -v integration_cli_test.go integration_common_test.go

93
poll.go
View file

@ -1,8 +1,10 @@
package headscale package headscale
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"io" "io"
"net/http" "net/http"
"time" "time"
@ -152,14 +154,33 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
Str("id", ctx.Param("id")). Str("id", ctx.Param("id")).
Str("machine", machine.Name). Str("machine", machine.Name).
Msg("Loading or creating update channel") Msg("Loading or creating update channel")
updateChan := make(chan struct{})
pollDataChan := make(chan []byte) // TODO: could probably remove all that duplication once generics land.
closeChanWithLog := func(channel interface{}, name string) {
log.Trace().
Str("handler", "PollNetMap").
Str("machine", machine.Name).
Str("channel", "Done").
Msg(fmt.Sprintf("Closing %s channel", name))
switch c := channel.(type) {
case (chan struct{}):
close(c)
case (chan []byte):
close(c)
}
}
const chanSize = 8
updateChan := make(chan struct{}, chanSize)
defer closeChanWithLog(updateChan, "updateChan")
pollDataChan := make(chan []byte, chanSize)
defer closeChanWithLog(pollDataChan, "pollDataChan")
keepAliveChan := make(chan []byte) keepAliveChan := make(chan []byte)
defer closeChanWithLog(keepAliveChan, "keepAliveChan")
cancelKeepAlive := make(chan struct{})
defer close(cancelKeepAlive)
if req.OmitPeers && !req.Stream { if req.OmitPeers && !req.Stream {
log.Info(). log.Info().
@ -172,7 +193,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
// even tho the comments in the tailscale code dont explicitly say so. // even tho the comments in the tailscale code dont explicitly say so.
updateRequestsFromNode.WithLabelValues(machine.Name, machine.Namespace.Name, "endpoint-update"). updateRequestsFromNode.WithLabelValues(machine.Name, machine.Namespace.Name, "endpoint-update").
Inc() Inc()
go func() { updateChan <- struct{}{} }() updateChan <- struct{}{}
return return
} else if req.OmitPeers && req.Stream { } else if req.OmitPeers && req.Stream {
@ -193,7 +214,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("machine", machine.Name). Str("machine", machine.Name).
Msg("Sending initial map") Msg("Sending initial map")
go func() { pollDataChan <- data }() pollDataChan <- data
log.Info(). log.Info().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
@ -201,7 +222,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
Msg("Notifying peers") Msg("Notifying peers")
updateRequestsFromNode.WithLabelValues(machine.Name, machine.Namespace.Name, "full-update"). updateRequestsFromNode.WithLabelValues(machine.Name, machine.Namespace.Name, "full-update").
Inc() Inc()
go func() { updateChan <- struct{}{} }() updateChan <- struct{}{}
h.PollNetMapStream( h.PollNetMapStream(
ctx, ctx,
@ -211,7 +232,6 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
pollDataChan, pollDataChan,
keepAliveChan, keepAliveChan,
updateChan, updateChan,
cancelKeepAlive,
) )
log.Trace(). log.Trace().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
@ -231,16 +251,20 @@ func (h *Headscale) PollNetMapStream(
pollDataChan chan []byte, pollDataChan chan []byte,
keepAliveChan chan []byte, keepAliveChan chan []byte,
updateChan chan struct{}, updateChan chan struct{},
cancelKeepAlive chan struct{},
) { ) {
go h.scheduledPollWorker( {
cancelKeepAlive, ctx, cancel := context.WithCancel(ctx.Request.Context())
updateChan, defer cancel()
keepAliveChan,
machineKey, go h.scheduledPollWorker(
mapRequest, ctx,
machine, updateChan,
) keepAliveChan,
machineKey,
mapRequest,
machine,
)
}
ctx.Stream(func(writer io.Writer) bool { ctx.Stream(func(writer io.Writer) bool {
log.Trace(). log.Trace().
@ -455,42 +479,13 @@ func (h *Headscale) PollNetMapStream(
machine.LastSeen = &now machine.LastSeen = &now
h.db.Save(&machine) h.db.Save(&machine)
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", machine.Name).
Str("channel", "Done").
Msg("Cancelling keepAlive channel")
cancelKeepAlive <- struct{}{}
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", machine.Name).
Str("channel", "Done").
Msg("Closing update channel")
// h.closeUpdateChannel(m)
close(updateChan)
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", machine.Name).
Str("channel", "Done").
Msg("Closing pollData channel")
close(pollDataChan)
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", machine.Name).
Str("channel", "Done").
Msg("Closing keepAliveChan channel")
close(keepAliveChan)
return false return false
} }
}) })
} }
func (h *Headscale) scheduledPollWorker( func (h *Headscale) scheduledPollWorker(
cancelChan <-chan struct{}, ctx context.Context,
updateChan chan<- struct{}, updateChan chan<- struct{},
keepAliveChan chan<- []byte, keepAliveChan chan<- []byte,
machineKey key.MachinePublic, machineKey key.MachinePublic,
@ -502,7 +497,7 @@ func (h *Headscale) scheduledPollWorker(
for { for {
select { select {
case <-cancelChan: case <-ctx.Done():
return return
case <-keepAliveTicker.C: case <-keepAliveTicker.C: