mirror of
https://github.com/juanfont/headscale.git
synced 2025-01-19 10:20:05 +09:00
Split up MapResponse
This commits extends the mapper with functions for creating "delta" MapResponses for different purposes (peer changed, peer removed, derp). This wires up the new state management with a new StateUpdate struct letting the poll worker know what kind of update to send to the connected nodes. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
66ff1fcd40
commit
4b65cf48d0
8 changed files with 284 additions and 115 deletions
|
@ -257,7 +257,10 @@ func (h *Headscale) scheduledDERPMapUpdateWorker(cancelChan <-chan struct{}) {
|
|||
h.DERPMap.Regions[region.RegionID] = ®ion
|
||||
}
|
||||
|
||||
h.nodeNotifier.NotifyAll()
|
||||
h.nodeNotifier.NotifyAll(types.StateUpdate{
|
||||
Type: types.StateDERPUpdated,
|
||||
DERPMap: *h.DERPMap,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -721,7 +724,9 @@ func (h *Headscale) Serve() error {
|
|||
Str("path", aclPath).
|
||||
Msg("ACL policy successfully reloaded, notifying nodes of change")
|
||||
|
||||
h.nodeNotifier.NotifyAll()
|
||||
h.nodeNotifier.NotifyAll(types.StateUpdate{
|
||||
Type: types.StateFullUpdate,
|
||||
})
|
||||
}
|
||||
|
||||
default:
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/patrickmn/go-cache"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
|
@ -218,7 +219,10 @@ func (hsdb *HSDatabase) SetTags(
|
|||
}
|
||||
machine.ForcedTags = newTags
|
||||
|
||||
hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
|
||||
hsdb.notifier.NotifyWithIgnore(types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
Changed: []uint64{machine.ID},
|
||||
}, machine.MachineKey)
|
||||
|
||||
if err := hsdb.db.Save(machine).Error; err != nil {
|
||||
return fmt.Errorf("failed to update tags for machine in the database: %w", err)
|
||||
|
@ -232,7 +236,10 @@ func (hsdb *HSDatabase) ExpireMachine(machine *types.Machine) error {
|
|||
now := time.Now()
|
||||
machine.Expiry = &now
|
||||
|
||||
hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
|
||||
hsdb.notifier.NotifyWithIgnore(types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
Changed: []uint64{machine.ID},
|
||||
}, machine.MachineKey)
|
||||
|
||||
if err := hsdb.db.Save(machine).Error; err != nil {
|
||||
return fmt.Errorf("failed to expire machine in the database: %w", err)
|
||||
|
@ -259,7 +266,10 @@ func (hsdb *HSDatabase) RenameMachine(machine *types.Machine, newName string) er
|
|||
}
|
||||
machine.GivenName = newName
|
||||
|
||||
hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
|
||||
hsdb.notifier.NotifyWithIgnore(types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
Changed: []uint64{machine.ID},
|
||||
}, machine.MachineKey)
|
||||
|
||||
if err := hsdb.db.Save(machine).Error; err != nil {
|
||||
return fmt.Errorf("failed to rename machine in the database: %w", err)
|
||||
|
@ -275,7 +285,10 @@ func (hsdb *HSDatabase) RefreshMachine(machine *types.Machine, expiry time.Time)
|
|||
machine.LastSuccessfulUpdate = &now
|
||||
machine.Expiry = &expiry
|
||||
|
||||
hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
|
||||
hsdb.notifier.NotifyWithIgnore(types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
Changed: []uint64{machine.ID},
|
||||
}, machine.MachineKey)
|
||||
|
||||
if err := hsdb.db.Save(machine).Error; err != nil {
|
||||
return fmt.Errorf(
|
||||
|
@ -549,6 +562,27 @@ func (hsdb *HSDatabase) IsRoutesEnabled(machine *types.Machine, routeStr string)
|
|||
return false
|
||||
}
|
||||
|
||||
func OnlineMachineMap(peers types.Machines) map[tailcfg.NodeID]bool {
|
||||
ret := make(map[tailcfg.NodeID]bool)
|
||||
|
||||
for _, peer := range peers {
|
||||
ret[tailcfg.NodeID(peer.ID)] = peer.IsOnline()
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) ListOnlineMachines(
|
||||
machine *types.Machine,
|
||||
) (map[tailcfg.NodeID]bool, error) {
|
||||
peers, err := hsdb.ListPeers(machine)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return OnlineMachineMap(peers), nil
|
||||
}
|
||||
|
||||
// enableRoutes enables new routes based on a list of new routes.
|
||||
func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string) error {
|
||||
newRoutes := make([]netip.Prefix, len(routeStrs))
|
||||
|
@ -600,7 +634,10 @@ func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string
|
|||
}
|
||||
}
|
||||
|
||||
hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
|
||||
hsdb.notifier.NotifyWithIgnore(types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
Changed: []uint64{machine.ID},
|
||||
}, machine.MachineKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -676,12 +713,13 @@ func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Durati
|
|||
return
|
||||
}
|
||||
|
||||
expiredFound := false
|
||||
expired := make([]tailcfg.NodeID, 0)
|
||||
for idx, machine := range machines {
|
||||
if machine.IsEphemeral() && machine.LastSeen != nil &&
|
||||
time.Now().
|
||||
After(machine.LastSeen.Add(inactivityThreshhold)) {
|
||||
expiredFound = true
|
||||
expired = append(expired, tailcfg.NodeID(machine.ID))
|
||||
|
||||
log.Info().
|
||||
Str("machine", machine.Hostname).
|
||||
Msg("Ephemeral client removed from database")
|
||||
|
@ -696,8 +734,11 @@ func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Durati
|
|||
}
|
||||
}
|
||||
|
||||
if expiredFound {
|
||||
hsdb.notifier.NotifyAll()
|
||||
if len(expired) > 0 {
|
||||
hsdb.notifier.NotifyAll(types.StateUpdate{
|
||||
Type: types.StatePeerRemoved,
|
||||
Removed: expired,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -726,11 +767,11 @@ func (hsdb *HSDatabase) ExpireExpiredMachines(lastCheck time.Time) time.Time {
|
|||
return time.Unix(0, 0)
|
||||
}
|
||||
|
||||
expiredFound := false
|
||||
expired := make([]tailcfg.NodeID, 0)
|
||||
for index, machine := range machines {
|
||||
if machine.IsExpired() &&
|
||||
machine.Expiry.After(lastCheck) {
|
||||
expiredFound = true
|
||||
expired = append(expired, tailcfg.NodeID(machine.ID))
|
||||
|
||||
err := hsdb.ExpireMachine(&machines[index])
|
||||
if err != nil {
|
||||
|
@ -748,8 +789,11 @@ func (hsdb *HSDatabase) ExpireExpiredMachines(lastCheck time.Time) time.Time {
|
|||
}
|
||||
}
|
||||
|
||||
if expiredFound {
|
||||
hsdb.notifier.NotifyAll()
|
||||
if len(expired) > 0 {
|
||||
hsdb.notifier.NotifyAll(types.StateUpdate{
|
||||
Type: types.StatePeerRemoved,
|
||||
Removed: expired,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -274,7 +274,7 @@ func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error {
|
|||
log.Error().Err(err).Msg("error getting routes")
|
||||
}
|
||||
|
||||
routesChanged := false
|
||||
changedMachines := make([]uint64, 0)
|
||||
for pos, route := range routes {
|
||||
if route.IsExitRoute() {
|
||||
continue
|
||||
|
@ -295,7 +295,7 @@ func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error {
|
|||
return err
|
||||
}
|
||||
|
||||
routesChanged = true
|
||||
changedMachines = append(changedMachines, route.MachineID)
|
||||
|
||||
continue
|
||||
}
|
||||
|
@ -369,12 +369,15 @@ func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error {
|
|||
return err
|
||||
}
|
||||
|
||||
routesChanged = true
|
||||
changedMachines = append(changedMachines, route.MachineID)
|
||||
}
|
||||
}
|
||||
|
||||
if routesChanged {
|
||||
hsdb.notifier.NotifyAll()
|
||||
if len(changedMachines) > 0 {
|
||||
hsdb.notifier.NotifyAll(types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
Changed: changedMachines,
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -129,45 +130,35 @@ func fullMapResponse(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// Peers is always returned sorted by Node.ID.
|
||||
sort.SliceStable(tailPeers, func(x, y int) bool {
|
||||
return tailPeers[x].ID < tailPeers[y].ID
|
||||
})
|
||||
|
||||
now := time.Now()
|
||||
|
||||
resp := tailcfg.MapResponse{
|
||||
KeepAlive: false,
|
||||
Node: tailnode,
|
||||
|
||||
// TODO: Only send if updated
|
||||
DERPMap: derpMap,
|
||||
|
||||
// TODO: Only send if updated
|
||||
Node: tailnode,
|
||||
Peers: tailPeers,
|
||||
|
||||
// TODO(kradalby): Implement:
|
||||
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L1351-L1374
|
||||
// PeersChanged
|
||||
// PeersRemoved
|
||||
// PeersChangedPatch
|
||||
// PeerSeenChange
|
||||
// OnlineChange
|
||||
DERPMap: derpMap,
|
||||
|
||||
// TODO: Only send if updated
|
||||
DNSConfig: dnsConfig,
|
||||
Domain: baseDomain,
|
||||
|
||||
// TODO: Only send if updated
|
||||
Domain: baseDomain,
|
||||
|
||||
// Do not instruct clients to collect services, we do not
|
||||
// Do not instruct clients to collect services we do not
|
||||
// support or do anything with them
|
||||
CollectServices: "false",
|
||||
|
||||
// TODO: Only send if updated
|
||||
PacketFilter: policy.ReduceFilterRules(machine, rules),
|
||||
|
||||
UserProfiles: profiles,
|
||||
|
||||
// TODO: Only send if updated
|
||||
SSHPolicy: sshPolicy,
|
||||
|
||||
ControlTime: &now,
|
||||
ControlTime: &now,
|
||||
KeepAlive: false,
|
||||
OnlineChange: db.OnlineMachineMap(peers),
|
||||
|
||||
Debug: &tailcfg.Debug{
|
||||
DisableLogTail: !logtail,
|
||||
|
@ -271,8 +262,8 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine types.Machine) {
|
|||
}
|
||||
}
|
||||
|
||||
// CreateMapResponse returns a MapResponse for the given machine.
|
||||
func (m Mapper) CreateMapResponse(
|
||||
// FullMapResponse returns a MapResponse for the given machine.
|
||||
func (m Mapper) FullMapResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
machine *types.Machine,
|
||||
pol *policy.ACLPolicy,
|
||||
|
@ -302,39 +293,107 @@ func (m Mapper) CreateMapResponse(
|
|||
}
|
||||
|
||||
if m.isNoise {
|
||||
return m.marshalMapResponse(mapResponse, key.MachinePublic{}, mapRequest.Compress)
|
||||
return m.marshalMapResponse(mapResponse, machine, mapRequest.Compress)
|
||||
}
|
||||
|
||||
var machineKey key.MachinePublic
|
||||
err = machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)))
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Cannot parse client key")
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m.marshalMapResponse(mapResponse, machineKey, mapRequest.Compress)
|
||||
return m.marshalMapResponse(mapResponse, machine, mapRequest.Compress)
|
||||
}
|
||||
|
||||
func (m Mapper) CreateKeepAliveResponse(
|
||||
func (m Mapper) KeepAliveResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
machine *types.Machine,
|
||||
) ([]byte, error) {
|
||||
keepAliveResponse := tailcfg.MapResponse{
|
||||
KeepAlive: true,
|
||||
resp := m.baseMapResponse(machine)
|
||||
resp.KeepAlive = true
|
||||
|
||||
return m.marshalMapResponse(&resp, machine, mapRequest.Compress)
|
||||
}
|
||||
|
||||
func (m Mapper) DERPMapResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
machine *types.Machine,
|
||||
derpMap tailcfg.DERPMap,
|
||||
) ([]byte, error) {
|
||||
resp := m.baseMapResponse(machine)
|
||||
resp.DERPMap = &derpMap
|
||||
|
||||
return m.marshalMapResponse(&resp, machine, mapRequest.Compress)
|
||||
}
|
||||
|
||||
func (m Mapper) PeerChangedResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
machine *types.Machine,
|
||||
machineKeys []uint64,
|
||||
pol *policy.ACLPolicy,
|
||||
) ([]byte, error) {
|
||||
var err error
|
||||
changed := make(types.Machines, len(machineKeys))
|
||||
lastSeen := make(map[tailcfg.NodeID]bool)
|
||||
for idx, machineKey := range machineKeys {
|
||||
peer, err := m.db.GetMachineByID(machineKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
changed[idx] = *peer
|
||||
|
||||
// We have just seen the node, let the peers update their list.
|
||||
lastSeen[tailcfg.NodeID(peer.ID)] = true
|
||||
}
|
||||
|
||||
if m.isNoise {
|
||||
return m.marshalMapResponse(
|
||||
keepAliveResponse,
|
||||
key.MachinePublic{},
|
||||
mapRequest.Compress,
|
||||
)
|
||||
rules, _, err := policy.GenerateFilterAndSSHRules(
|
||||
pol,
|
||||
machine,
|
||||
changed,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Filter out peers that have expired.
|
||||
changed = lo.Filter(changed, func(item types.Machine, index int) bool {
|
||||
return !item.IsExpired()
|
||||
})
|
||||
|
||||
// If there are filter rules present, see if there are any machines that cannot
|
||||
// access eachother at all and remove them from the changed.
|
||||
if len(rules) > 0 {
|
||||
changed = policy.FilterMachinesByACL(machine, changed, rules)
|
||||
}
|
||||
|
||||
tailPeers, err := tailNodes(changed, pol, m.dnsCfg, m.baseDomain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Peers is always returned sorted by Node.ID.
|
||||
sort.SliceStable(tailPeers, func(x, y int) bool {
|
||||
return tailPeers[x].ID < tailPeers[y].ID
|
||||
})
|
||||
|
||||
resp := m.baseMapResponse(machine)
|
||||
resp.PeersChanged = tailPeers
|
||||
resp.PeerSeenChange = lastSeen
|
||||
|
||||
return m.marshalMapResponse(&resp, machine, mapRequest.Compress)
|
||||
}
|
||||
|
||||
func (m Mapper) PeerRemovedResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
machine *types.Machine,
|
||||
removed []tailcfg.NodeID,
|
||||
) ([]byte, error) {
|
||||
resp := m.baseMapResponse(machine)
|
||||
resp.PeersRemoved = removed
|
||||
|
||||
return m.marshalMapResponse(&resp, machine, mapRequest.Compress)
|
||||
}
|
||||
|
||||
func (m Mapper) marshalMapResponse(
|
||||
resp *tailcfg.MapResponse,
|
||||
machine *types.Machine,
|
||||
compression string,
|
||||
) ([]byte, error) {
|
||||
var machineKey key.MachinePublic
|
||||
err := machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)))
|
||||
if err != nil {
|
||||
|
@ -346,40 +405,6 @@ func (m Mapper) CreateKeepAliveResponse(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
return m.marshalMapResponse(keepAliveResponse, machineKey, mapRequest.Compress)
|
||||
}
|
||||
|
||||
// MarshalResponse takes an Tailscale Response, marhsal it to JSON.
|
||||
// If isNoise is set, then the JSON body will be returned
|
||||
// If !isNoise and privateKey2019 is set, the JSON body will be sealed in a Nacl box.
|
||||
func MarshalResponse(
|
||||
resp interface{},
|
||||
isNoise bool,
|
||||
privateKey2019 *key.MachinePrivate,
|
||||
machineKey key.MachinePublic,
|
||||
) ([]byte, error) {
|
||||
jsonBody, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Cannot marshal response")
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !isNoise && privateKey2019 != nil {
|
||||
return privateKey2019.SealTo(machineKey, jsonBody), nil
|
||||
}
|
||||
|
||||
return jsonBody, nil
|
||||
}
|
||||
|
||||
func (m Mapper) marshalMapResponse(
|
||||
resp interface{},
|
||||
machineKey key.MachinePublic,
|
||||
compression string,
|
||||
) ([]byte, error) {
|
||||
jsonBody, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
|
@ -409,6 +434,32 @@ func (m Mapper) marshalMapResponse(
|
|||
return data, nil
|
||||
}
|
||||
|
||||
// MarshalResponse takes an Tailscale Response, marhsal it to JSON.
|
||||
// If isNoise is set, then the JSON body will be returned
|
||||
// If !isNoise and privateKey2019 is set, the JSON body will be sealed in a Nacl box.
|
||||
func MarshalResponse(
|
||||
resp interface{},
|
||||
isNoise bool,
|
||||
privateKey2019 *key.MachinePrivate,
|
||||
machineKey key.MachinePublic,
|
||||
) ([]byte, error) {
|
||||
jsonBody, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Cannot marshal response")
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !isNoise && privateKey2019 != nil {
|
||||
return privateKey2019.SealTo(machineKey, jsonBody), nil
|
||||
}
|
||||
|
||||
return jsonBody, nil
|
||||
}
|
||||
|
||||
func zstdEncode(in []byte) []byte {
|
||||
encoder, ok := zstdEncoderPool.Get().(*zstd.Encoder)
|
||||
if !ok {
|
||||
|
@ -433,3 +484,19 @@ var zstdEncoderPool = &sync.Pool{
|
|||
return encoder
|
||||
},
|
||||
}
|
||||
|
||||
func (m *Mapper) baseMapResponse(machine *types.Machine) tailcfg.MapResponse {
|
||||
now := time.Now()
|
||||
|
||||
resp := tailcfg.MapResponse{
|
||||
KeepAlive: false,
|
||||
ControlTime: &now,
|
||||
}
|
||||
|
||||
online, err := m.db.ListOnlineMachines(machine)
|
||||
if err == nil {
|
||||
resp.OnlineChange = online
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
|
|
@ -387,6 +387,7 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
DNSConfig: &tailcfg.DNSConfig{},
|
||||
Domain: "",
|
||||
CollectServices: "false",
|
||||
OnlineChange: map[tailcfg.NodeID]bool{tailPeer1.ID: false},
|
||||
PacketFilter: []tailcfg.FilterRule{},
|
||||
UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}},
|
||||
SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
|
||||
|
@ -428,6 +429,7 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
DNSConfig: &tailcfg.DNSConfig{},
|
||||
Domain: "",
|
||||
CollectServices: "false",
|
||||
OnlineChange: map[tailcfg.NodeID]bool{tailPeer1.ID: false},
|
||||
PacketFilter: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.2/32"},
|
||||
|
|
|
@ -3,24 +3,25 @@ package notifier
|
|||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
)
|
||||
|
||||
type Notifier struct {
|
||||
l sync.RWMutex
|
||||
nodes map[string]chan<- struct{}
|
||||
nodes map[string]chan<- types.StateUpdate
|
||||
}
|
||||
|
||||
func NewNotifier() *Notifier {
|
||||
return &Notifier{}
|
||||
}
|
||||
|
||||
func (n *Notifier) AddNode(machineKey string, c chan<- struct{}) {
|
||||
func (n *Notifier) AddNode(machineKey string, c chan<- types.StateUpdate) {
|
||||
n.l.Lock()
|
||||
defer n.l.Unlock()
|
||||
|
||||
if n.nodes == nil {
|
||||
n.nodes = make(map[string]chan<- struct{})
|
||||
n.nodes = make(map[string]chan<- types.StateUpdate)
|
||||
}
|
||||
|
||||
n.nodes[machineKey] = c
|
||||
|
@ -37,11 +38,11 @@ func (n *Notifier) RemoveNode(machineKey string) {
|
|||
delete(n.nodes, machineKey)
|
||||
}
|
||||
|
||||
func (n *Notifier) NotifyAll() {
|
||||
n.NotifyWithIgnore()
|
||||
func (n *Notifier) NotifyAll(update types.StateUpdate) {
|
||||
n.NotifyWithIgnore(update)
|
||||
}
|
||||
|
||||
func (n *Notifier) NotifyWithIgnore(ignore ...string) {
|
||||
func (n *Notifier) NotifyWithIgnore(update types.StateUpdate, ignore ...string) {
|
||||
n.l.RLock()
|
||||
defer n.l.RUnlock()
|
||||
|
||||
|
@ -50,6 +51,6 @@ func (n *Notifier) NotifyWithIgnore(ignore ...string) {
|
|||
continue
|
||||
}
|
||||
|
||||
c <- struct{}{}
|
||||
c <- update
|
||||
}
|
||||
}
|
||||
|
|
|
@ -116,7 +116,7 @@ func (h *Headscale) handlePoll(
|
|||
return
|
||||
}
|
||||
|
||||
mapResp, err := mapp.CreateMapResponse(mapRequest, machine, h.ACLPolicy)
|
||||
mapResp, err := mapp.FullMapResponse(mapRequest, machine, h.ACLPolicy)
|
||||
if err != nil {
|
||||
logErr(err, "Failed to create MapResponse")
|
||||
http.Error(writer, "", http.StatusInternalServerError)
|
||||
|
@ -163,7 +163,12 @@ func (h *Headscale) handlePoll(
|
|||
Inc()
|
||||
|
||||
// Tell all the other nodes about the new endpoint, but dont update ourselves.
|
||||
h.nodeNotifier.NotifyWithIgnore(machine.MachineKey)
|
||||
h.nodeNotifier.NotifyWithIgnore(
|
||||
types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
Changed: []uint64{machine.ID},
|
||||
},
|
||||
machine.MachineKey)
|
||||
|
||||
return
|
||||
} else if mapRequest.OmitPeers && mapRequest.Stream {
|
||||
|
@ -220,7 +225,7 @@ func (h *Headscale) pollNetMapStream(
|
|||
keepAliveTicker := time.NewTicker(keepAliveInterval)
|
||||
|
||||
const chanSize = 8
|
||||
updateChan := make(chan struct{}, chanSize)
|
||||
updateChan := make(chan types.StateUpdate, chanSize)
|
||||
|
||||
h.pollNetMapStreamWG.Add(1)
|
||||
defer h.pollNetMapStreamWG.Done()
|
||||
|
@ -238,7 +243,7 @@ func (h *Headscale) pollNetMapStream(
|
|||
for {
|
||||
select {
|
||||
case <-keepAliveTicker.C:
|
||||
data, err := mapp.CreateKeepAliveResponse(mapRequest, machine)
|
||||
data, err := mapp.KeepAliveResponse(mapRequest, machine)
|
||||
if err != nil {
|
||||
logErr(err, "Error generating the keep alive msg")
|
||||
|
||||
|
@ -263,10 +268,23 @@ func (h *Headscale) pollNetMapStream(
|
|||
return
|
||||
}
|
||||
|
||||
case <-updateChan:
|
||||
data, err := mapp.CreateMapResponse(mapRequest, machine, h.ACLPolicy)
|
||||
case update := <-updateChan:
|
||||
var data []byte
|
||||
var err error
|
||||
|
||||
switch update.Type {
|
||||
case types.StateFullUpdate:
|
||||
data, err = mapp.FullMapResponse(mapRequest, machine, h.ACLPolicy)
|
||||
case types.StatePeerChanged:
|
||||
data, err = mapp.PeerChangedResponse(mapRequest, machine, update.Changed, h.ACLPolicy)
|
||||
case types.StatePeerRemoved:
|
||||
data, err = mapp.PeerRemovedResponse(mapRequest, machine, update.Removed)
|
||||
case types.StateDERPUpdated:
|
||||
data, err = mapp.DERPMapResponse(mapRequest, machine, update.DERPMap)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logErr(err, "Could not get the map update")
|
||||
logErr(err, "Could not get the create map update")
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -317,7 +335,7 @@ func (h *Headscale) pollNetMapStream(
|
|||
}
|
||||
}
|
||||
|
||||
func closeChanWithLog[C chan []byte | chan struct{}](channel C, machine, name string) {
|
||||
func closeChanWithLog[C chan []byte | chan struct{} | chan types.StateUpdate](channel C, machine, name string) {
|
||||
log.Trace().
|
||||
Str("handler", "PollNetMap").
|
||||
Str("machine", machine).
|
||||
|
|
|
@ -106,3 +106,32 @@ func (i StringList) Value() (driver.Value, error) {
|
|||
|
||||
return string(bytes), err
|
||||
}
|
||||
|
||||
type StateUpdateType int
|
||||
|
||||
const (
|
||||
StateFullUpdate StateUpdateType = iota
|
||||
StatePeerChanged
|
||||
StatePeerRemoved
|
||||
StateDERPUpdated
|
||||
)
|
||||
|
||||
// StateUpdate is an internal message containing information about
|
||||
// a state change that has happened to the network.
|
||||
type StateUpdate struct {
|
||||
// The type of update
|
||||
Type StateUpdateType
|
||||
|
||||
// Changed must be set when Type is StatePeerChanged and
|
||||
// contain the Machine IDs of machines that has changed.
|
||||
Changed []uint64
|
||||
|
||||
// Removed must be set when Type is StatePeerRemoved and
|
||||
// contain a list of the nodes that has been removed from
|
||||
// the network.
|
||||
Removed []tailcfg.NodeID
|
||||
|
||||
// DERPMap must be set when Type is StateDERPUpdated and
|
||||
// contain the new DERP Map.
|
||||
DERPMap tailcfg.DERPMap
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue