Account for racecondition in deleting/closing update channel

This commit tries to address the possible raceondition  that can happen
if a client closes its connection after we have fetched it from the
syncmap before sending the message.

To try to avoid introducing new dead lock conditions, all messages sent
to updateChannel has been moved into a function, which handles the
locking (instead of calling it all over the place)

The same lock is used around the delete/close function.
This commit is contained in:
Kristoffer Dalby 2021-08-20 16:52:34 +01:00
parent 1f422af1c8
commit 88d7ac04bf
No known key found for this signature in database
GPG key ID: 09F62DC067465735
3 changed files with 46 additions and 28 deletions

3
app.go
View file

@ -58,7 +58,8 @@ type Headscale struct {
aclPolicy *ACLPolicy
aclRules *[]tailcfg.FilterRule
clientsUpdateChannels sync.Map
clientsUpdateChannels sync.Map
clientsUpdateChannelMutex sync.Mutex
lastStateChange sync.Map
}

View file

@ -266,7 +266,7 @@ func (h *Headscale) notifyChangesToPeers(m *Machine) {
Str("peer", p.Name).
Str("address", p.Addresses[0].String()).
Msgf("Notifying peer %s (%s)", p.Name, p.Addresses[0])
err := h.requestUpdate(p)
err := h.sendRequestOnUpdateChannel(p)
if err != nil {
log.Info().
Str("func", "notifyChangesToPeers").
@ -283,7 +283,45 @@ func (h *Headscale) notifyChangesToPeers(m *Machine) {
}
}
func (h *Headscale) requestUpdate(m *tailcfg.Node) error {
func (h *Headscale) getOrOpenUpdateChannel(m *Machine) <-chan struct{} {
var updateChan chan struct{}
if storedChan, ok := h.clientsUpdateChannels.Load(m.ID); ok {
if unwrapped, ok := storedChan.(chan struct{}); ok {
updateChan = unwrapped
} else {
log.Error().
Str("handler", "openUpdateChannel").
Str("machine", m.Name).
Msg("Failed to convert update channel to struct{}")
}
} else {
log.Debug().
Str("handler", "openUpdateChannel").
Str("machine", m.Name).
Msg("Update channel not found, creating")
updateChan = make(chan struct{})
h.clientsUpdateChannels.Store(m.ID, updateChan)
}
return updateChan
}
func (h *Headscale) closeUpdateChannel(m *Machine) {
h.clientsUpdateChannelMutex.Lock()
defer h.clientsUpdateChannelMutex.Unlock()
if storedChan, ok := h.clientsUpdateChannels.Load(m.ID); ok {
if unwrapped, ok := storedChan.(chan struct{}); ok {
close(unwrapped)
}
}
h.clientsUpdateChannels.Delete(m.ID)
}
func (h *Headscale) sendRequestOnUpdateChannel(m *tailcfg.Node) error {
h.clientsUpdateChannelMutex.Lock()
defer h.clientsUpdateChannelMutex.Unlock()
pUp, ok := h.clientsUpdateChannels.Load(uint64(m.ID))
if ok {
log.Info().

29
poll.go
View file

@ -134,27 +134,7 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
Str("id", c.Param("id")).
Str("machine", m.Name).
Msg("Loading or creating update channel")
var updateChan chan struct{}
if storedChan, ok := h.clientsUpdateChannels.Load(m.ID); ok {
if wrapped, ok := storedChan.(chan struct{}); ok {
updateChan = wrapped
} else {
log.Error().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Msg("Failed to convert update channel to struct{}")
}
} else {
log.Debug().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Msg("Update channel not found, creating")
updateChan = make(chan struct{})
h.clientsUpdateChannels.Store(m.ID, updateChan)
}
updateChan := h.getOrOpenUpdateChannel(&m)
pollDataChan := make(chan []byte)
// defer close(pollData)
@ -215,7 +195,7 @@ func (h *Headscale) PollNetMapStream(
mKey wgkey.Key,
pollDataChan chan []byte,
keepAliveChan chan []byte,
updateChan chan struct{},
updateChan <-chan struct{},
cancelKeepAlive chan struct{},
) {
go h.scheduledPollWorker(cancelKeepAlive, keepAliveChan, mKey, req, m)
@ -364,8 +344,7 @@ func (h *Headscale) PollNetMapStream(
cancelKeepAlive <- struct{}{}
h.clientsUpdateChannels.Delete(m.ID)
// close(updateChan)
h.closeUpdateChannel(&m)
close(pollDataChan)
@ -411,7 +390,7 @@ func (h *Headscale) scheduledPollWorker(
// Send an update request regardless of outdated or not, if data is sent
// to the node is determined in the updateChan consumer block
n, _ := m.toNode()
err := h.requestUpdate(n)
err := h.sendRequestOnUpdateChannel(n)
if err != nil {
log.Error().
Str("func", "keepAlive").