package notifier import ( "context" "fmt" "sort" "strings" "sync" "time" "github.com/juanfont/headscale/hscontrol/types" "github.com/puzpuzpuz/xsync/v3" "github.com/rs/zerolog/log" "tailscale.com/tailcfg" "tailscale.com/util/set" ) type Notifier struct { l sync.RWMutex nodes map[types.NodeID]chan<- types.StateUpdate connected *xsync.MapOf[types.NodeID, bool] b *batcher } func NewNotifier(cfg *types.Config) *Notifier { n := &Notifier{ nodes: make(map[types.NodeID]chan<- types.StateUpdate), connected: xsync.NewMapOf[types.NodeID, bool](), } b := newBatcher(cfg.Tuning.BatchChangeDelay, n) n.b = b // TODO(kradalby): clean this up go b.doWork() return n } func (n *Notifier) AddNode(nodeID types.NodeID, c chan<- types.StateUpdate) { log.Trace().Caller().Uint64("node.id", nodeID.Uint64()).Msg("acquiring lock to add node") defer log.Trace(). Caller(). Uint64("node.id", nodeID.Uint64()). Msg("releasing lock to add node") start := time.Now() n.l.Lock() defer n.l.Unlock() notifierWaitForLock.WithLabelValues("add").Observe(time.Since(start).Seconds()) n.nodes[nodeID] = c n.connected.Store(nodeID, true) log.Trace(). Uint64("node.id", nodeID.Uint64()). Int("open_chans", len(n.nodes)). Msg("Added new channel") notifierNodeUpdateChans.Inc() } func (n *Notifier) RemoveNode(nodeID types.NodeID) { log.Trace().Caller().Uint64("node.id", nodeID.Uint64()).Msg("acquiring lock to remove node") defer log.Trace(). Caller(). Uint64("node.id", nodeID.Uint64()). Msg("releasing lock to remove node") start := time.Now() n.l.Lock() defer n.l.Unlock() notifierWaitForLock.WithLabelValues("remove").Observe(time.Since(start).Seconds()) if len(n.nodes) == 0 { return } delete(n.nodes, nodeID) n.connected.Store(nodeID, false) log.Trace(). Uint64("node.id", nodeID.Uint64()). Int("open_chans", len(n.nodes)). Msg("Removed channel") notifierNodeUpdateChans.Dec() } // IsConnected reports if a node is connected to headscale and has a // poll session open. func (n *Notifier) IsConnected(nodeID types.NodeID) bool { n.l.RLock() defer n.l.RUnlock() if val, ok := n.connected.Load(nodeID); ok { return val } return false } // IsLikelyConnected reports if a node is connected to headscale and has a // poll session open, but doesnt lock, so might be wrong. func (n *Notifier) IsLikelyConnected(nodeID types.NodeID) bool { if val, ok := n.connected.Load(nodeID); ok { return val } return false } func (n *Notifier) LikelyConnectedMap() *xsync.MapOf[types.NodeID, bool] { return n.connected } func (n *Notifier) NotifyAll(ctx context.Context, update types.StateUpdate) { n.NotifyWithIgnore(ctx, update) } func (n *Notifier) NotifyWithIgnore( ctx context.Context, update types.StateUpdate, ignoreNodeIDs ...types.NodeID, ) { notifierUpdateReceived.WithLabelValues(update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc() n.b.addOrPassthrough(update) } func (n *Notifier) NotifyByNodeID( ctx context.Context, update types.StateUpdate, nodeID types.NodeID, ) { log.Trace().Caller().Str("type", update.Type.String()).Msg("acquiring lock to notify") defer log.Trace(). Caller(). Str("type", update.Type.String()). Msg("releasing lock, finished notifying") start := time.Now() n.l.RLock() defer n.l.RUnlock() notifierWaitForLock.WithLabelValues("notify").Observe(time.Since(start).Seconds()) if c, ok := n.nodes[nodeID]; ok { select { case <-ctx.Done(): log.Error(). Err(ctx.Err()). Uint64("node.id", nodeID.Uint64()). Any("origin", types.NotifyOriginKey.Value(ctx)). Any("origin-hostname", types.NotifyHostnameKey.Value(ctx)). Msgf("update not sent, context cancelled") notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc() return case c <- update: log.Trace(). Uint64("node.id", nodeID.Uint64()). Any("origin", ctx.Value("origin")). Any("origin-hostname", ctx.Value("hostname")). Msgf("update successfully sent on chan") notifierUpdateSent.WithLabelValues("ok", update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc() } } } func (n *Notifier) sendAll(update types.StateUpdate) { start := time.Now() n.l.RLock() defer n.l.RUnlock() notifierWaitForLock.WithLabelValues("send-all").Observe(time.Since(start).Seconds()) for _, c := range n.nodes { c <- update notifierUpdateSent.WithLabelValues("ok", update.Type.String(), "send-all").Inc() } } func (n *Notifier) String() string { n.l.RLock() defer n.l.RUnlock() var b strings.Builder b.WriteString("chans:\n") for k, v := range n.nodes { fmt.Fprintf(&b, "\t%d: %p\n", k, v) } b.WriteString("\n") b.WriteString("connected:\n") n.connected.Range(func(k types.NodeID, v bool) bool { fmt.Fprintf(&b, "\t%d: %t\n", k, v) return true }) return b.String() } type batcher struct { tick *time.Ticker mu sync.Mutex cancelCh chan struct{} changedNodeIDs set.Slice[types.NodeID] nodesChanged bool patches map[types.NodeID]tailcfg.PeerChange patchesChanged bool n *Notifier } func newBatcher(batchTime time.Duration, n *Notifier) *batcher { return &batcher{ tick: time.NewTicker(batchTime), cancelCh: make(chan struct{}), patches: make(map[types.NodeID]tailcfg.PeerChange), n: n, } } func (b *batcher) close() { b.cancelCh <- struct{}{} } // addOrPassthrough adds the update to the batcher, if it is not a // type that is currently batched, it will be sent immediately. func (b *batcher) addOrPassthrough(update types.StateUpdate) { b.mu.Lock() defer b.mu.Unlock() switch update.Type { case types.StatePeerChanged: b.changedNodeIDs.Add(update.ChangeNodes...) b.nodesChanged = true case types.StatePeerChangedPatch: for _, newPatch := range update.ChangePatches { if curr, ok := b.patches[types.NodeID(newPatch.NodeID)]; ok { overwritePatch(&curr, newPatch) b.patches[types.NodeID(newPatch.NodeID)] = curr } else { b.patches[types.NodeID(newPatch.NodeID)] = *newPatch } } b.patchesChanged = true default: b.n.sendAll(update) } } // flush sends all the accumulated patches to all // nodes in the notifier. func (b *batcher) flush() { b.mu.Lock() defer b.mu.Unlock() if b.nodesChanged || b.patchesChanged { var patches []*tailcfg.PeerChange // If a node is getting a full update from a change // node update, then the patch can be dropped. for nodeID, patch := range b.patches { if b.changedNodeIDs.Contains(nodeID) { delete(b.patches, nodeID) } else { patches = append(patches, &patch) } } changedNodes := b.changedNodeIDs.Slice().AsSlice() sort.Slice(changedNodes, func(i, j int) bool { return changedNodes[i] < changedNodes[j] }) if b.changedNodeIDs.Slice().Len() > 0 { update := types.StateUpdate{ Type: types.StatePeerChanged, ChangeNodes: changedNodes, } b.n.sendAll(update) } if len(patches) > 0 { patchUpdate := types.StateUpdate{ Type: types.StatePeerChangedPatch, ChangePatches: patches, } b.n.sendAll(patchUpdate) } b.changedNodeIDs = set.Slice[types.NodeID]{} b.nodesChanged = false b.patches = make(map[types.NodeID]tailcfg.PeerChange, len(b.patches)) b.patchesChanged = false } } func (b *batcher) doWork() { for { select { case <-b.cancelCh: return case <-b.tick.C: b.flush() } } } // overwritePatch takes the current patch and a newer patch // and override any field that has changed func overwritePatch(currPatch, newPatch *tailcfg.PeerChange) { if newPatch.DERPRegion != 0 { currPatch.DERPRegion = newPatch.DERPRegion } if newPatch.Cap != 0 { currPatch.Cap = newPatch.Cap } if newPatch.CapMap != nil { currPatch.CapMap = newPatch.CapMap } if newPatch.Endpoints != nil { currPatch.Endpoints = newPatch.Endpoints } if newPatch.Key != nil { currPatch.Key = newPatch.Key } if newPatch.KeySignature != nil { currPatch.KeySignature = newPatch.KeySignature } if newPatch.DiscoKey != nil { currPatch.DiscoKey = newPatch.DiscoKey } if newPatch.Online != nil { currPatch.Online = newPatch.Online } if newPatch.LastSeen != nil { currPatch.LastSeen = newPatch.LastSeen } if newPatch.KeyExpiry != nil { currPatch.KeyExpiry = newPatch.KeyExpiry } if newPatch.Capabilities != nil { currPatch.Capabilities = newPatch.Capabilities } }