mirror of
https://github.com/juanfont/headscale.git
synced 2024-11-30 02:43:05 +00:00
Account for racecondition in deleting/closing update channel
This commit tries to address the possible raceondition that can happen if a client closes its connection after we have fetched it from the syncmap before sending the message. To try to avoid introducing new dead lock conditions, all messages sent to updateChannel has been moved into a function, which handles the locking (instead of calling it all over the place) The same lock is used around the delete/close function.
This commit is contained in:
parent
1f422af1c8
commit
88d7ac04bf
3 changed files with 46 additions and 28 deletions
1
app.go
1
app.go
|
@ -59,6 +59,7 @@ type Headscale struct {
|
||||||
aclRules *[]tailcfg.FilterRule
|
aclRules *[]tailcfg.FilterRule
|
||||||
|
|
||||||
clientsUpdateChannels sync.Map
|
clientsUpdateChannels sync.Map
|
||||||
|
clientsUpdateChannelMutex sync.Mutex
|
||||||
|
|
||||||
lastStateChange sync.Map
|
lastStateChange sync.Map
|
||||||
}
|
}
|
||||||
|
|
42
machine.go
42
machine.go
|
@ -266,7 +266,7 @@ func (h *Headscale) notifyChangesToPeers(m *Machine) {
|
||||||
Str("peer", p.Name).
|
Str("peer", p.Name).
|
||||||
Str("address", p.Addresses[0].String()).
|
Str("address", p.Addresses[0].String()).
|
||||||
Msgf("Notifying peer %s (%s)", p.Name, p.Addresses[0])
|
Msgf("Notifying peer %s (%s)", p.Name, p.Addresses[0])
|
||||||
err := h.requestUpdate(p)
|
err := h.sendRequestOnUpdateChannel(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Info().
|
log.Info().
|
||||||
Str("func", "notifyChangesToPeers").
|
Str("func", "notifyChangesToPeers").
|
||||||
|
@ -283,7 +283,45 @@ func (h *Headscale) notifyChangesToPeers(m *Machine) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) requestUpdate(m *tailcfg.Node) error {
|
func (h *Headscale) getOrOpenUpdateChannel(m *Machine) <-chan struct{} {
|
||||||
|
var updateChan chan struct{}
|
||||||
|
if storedChan, ok := h.clientsUpdateChannels.Load(m.ID); ok {
|
||||||
|
if unwrapped, ok := storedChan.(chan struct{}); ok {
|
||||||
|
updateChan = unwrapped
|
||||||
|
} else {
|
||||||
|
log.Error().
|
||||||
|
Str("handler", "openUpdateChannel").
|
||||||
|
Str("machine", m.Name).
|
||||||
|
Msg("Failed to convert update channel to struct{}")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Debug().
|
||||||
|
Str("handler", "openUpdateChannel").
|
||||||
|
Str("machine", m.Name).
|
||||||
|
Msg("Update channel not found, creating")
|
||||||
|
|
||||||
|
updateChan = make(chan struct{})
|
||||||
|
h.clientsUpdateChannels.Store(m.ID, updateChan)
|
||||||
|
}
|
||||||
|
return updateChan
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Headscale) closeUpdateChannel(m *Machine) {
|
||||||
|
h.clientsUpdateChannelMutex.Lock()
|
||||||
|
defer h.clientsUpdateChannelMutex.Unlock()
|
||||||
|
|
||||||
|
if storedChan, ok := h.clientsUpdateChannels.Load(m.ID); ok {
|
||||||
|
if unwrapped, ok := storedChan.(chan struct{}); ok {
|
||||||
|
close(unwrapped)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.clientsUpdateChannels.Delete(m.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Headscale) sendRequestOnUpdateChannel(m *tailcfg.Node) error {
|
||||||
|
h.clientsUpdateChannelMutex.Lock()
|
||||||
|
defer h.clientsUpdateChannelMutex.Unlock()
|
||||||
|
|
||||||
pUp, ok := h.clientsUpdateChannels.Load(uint64(m.ID))
|
pUp, ok := h.clientsUpdateChannels.Load(uint64(m.ID))
|
||||||
if ok {
|
if ok {
|
||||||
log.Info().
|
log.Info().
|
||||||
|
|
29
poll.go
29
poll.go
|
@ -134,27 +134,7 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
|
||||||
Str("id", c.Param("id")).
|
Str("id", c.Param("id")).
|
||||||
Str("machine", m.Name).
|
Str("machine", m.Name).
|
||||||
Msg("Loading or creating update channel")
|
Msg("Loading or creating update channel")
|
||||||
var updateChan chan struct{}
|
updateChan := h.getOrOpenUpdateChannel(&m)
|
||||||
if storedChan, ok := h.clientsUpdateChannels.Load(m.ID); ok {
|
|
||||||
if wrapped, ok := storedChan.(chan struct{}); ok {
|
|
||||||
updateChan = wrapped
|
|
||||||
} else {
|
|
||||||
log.Error().
|
|
||||||
Str("handler", "PollNetMap").
|
|
||||||
Str("id", c.Param("id")).
|
|
||||||
Str("machine", m.Name).
|
|
||||||
Msg("Failed to convert update channel to struct{}")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log.Debug().
|
|
||||||
Str("handler", "PollNetMap").
|
|
||||||
Str("id", c.Param("id")).
|
|
||||||
Str("machine", m.Name).
|
|
||||||
Msg("Update channel not found, creating")
|
|
||||||
|
|
||||||
updateChan = make(chan struct{})
|
|
||||||
h.clientsUpdateChannels.Store(m.ID, updateChan)
|
|
||||||
}
|
|
||||||
|
|
||||||
pollDataChan := make(chan []byte)
|
pollDataChan := make(chan []byte)
|
||||||
// defer close(pollData)
|
// defer close(pollData)
|
||||||
|
@ -215,7 +195,7 @@ func (h *Headscale) PollNetMapStream(
|
||||||
mKey wgkey.Key,
|
mKey wgkey.Key,
|
||||||
pollDataChan chan []byte,
|
pollDataChan chan []byte,
|
||||||
keepAliveChan chan []byte,
|
keepAliveChan chan []byte,
|
||||||
updateChan chan struct{},
|
updateChan <-chan struct{},
|
||||||
cancelKeepAlive chan struct{},
|
cancelKeepAlive chan struct{},
|
||||||
) {
|
) {
|
||||||
go h.scheduledPollWorker(cancelKeepAlive, keepAliveChan, mKey, req, m)
|
go h.scheduledPollWorker(cancelKeepAlive, keepAliveChan, mKey, req, m)
|
||||||
|
@ -364,8 +344,7 @@ func (h *Headscale) PollNetMapStream(
|
||||||
|
|
||||||
cancelKeepAlive <- struct{}{}
|
cancelKeepAlive <- struct{}{}
|
||||||
|
|
||||||
h.clientsUpdateChannels.Delete(m.ID)
|
h.closeUpdateChannel(&m)
|
||||||
// close(updateChan)
|
|
||||||
|
|
||||||
close(pollDataChan)
|
close(pollDataChan)
|
||||||
|
|
||||||
|
@ -411,7 +390,7 @@ func (h *Headscale) scheduledPollWorker(
|
||||||
// Send an update request regardless of outdated or not, if data is sent
|
// Send an update request regardless of outdated or not, if data is sent
|
||||||
// to the node is determined in the updateChan consumer block
|
// to the node is determined in the updateChan consumer block
|
||||||
n, _ := m.toNode()
|
n, _ := m.toNode()
|
||||||
err := h.requestUpdate(n)
|
err := h.sendRequestOnUpdateChannel(n)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("func", "keepAlive").
|
Str("func", "keepAlive").
|
||||||
|
|
Loading…
Reference in a new issue