diff --git a/api.go b/api.go index 17306b32..ce3b242e 100644 --- a/api.go +++ b/api.go @@ -277,7 +277,7 @@ func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m *Ma log.Trace(). Str("func", "getMapResponse"). Str("machine", req.Hostinfo.Hostname). - Interface("payload", resp). + // Interface("payload", resp). Msgf("Generated map response: %s", tailMapResponseToString(resp)) var respBody []byte diff --git a/app.go b/app.go index 27b9672d..acaa7f1b 100644 --- a/app.go +++ b/app.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "os" + "sort" "strings" "sync" "time" @@ -65,9 +66,6 @@ type Headscale struct { aclPolicy *ACLPolicy aclRules *[]tailcfg.FilterRule - clientsUpdateChannels sync.Map - clientsUpdateChannelMutex sync.Mutex - lastStateChange sync.Map } @@ -145,10 +143,9 @@ func (h *Headscale) expireEphemeralNodesWorker() { if err != nil { log.Error().Err(err).Str("machine", m.Name).Msg("🤮 Cannot delete ephemeral machine from the database") } - updateRequestsFromNode.WithLabelValues("ephemeral-node-update").Inc() - h.notifyChangesToPeers(&m) } } + h.setLastStateChangeToNow(ns.Name) } } @@ -251,14 +248,28 @@ func (h *Headscale) setLastStateChangeToNow(namespace string) { h.lastStateChange.Store(namespace, now) } -func (h *Headscale) getLastStateChange(namespace string) time.Time { - if wrapped, ok := h.lastStateChange.Load(namespace); ok { - lastChange, _ := wrapped.(time.Time) - return lastChange +func (h *Headscale) getLastStateChange(namespaces ...string) time.Time { + times := []time.Time{} + + for _, namespace := range namespaces { + if wrapped, ok := h.lastStateChange.Load(namespace); ok { + lastChange, _ := wrapped.(time.Time) + + times = append(times, lastChange) + } } - now := time.Now().UTC() - h.lastStateChange.Store(namespace, now) - return now + sort.Slice(times, func(i, j int) bool { + return times[i].After(times[j]) + }) + + log.Trace().Msgf("Latest times %#v", times) + + if len(times) == 0 { + return time.Now().UTC() + + } else { + return times[0] + } } diff --git a/go.mod b/go.mod index 70428315..1fadd6b7 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/docker/cli v20.10.8+incompatible // indirect github.com/docker/docker v20.10.8+incompatible // indirect github.com/efekarakus/termcolor v1.0.1 + github.com/fatih/set v0.2.1 // indirect github.com/gin-gonic/gin v1.7.4 github.com/gofrs/uuid v4.0.0+incompatible github.com/google/go-github v17.0.0+incompatible // indirect diff --git a/go.sum b/go.sum index 6dc92b6d..b429ca95 100644 --- a/go.sum +++ b/go.sum @@ -204,6 +204,8 @@ github.com/fanliao/go-promise v0.0.0-20141029170127-1890db352a72/go.mod h1:Pjfxu github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fatih/color v1.9.0/go.mod h1:eQcE1qtQxscV5RaZvpXrrb8Drkc3/DdQ+uUYCNjL+zU= github.com/fatih/color v1.10.0/go.mod h1:ELkj/draVOlAH/xkhN6mQ50Qd0MPOk5AAr3maGEBuJM= +github.com/fatih/set v0.2.1 h1:nn2CaJyknWE/6txyUDGwysr3G5QC6xWB/PtVjPBbeaA= +github.com/fatih/set v0.2.1/go.mod h1:+RKtMCH+favT2+3YecHGxcc0b4KyVWA1QWWJUs4E0CI= github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc= github.com/franela/goblin v0.0.0-20200105215937-c9ffbefa60db/go.mod h1:7dvUGVsVBjqR7JHJk0brhHOZYGmfBYOrK0ZhYMEtBr4= github.com/franela/goreq v0.0.0-20171204163338-bcd34c9993f8/go.mod h1:ZhphrRTfi2rbfLwlschooIH4+wKKDR4Pdxhh+TRoA20= diff --git a/machine.go b/machine.go index 7a9df2e8..21d774b4 100644 --- a/machine.go +++ b/machine.go @@ -2,13 +2,13 @@ package headscale import ( "encoding/json" - "errors" "fmt" "sort" "strconv" "strings" "time" + "github.com/fatih/set" "github.com/rs/zerolog/log" "gorm.io/datatypes" @@ -66,7 +66,7 @@ func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) { if err := h.db.Where("namespace_id = ? AND machine_key <> ? AND registered", m.NamespaceID, m.MachineKey).Find(&machines).Error; err != nil { log.Error().Err(err).Msg("Error accessing db") - return nil, err + return Machines{}, err } sort.Slice(machines, func(i, j int) bool { return machines[i].ID < machines[j].ID }) @@ -88,7 +88,7 @@ func (h *Headscale) getShared(m *Machine) (Machines, error) { sharedMachines := []SharedMachine{} if err := h.db.Preload("Namespace").Preload("Machine").Where("namespace_id = ?", m.NamespaceID).Find(&sharedMachines).Error; err != nil { - return nil, err + return Machines{}, err } peers := make(Machines, 0) @@ -112,7 +112,7 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) { Str("func", "getPeers"). Err(err). Msg("Cannot fetch peers") - return nil, err + return Machines{}, err } shared, err := h.getShared(m) @@ -121,7 +121,7 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) { Str("func", "getDirectPeers"). Err(err). Msg("Cannot fetch peers") - return nil, err + return Machines{}, err } peers := append(direct, shared...) @@ -214,118 +214,31 @@ func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) { return &hostinfo, nil } -func (h *Headscale) notifyChangesToPeers(m *Machine) { - peers, err := h.getPeers(m) - if err != nil { - log.Error(). - Str("func", "notifyChangesToPeers"). - Str("machine", m.Name). - Msgf("Error getting peers: %s", err) - return - } - for _, peer := range peers { - log.Info(). - Str("func", "notifyChangesToPeers"). - Str("machine", m.Name). - Str("peer", peer.Name). - Str("address", peer.IPAddress). - Msgf("Notifying peer %s (%s)", peer.Name, peer.IPAddress) - err := h.sendRequestOnUpdateChannel(&peer) - if err != nil { - log.Info(). - Str("func", "notifyChangesToPeers"). - Str("machine", m.Name). - Str("peer", peer.Name). - Msgf("Peer %s does not have an open update client, skipping.", peer.Name) - continue - } - log.Trace(). - Str("func", "notifyChangesToPeers"). - Str("machine", m.Name). - Str("peer", peer.Name). - Str("address", peer.IPAddress). - Msgf("Notified peer %s (%s)", peer.Name, peer.IPAddress) - } -} - -func (h *Headscale) getOrOpenUpdateChannel(m *Machine) <-chan struct{} { - var updateChan chan struct{} - if storedChan, ok := h.clientsUpdateChannels.Load(m.ID); ok { - if unwrapped, ok := storedChan.(chan struct{}); ok { - updateChan = unwrapped - } else { - log.Error(). - Str("handler", "openUpdateChannel"). - Str("machine", m.Name). - Msg("Failed to convert update channel to struct{}") - } - } else { - log.Debug(). - Str("handler", "openUpdateChannel"). - Str("machine", m.Name). - Msg("Update channel not found, creating") - - updateChan = make(chan struct{}) - h.clientsUpdateChannels.Store(m.ID, updateChan) - } - return updateChan -} - -func (h *Headscale) closeUpdateChannel(m *Machine) { - h.clientsUpdateChannelMutex.Lock() - defer h.clientsUpdateChannelMutex.Unlock() - - if storedChan, ok := h.clientsUpdateChannels.Load(m.ID); ok { - if unwrapped, ok := storedChan.(chan struct{}); ok { - close(unwrapped) - } - } - h.clientsUpdateChannels.Delete(m.ID) -} - -func (h *Headscale) sendRequestOnUpdateChannel(m *Machine) error { - h.clientsUpdateChannelMutex.Lock() - defer h.clientsUpdateChannelMutex.Unlock() - - pUp, ok := h.clientsUpdateChannels.Load(uint64(m.ID)) - if ok { - log.Info(). - Str("func", "requestUpdate"). - Str("machine", m.Name). - Msgf("Notifying peer %s", m.Name) - - if update, ok := pUp.(chan struct{}); ok { - log.Trace(). - Str("func", "requestUpdate"). - Str("machine", m.Name). - Msgf("Update channel is %#v", update) - - updateRequestsToNode.Inc() - update <- struct{}{} - - log.Trace(). - Str("func", "requestUpdate"). - Str("machine", m.Name). - Msgf("Notified machine %s", m.Name) - } - } else { - err := errors.New("machine does not have an open update channel") - log.Info(). - Str("func", "requestUpdate"). - Str("machine", m.Name). - Msgf("Machine %s does not have an open update channel", m.Name) - return err - } - return nil -} - func (h *Headscale) isOutdated(m *Machine) bool { err := h.UpdateMachine(m) if err != nil { + // It does not seem meaningful to propagate this error as the end result + // will have to be that the machine has to be considered outdated. return true } - lastChange := h.getLastStateChange(m.Namespace.Name) + sharedMachines, _ := h.getShared(m) + + namespaceSet := set.New(set.ThreadSafe) + namespaceSet.Add(m.Namespace.Name) + + // Check if any of our shared namespaces has updates that we have + // not propagated. + for _, sharedMachine := range sharedMachines { + namespaceSet.Add(sharedMachine.Namespace.Name) + } + + namespaces := make([]string, namespaceSet.Size()) + for index, namespace := range namespaceSet.List() { + namespaces[index] = namespace.(string) + } + + lastChange := h.getLastStateChange(namespaces...) log.Trace(). Str("func", "keepAlive"). Str("machine", m.Name). diff --git a/metrics.go b/metrics.go index d516fade..0d3dca34 100644 --- a/metrics.go +++ b/metrics.go @@ -26,16 +26,16 @@ var ( Namespace: prometheusNamespace, Name: "update_request_from_node_total", Help: "The number of updates requested by a node/update function", - }, []string{"state"}) - updateRequestsToNode = promauto.NewCounter(prometheus.CounterOpts{ + }, []string{"namespace", "machine", "state"}) + updateRequestsSentToNode = promauto.NewCounterVec(prometheus.CounterOpts{ Namespace: prometheusNamespace, - Name: "update_request_to_node_total", + Name: "update_request_sent_to_node_total", Help: "The number of calls/messages issued on a specific nodes update channel", - }) + }, []string{"namespace", "machine", "status"}) //TODO(kradalby): This is very debugging, we might want to remove it. updateRequestsReceivedOnChannel = promauto.NewCounterVec(prometheus.CounterOpts{ Namespace: prometheusNamespace, Name: "update_request_received_on_channel_total", Help: "The number of update requests received on an update channel", - }, []string{"machine"}) + }, []string{"namespace", "machine"}) ) diff --git a/namespaces.go b/namespaces.go index 75b6eab5..2bf62bb3 100644 --- a/namespaces.go +++ b/namespaces.go @@ -176,24 +176,17 @@ func (h *Headscale) checkForNamespacesPendingUpdates() { return } - names := []string{} - err = json.Unmarshal([]byte(v), &names) + namespaces := []string{} + err = json.Unmarshal([]byte(v), &namespaces) if err != nil { return } - for _, name := range names { + for _, namespace := range namespaces { log.Trace(). Str("func", "RequestMapUpdates"). - Str("machine", name). - Msg("Sending updates to nodes in namespace") - machines, err := h.ListMachinesInNamespace(name) - if err != nil { - continue - } - for _, m := range *machines { - updateRequestsFromNode.WithLabelValues("namespace-update").Inc() - h.notifyChangesToPeers(&m) - } + Str("machine", namespace). + Msg("Sending updates to nodes in namespacespace") + h.setLastStateChangeToNow(namespace) } newV, err := h.getValue("namespaces_pending_updates") if err != nil { diff --git a/poll.go b/poll.go index a33d341b..6a652805 100644 --- a/poll.go +++ b/poll.go @@ -140,10 +140,9 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) { Str("id", c.Param("id")). Str("machine", m.Name). Msg("Loading or creating update channel") - updateChan := h.getOrOpenUpdateChannel(m) + updateChan := make(chan struct{}) pollDataChan := make(chan []byte) - // defer close(pollData) keepAliveChan := make(chan []byte) @@ -159,8 +158,8 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) { // It sounds like we should update the nodes when we have received a endpoint update // even tho the comments in the tailscale code dont explicitly say so. - updateRequestsFromNode.WithLabelValues("endpoint-update").Inc() - go h.notifyChangesToPeers(m) + updateRequestsFromNode.WithLabelValues(m.Name, m.Namespace.Name, "endpoint-update").Inc() + go func() { updateChan <- struct{}{} }() return } else if req.OmitPeers && req.Stream { log.Warn(). @@ -185,8 +184,8 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) { Str("handler", "PollNetMap"). Str("machine", m.Name). Msg("Notifying peers") - updateRequestsFromNode.WithLabelValues("full-update").Inc() - go h.notifyChangesToPeers(m) + updateRequestsFromNode.WithLabelValues(m.Name, m.Namespace.Name, "full-update").Inc() + go func() { updateChan <- struct{}{} }() h.PollNetMapStream(c, m, req, mKey, pollDataChan, keepAliveChan, updateChan, cancelKeepAlive) log.Trace(). @@ -206,10 +205,10 @@ func (h *Headscale) PollNetMapStream( mKey wgkey.Key, pollDataChan chan []byte, keepAliveChan chan []byte, - updateChan <-chan struct{}, + updateChan chan struct{}, cancelKeepAlive chan struct{}, ) { - go h.scheduledPollWorker(cancelKeepAlive, keepAliveChan, mKey, req, m) + go h.scheduledPollWorker(cancelKeepAlive, updateChan, keepAliveChan, mKey, req, m) c.Stream(func(w io.Writer) bool { log.Trace(). @@ -325,7 +324,7 @@ func (h *Headscale) PollNetMapStream( Str("machine", m.Name). Str("channel", "update"). Msg("Received a request for update") - updateRequestsReceivedOnChannel.WithLabelValues(m.Name).Inc() + updateRequestsReceivedOnChannel.WithLabelValues(m.Name, m.Namespace.Name).Inc() if h.isOutdated(m) { log.Debug(). Str("handler", "PollNetMapStream"). @@ -350,6 +349,7 @@ func (h *Headscale) PollNetMapStream( Str("channel", "update"). Err(err). Msg("Could not write the map response") + updateRequestsSentToNode.WithLabelValues(m.Name, m.Namespace.Name, "failed").Inc() return false } log.Trace(). @@ -357,14 +357,15 @@ func (h *Headscale) PollNetMapStream( Str("machine", m.Name). Str("channel", "update"). Msg("Updated Map has been sent") + updateRequestsSentToNode.WithLabelValues(m.Name, m.Namespace.Name, "success").Inc() - // Keep track of the last successful update, - // we sometimes end in a state were the update - // is not picked up by a client and we use this - // to determine if we should "force" an update. - // TODO(kradalby): Abstract away all the database calls, this can cause race conditions - // when an outdated machine object is kept alive, e.g. db is update from - // command line, but then overwritten. + // Keep track of the last successful update, + // we sometimes end in a state were the update + // is not picked up by a client and we use this + // to determine if we should "force" an update. + // TODO(kradalby): Abstract away all the database calls, this can cause race conditions + // when an outdated machine object is kept alive, e.g. db is update from + // command line, but then overwritten. err = h.UpdateMachine(m) if err != nil { log.Error(). @@ -423,7 +424,8 @@ func (h *Headscale) PollNetMapStream( Str("machine", m.Name). Str("channel", "Done"). Msg("Closing update channel") - h.closeUpdateChannel(m) + //h.closeUpdateChannel(m) + close(updateChan) log.Trace(). Str("handler", "PollNetMapStream"). @@ -446,13 +448,14 @@ func (h *Headscale) PollNetMapStream( func (h *Headscale) scheduledPollWorker( cancelChan <-chan struct{}, + updateChan chan<- struct{}, keepAliveChan chan<- []byte, mKey wgkey.Key, req tailcfg.MapRequest, m *Machine, ) { keepAliveTicker := time.NewTicker(60 * time.Second) - updateCheckerTicker := time.NewTicker(30 * time.Second) + updateCheckerTicker := time.NewTicker(10 * time.Second) for { select { @@ -476,16 +479,12 @@ func (h *Headscale) scheduledPollWorker( keepAliveChan <- data case <-updateCheckerTicker.C: - // Send an update request regardless of outdated or not, if data is sent - // to the node is determined in the updateChan consumer block - err := h.sendRequestOnUpdateChannel(m) - if err != nil { - log.Error(). - Str("func", "keepAlive"). - Str("machine", m.Name). - Err(err). - Msgf("Failed to send update request to %s", m.Name) - } + log.Debug(). + Str("func", "scheduledPollWorker"). + Str("machine", m.Name). + Msg("Sending update request") + updateRequestsFromNode.WithLabelValues(m.Name, m.Namespace.Name, "scheduled-update").Inc() + updateChan <- struct{}{} } } }