package notifier import ( "context" "fmt" "slices" "strings" "sync" "github.com/juanfont/headscale/hscontrol/types" "github.com/rs/zerolog/log" ) type Notifier struct { l sync.RWMutex nodes map[types.NodeID]chan<- types.StateUpdate connected types.NodeConnectedMap } func NewNotifier() *Notifier { return &Notifier{ nodes: make(map[types.NodeID]chan<- types.StateUpdate), connected: make(types.NodeConnectedMap), } } func (n *Notifier) AddNode(nodeID types.NodeID, c chan<- types.StateUpdate) { log.Trace().Caller().Uint64("node.id", nodeID.Uint64()).Msg("acquiring lock to add node") defer log.Trace(). Caller(). Uint64("node.id", nodeID.Uint64()). Msg("releasing lock to add node") n.l.Lock() defer n.l.Unlock() n.nodes[nodeID] = c n.connected[nodeID] = true log.Trace(). Uint64("node.id", nodeID.Uint64()). Int("open_chans", len(n.nodes)). Msg("Added new channel") } func (n *Notifier) RemoveNode(nodeID types.NodeID) { log.Trace().Caller().Uint64("node.id", nodeID.Uint64()).Msg("acquiring lock to remove node") defer log.Trace(). Caller(). Uint64("node.id", nodeID.Uint64()). Msg("releasing lock to remove node") n.l.Lock() defer n.l.Unlock() if len(n.nodes) == 0 { return } delete(n.nodes, nodeID) n.connected[nodeID] = false log.Trace(). Uint64("node.id", nodeID.Uint64()). Int("open_chans", len(n.nodes)). Msg("Removed channel") } // IsConnected reports if a node is connected to headscale and has a // poll session open. func (n *Notifier) IsConnected(nodeID types.NodeID) bool { n.l.RLock() defer n.l.RUnlock() return n.connected[nodeID] } // IsLikelyConnected reports if a node is connected to headscale and has a // poll session open, but doesnt lock, so might be wrong. func (n *Notifier) IsLikelyConnected(nodeID types.NodeID) bool { return n.connected[nodeID] } // TODO(kradalby): This returns a pointer and can be dangerous. func (n *Notifier) ConnectedMap() types.NodeConnectedMap { return n.connected } func (n *Notifier) NotifyAll(ctx context.Context, update types.StateUpdate) { n.NotifyWithIgnore(ctx, update) } func (n *Notifier) NotifyWithIgnore( ctx context.Context, update types.StateUpdate, ignoreNodeIDs ...types.NodeID, ) { log.Trace().Caller().Str("type", update.Type.String()).Msg("acquiring lock to notify") defer log.Trace(). Caller(). Str("type", update.Type.String()). Msg("releasing lock, finished notifying") n.l.RLock() defer n.l.RUnlock() if update.Type == types.StatePeerChangedPatch { log.Trace().Interface("update", update).Interface("online", n.connected).Msg("PATCH UPDATE SENT") } for nodeID, c := range n.nodes { if slices.Contains(ignoreNodeIDs, nodeID) { continue } select { case <-ctx.Done(): log.Error(). Err(ctx.Err()). Uint64("node.id", nodeID.Uint64()). Any("origin", ctx.Value("origin")). Any("origin-hostname", ctx.Value("hostname")). Msgf("update not sent, context cancelled") return case c <- update: log.Trace(). Uint64("node.id", nodeID.Uint64()). Any("origin", ctx.Value("origin")). Any("origin-hostname", ctx.Value("hostname")). Msgf("update successfully sent on chan") } } } func (n *Notifier) NotifyByMachineKey( ctx context.Context, update types.StateUpdate, nodeID types.NodeID, ) { log.Trace().Caller().Str("type", update.Type.String()).Msg("acquiring lock to notify") defer log.Trace(). Caller(). Str("type", update.Type.String()). Msg("releasing lock, finished notifying") n.l.RLock() defer n.l.RUnlock() if c, ok := n.nodes[nodeID]; ok { select { case <-ctx.Done(): log.Error(). Err(ctx.Err()). Uint64("node.id", nodeID.Uint64()). Any("origin", ctx.Value("origin")). Any("origin-hostname", ctx.Value("hostname")). Msgf("update not sent, context cancelled") return case c <- update: log.Trace(). Uint64("node.id", nodeID.Uint64()). Any("origin", ctx.Value("origin")). Any("origin-hostname", ctx.Value("hostname")). Msgf("update successfully sent on chan") } } } func (n *Notifier) String() string { n.l.RLock() defer n.l.RUnlock() var b strings.Builder b.WriteString("chans:\n") for k, v := range n.nodes { fmt.Fprintf(&b, "\t%d: %p\n", k, v) } b.WriteString("\n") b.WriteString("connected:\n") for k, v := range n.connected { fmt.Fprintf(&b, "\t%d: %t\n", k, v) } return b.String() }