Split up MapResponse

This commits extends the mapper with functions for creating "delta"
MapResponses for different purposes (peer changed, peer removed, derp).

This wires up the new state management with a new StateUpdate struct
letting the poll worker know what kind of update to send to the
connected nodes.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2023-06-29 11:20:22 +01:00 committed by Kristoffer Dalby
parent 66ff1fcd40
commit 4b65cf48d0
8 changed files with 284 additions and 115 deletions

View file

@ -257,7 +257,10 @@ func (h *Headscale) scheduledDERPMapUpdateWorker(cancelChan <-chan struct{}) {
h.DERPMap.Regions[region.RegionID] = &region
}
h.nodeNotifier.NotifyAll()
h.nodeNotifier.NotifyAll(types.StateUpdate{
Type: types.StateDERPUpdated,
DERPMap: *h.DERPMap,
})
}
}
}
@ -721,7 +724,9 @@ func (h *Headscale) Serve() error {
Str("path", aclPath).
Msg("ACL policy successfully reloaded, notifying nodes of change")
h.nodeNotifier.NotifyAll()
h.nodeNotifier.NotifyAll(types.StateUpdate{
Type: types.StateFullUpdate,
})
}
default:

View file

@ -13,6 +13,7 @@ import (
"github.com/patrickmn/go-cache"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
@ -218,7 +219,10 @@ func (hsdb *HSDatabase) SetTags(
}
machine.ForcedTags = newTags
hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
hsdb.notifier.NotifyWithIgnore(types.StateUpdate{
Type: types.StatePeerChanged,
Changed: []uint64{machine.ID},
}, machine.MachineKey)
if err := hsdb.db.Save(machine).Error; err != nil {
return fmt.Errorf("failed to update tags for machine in the database: %w", err)
@ -232,7 +236,10 @@ func (hsdb *HSDatabase) ExpireMachine(machine *types.Machine) error {
now := time.Now()
machine.Expiry = &now
hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
hsdb.notifier.NotifyWithIgnore(types.StateUpdate{
Type: types.StatePeerChanged,
Changed: []uint64{machine.ID},
}, machine.MachineKey)
if err := hsdb.db.Save(machine).Error; err != nil {
return fmt.Errorf("failed to expire machine in the database: %w", err)
@ -259,7 +266,10 @@ func (hsdb *HSDatabase) RenameMachine(machine *types.Machine, newName string) er
}
machine.GivenName = newName
hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
hsdb.notifier.NotifyWithIgnore(types.StateUpdate{
Type: types.StatePeerChanged,
Changed: []uint64{machine.ID},
}, machine.MachineKey)
if err := hsdb.db.Save(machine).Error; err != nil {
return fmt.Errorf("failed to rename machine in the database: %w", err)
@ -275,7 +285,10 @@ func (hsdb *HSDatabase) RefreshMachine(machine *types.Machine, expiry time.Time)
machine.LastSuccessfulUpdate = &now
machine.Expiry = &expiry
hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
hsdb.notifier.NotifyWithIgnore(types.StateUpdate{
Type: types.StatePeerChanged,
Changed: []uint64{machine.ID},
}, machine.MachineKey)
if err := hsdb.db.Save(machine).Error; err != nil {
return fmt.Errorf(
@ -549,6 +562,27 @@ func (hsdb *HSDatabase) IsRoutesEnabled(machine *types.Machine, routeStr string)
return false
}
func OnlineMachineMap(peers types.Machines) map[tailcfg.NodeID]bool {
ret := make(map[tailcfg.NodeID]bool)
for _, peer := range peers {
ret[tailcfg.NodeID(peer.ID)] = peer.IsOnline()
}
return ret
}
func (hsdb *HSDatabase) ListOnlineMachines(
machine *types.Machine,
) (map[tailcfg.NodeID]bool, error) {
peers, err := hsdb.ListPeers(machine)
if err != nil {
return nil, err
}
return OnlineMachineMap(peers), nil
}
// enableRoutes enables new routes based on a list of new routes.
func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string) error {
newRoutes := make([]netip.Prefix, len(routeStrs))
@ -600,7 +634,10 @@ func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string
}
}
hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
hsdb.notifier.NotifyWithIgnore(types.StateUpdate{
Type: types.StatePeerChanged,
Changed: []uint64{machine.ID},
}, machine.MachineKey)
return nil
}
@ -676,12 +713,13 @@ func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Durati
return
}
expiredFound := false
expired := make([]tailcfg.NodeID, 0)
for idx, machine := range machines {
if machine.IsEphemeral() && machine.LastSeen != nil &&
time.Now().
After(machine.LastSeen.Add(inactivityThreshhold)) {
expiredFound = true
expired = append(expired, tailcfg.NodeID(machine.ID))
log.Info().
Str("machine", machine.Hostname).
Msg("Ephemeral client removed from database")
@ -696,8 +734,11 @@ func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Durati
}
}
if expiredFound {
hsdb.notifier.NotifyAll()
if len(expired) > 0 {
hsdb.notifier.NotifyAll(types.StateUpdate{
Type: types.StatePeerRemoved,
Removed: expired,
})
}
}
}
@ -726,11 +767,11 @@ func (hsdb *HSDatabase) ExpireExpiredMachines(lastCheck time.Time) time.Time {
return time.Unix(0, 0)
}
expiredFound := false
expired := make([]tailcfg.NodeID, 0)
for index, machine := range machines {
if machine.IsExpired() &&
machine.Expiry.After(lastCheck) {
expiredFound = true
expired = append(expired, tailcfg.NodeID(machine.ID))
err := hsdb.ExpireMachine(&machines[index])
if err != nil {
@ -748,8 +789,11 @@ func (hsdb *HSDatabase) ExpireExpiredMachines(lastCheck time.Time) time.Time {
}
}
if expiredFound {
hsdb.notifier.NotifyAll()
if len(expired) > 0 {
hsdb.notifier.NotifyAll(types.StateUpdate{
Type: types.StatePeerRemoved,
Removed: expired,
})
}
}

View file

@ -274,7 +274,7 @@ func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error {
log.Error().Err(err).Msg("error getting routes")
}
routesChanged := false
changedMachines := make([]uint64, 0)
for pos, route := range routes {
if route.IsExitRoute() {
continue
@ -295,7 +295,7 @@ func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error {
return err
}
routesChanged = true
changedMachines = append(changedMachines, route.MachineID)
continue
}
@ -369,12 +369,15 @@ func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error {
return err
}
routesChanged = true
changedMachines = append(changedMachines, route.MachineID)
}
}
if routesChanged {
hsdb.notifier.NotifyAll()
if len(changedMachines) > 0 {
hsdb.notifier.NotifyAll(types.StateUpdate{
Type: types.StatePeerChanged,
Changed: changedMachines,
})
}
return nil

View file

@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"net/url"
"sort"
"strings"
"sync"
"time"
@ -129,45 +130,35 @@ func fullMapResponse(
return nil, err
}
// Peers is always returned sorted by Node.ID.
sort.SliceStable(tailPeers, func(x, y int) bool {
return tailPeers[x].ID < tailPeers[y].ID
})
now := time.Now()
resp := tailcfg.MapResponse{
KeepAlive: false,
Node: tailnode,
// TODO: Only send if updated
DERPMap: derpMap,
// TODO: Only send if updated
Node: tailnode,
Peers: tailPeers,
// TODO(kradalby): Implement:
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L1351-L1374
// PeersChanged
// PeersRemoved
// PeersChangedPatch
// PeerSeenChange
// OnlineChange
DERPMap: derpMap,
// TODO: Only send if updated
DNSConfig: dnsConfig,
Domain: baseDomain,
// TODO: Only send if updated
Domain: baseDomain,
// Do not instruct clients to collect services, we do not
// Do not instruct clients to collect services we do not
// support or do anything with them
CollectServices: "false",
// TODO: Only send if updated
PacketFilter: policy.ReduceFilterRules(machine, rules),
UserProfiles: profiles,
// TODO: Only send if updated
SSHPolicy: sshPolicy,
ControlTime: &now,
ControlTime: &now,
KeepAlive: false,
OnlineChange: db.OnlineMachineMap(peers),
Debug: &tailcfg.Debug{
DisableLogTail: !logtail,
@ -271,8 +262,8 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine types.Machine) {
}
}
// CreateMapResponse returns a MapResponse for the given machine.
func (m Mapper) CreateMapResponse(
// FullMapResponse returns a MapResponse for the given machine.
func (m Mapper) FullMapResponse(
mapRequest tailcfg.MapRequest,
machine *types.Machine,
pol *policy.ACLPolicy,
@ -302,39 +293,107 @@ func (m Mapper) CreateMapResponse(
}
if m.isNoise {
return m.marshalMapResponse(mapResponse, key.MachinePublic{}, mapRequest.Compress)
return m.marshalMapResponse(mapResponse, machine, mapRequest.Compress)
}
var machineKey key.MachinePublic
err = machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot parse client key")
return nil, err
}
return m.marshalMapResponse(mapResponse, machineKey, mapRequest.Compress)
return m.marshalMapResponse(mapResponse, machine, mapRequest.Compress)
}
func (m Mapper) CreateKeepAliveResponse(
func (m Mapper) KeepAliveResponse(
mapRequest tailcfg.MapRequest,
machine *types.Machine,
) ([]byte, error) {
keepAliveResponse := tailcfg.MapResponse{
KeepAlive: true,
resp := m.baseMapResponse(machine)
resp.KeepAlive = true
return m.marshalMapResponse(&resp, machine, mapRequest.Compress)
}
func (m Mapper) DERPMapResponse(
mapRequest tailcfg.MapRequest,
machine *types.Machine,
derpMap tailcfg.DERPMap,
) ([]byte, error) {
resp := m.baseMapResponse(machine)
resp.DERPMap = &derpMap
return m.marshalMapResponse(&resp, machine, mapRequest.Compress)
}
func (m Mapper) PeerChangedResponse(
mapRequest tailcfg.MapRequest,
machine *types.Machine,
machineKeys []uint64,
pol *policy.ACLPolicy,
) ([]byte, error) {
var err error
changed := make(types.Machines, len(machineKeys))
lastSeen := make(map[tailcfg.NodeID]bool)
for idx, machineKey := range machineKeys {
peer, err := m.db.GetMachineByID(machineKey)
if err != nil {
return nil, err
}
changed[idx] = *peer
// We have just seen the node, let the peers update their list.
lastSeen[tailcfg.NodeID(peer.ID)] = true
}
if m.isNoise {
return m.marshalMapResponse(
keepAliveResponse,
key.MachinePublic{},
mapRequest.Compress,
)
rules, _, err := policy.GenerateFilterAndSSHRules(
pol,
machine,
changed,
)
if err != nil {
return nil, err
}
// Filter out peers that have expired.
changed = lo.Filter(changed, func(item types.Machine, index int) bool {
return !item.IsExpired()
})
// If there are filter rules present, see if there are any machines that cannot
// access eachother at all and remove them from the changed.
if len(rules) > 0 {
changed = policy.FilterMachinesByACL(machine, changed, rules)
}
tailPeers, err := tailNodes(changed, pol, m.dnsCfg, m.baseDomain)
if err != nil {
return nil, err
}
// Peers is always returned sorted by Node.ID.
sort.SliceStable(tailPeers, func(x, y int) bool {
return tailPeers[x].ID < tailPeers[y].ID
})
resp := m.baseMapResponse(machine)
resp.PeersChanged = tailPeers
resp.PeerSeenChange = lastSeen
return m.marshalMapResponse(&resp, machine, mapRequest.Compress)
}
func (m Mapper) PeerRemovedResponse(
mapRequest tailcfg.MapRequest,
machine *types.Machine,
removed []tailcfg.NodeID,
) ([]byte, error) {
resp := m.baseMapResponse(machine)
resp.PeersRemoved = removed
return m.marshalMapResponse(&resp, machine, mapRequest.Compress)
}
func (m Mapper) marshalMapResponse(
resp *tailcfg.MapResponse,
machine *types.Machine,
compression string,
) ([]byte, error) {
var machineKey key.MachinePublic
err := machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)))
if err != nil {
@ -346,40 +405,6 @@ func (m Mapper) CreateKeepAliveResponse(
return nil, err
}
return m.marshalMapResponse(keepAliveResponse, machineKey, mapRequest.Compress)
}
// MarshalResponse takes an Tailscale Response, marhsal it to JSON.
// If isNoise is set, then the JSON body will be returned
// If !isNoise and privateKey2019 is set, the JSON body will be sealed in a Nacl box.
func MarshalResponse(
resp interface{},
isNoise bool,
privateKey2019 *key.MachinePrivate,
machineKey key.MachinePublic,
) ([]byte, error) {
jsonBody, err := json.Marshal(resp)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot marshal response")
return nil, err
}
if !isNoise && privateKey2019 != nil {
return privateKey2019.SealTo(machineKey, jsonBody), nil
}
return jsonBody, nil
}
func (m Mapper) marshalMapResponse(
resp interface{},
machineKey key.MachinePublic,
compression string,
) ([]byte, error) {
jsonBody, err := json.Marshal(resp)
if err != nil {
log.Error().
@ -409,6 +434,32 @@ func (m Mapper) marshalMapResponse(
return data, nil
}
// MarshalResponse takes an Tailscale Response, marhsal it to JSON.
// If isNoise is set, then the JSON body will be returned
// If !isNoise and privateKey2019 is set, the JSON body will be sealed in a Nacl box.
func MarshalResponse(
resp interface{},
isNoise bool,
privateKey2019 *key.MachinePrivate,
machineKey key.MachinePublic,
) ([]byte, error) {
jsonBody, err := json.Marshal(resp)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot marshal response")
return nil, err
}
if !isNoise && privateKey2019 != nil {
return privateKey2019.SealTo(machineKey, jsonBody), nil
}
return jsonBody, nil
}
func zstdEncode(in []byte) []byte {
encoder, ok := zstdEncoderPool.Get().(*zstd.Encoder)
if !ok {
@ -433,3 +484,19 @@ var zstdEncoderPool = &sync.Pool{
return encoder
},
}
func (m *Mapper) baseMapResponse(machine *types.Machine) tailcfg.MapResponse {
now := time.Now()
resp := tailcfg.MapResponse{
KeepAlive: false,
ControlTime: &now,
}
online, err := m.db.ListOnlineMachines(machine)
if err == nil {
resp.OnlineChange = online
}
return resp
}

View file

@ -387,6 +387,7 @@ func Test_fullMapResponse(t *testing.T) {
DNSConfig: &tailcfg.DNSConfig{},
Domain: "",
CollectServices: "false",
OnlineChange: map[tailcfg.NodeID]bool{tailPeer1.ID: false},
PacketFilter: []tailcfg.FilterRule{},
UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}},
SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
@ -428,6 +429,7 @@ func Test_fullMapResponse(t *testing.T) {
DNSConfig: &tailcfg.DNSConfig{},
Domain: "",
CollectServices: "false",
OnlineChange: map[tailcfg.NodeID]bool{tailPeer1.ID: false},
PacketFilter: []tailcfg.FilterRule{
{
SrcIPs: []string{"100.64.0.2/32"},

View file

@ -3,24 +3,25 @@ package notifier
import (
"sync"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
)
type Notifier struct {
l sync.RWMutex
nodes map[string]chan<- struct{}
nodes map[string]chan<- types.StateUpdate
}
func NewNotifier() *Notifier {
return &Notifier{}
}
func (n *Notifier) AddNode(machineKey string, c chan<- struct{}) {
func (n *Notifier) AddNode(machineKey string, c chan<- types.StateUpdate) {
n.l.Lock()
defer n.l.Unlock()
if n.nodes == nil {
n.nodes = make(map[string]chan<- struct{})
n.nodes = make(map[string]chan<- types.StateUpdate)
}
n.nodes[machineKey] = c
@ -37,11 +38,11 @@ func (n *Notifier) RemoveNode(machineKey string) {
delete(n.nodes, machineKey)
}
func (n *Notifier) NotifyAll() {
n.NotifyWithIgnore()
func (n *Notifier) NotifyAll(update types.StateUpdate) {
n.NotifyWithIgnore(update)
}
func (n *Notifier) NotifyWithIgnore(ignore ...string) {
func (n *Notifier) NotifyWithIgnore(update types.StateUpdate, ignore ...string) {
n.l.RLock()
defer n.l.RUnlock()
@ -50,6 +51,6 @@ func (n *Notifier) NotifyWithIgnore(ignore ...string) {
continue
}
c <- struct{}{}
c <- update
}
}

View file

@ -116,7 +116,7 @@ func (h *Headscale) handlePoll(
return
}
mapResp, err := mapp.CreateMapResponse(mapRequest, machine, h.ACLPolicy)
mapResp, err := mapp.FullMapResponse(mapRequest, machine, h.ACLPolicy)
if err != nil {
logErr(err, "Failed to create MapResponse")
http.Error(writer, "", http.StatusInternalServerError)
@ -163,7 +163,12 @@ func (h *Headscale) handlePoll(
Inc()
// Tell all the other nodes about the new endpoint, but dont update ourselves.
h.nodeNotifier.NotifyWithIgnore(machine.MachineKey)
h.nodeNotifier.NotifyWithIgnore(
types.StateUpdate{
Type: types.StatePeerChanged,
Changed: []uint64{machine.ID},
},
machine.MachineKey)
return
} else if mapRequest.OmitPeers && mapRequest.Stream {
@ -220,7 +225,7 @@ func (h *Headscale) pollNetMapStream(
keepAliveTicker := time.NewTicker(keepAliveInterval)
const chanSize = 8
updateChan := make(chan struct{}, chanSize)
updateChan := make(chan types.StateUpdate, chanSize)
h.pollNetMapStreamWG.Add(1)
defer h.pollNetMapStreamWG.Done()
@ -238,7 +243,7 @@ func (h *Headscale) pollNetMapStream(
for {
select {
case <-keepAliveTicker.C:
data, err := mapp.CreateKeepAliveResponse(mapRequest, machine)
data, err := mapp.KeepAliveResponse(mapRequest, machine)
if err != nil {
logErr(err, "Error generating the keep alive msg")
@ -263,10 +268,23 @@ func (h *Headscale) pollNetMapStream(
return
}
case <-updateChan:
data, err := mapp.CreateMapResponse(mapRequest, machine, h.ACLPolicy)
case update := <-updateChan:
var data []byte
var err error
switch update.Type {
case types.StateFullUpdate:
data, err = mapp.FullMapResponse(mapRequest, machine, h.ACLPolicy)
case types.StatePeerChanged:
data, err = mapp.PeerChangedResponse(mapRequest, machine, update.Changed, h.ACLPolicy)
case types.StatePeerRemoved:
data, err = mapp.PeerRemovedResponse(mapRequest, machine, update.Removed)
case types.StateDERPUpdated:
data, err = mapp.DERPMapResponse(mapRequest, machine, update.DERPMap)
}
if err != nil {
logErr(err, "Could not get the map update")
logErr(err, "Could not get the create map update")
return
}
@ -317,7 +335,7 @@ func (h *Headscale) pollNetMapStream(
}
}
func closeChanWithLog[C chan []byte | chan struct{}](channel C, machine, name string) {
func closeChanWithLog[C chan []byte | chan struct{} | chan types.StateUpdate](channel C, machine, name string) {
log.Trace().
Str("handler", "PollNetMap").
Str("machine", machine).

View file

@ -106,3 +106,32 @@ func (i StringList) Value() (driver.Value, error) {
return string(bytes), err
}
type StateUpdateType int
const (
StateFullUpdate StateUpdateType = iota
StatePeerChanged
StatePeerRemoved
StateDERPUpdated
)
// StateUpdate is an internal message containing information about
// a state change that has happened to the network.
type StateUpdate struct {
// The type of update
Type StateUpdateType
// Changed must be set when Type is StatePeerChanged and
// contain the Machine IDs of machines that has changed.
Changed []uint64
// Removed must be set when Type is StatePeerRemoved and
// contain a list of the nodes that has been removed from
// the network.
Removed []tailcfg.NodeID
// DERPMap must be set when Type is StateDERPUpdated and
// contain the new DERP Map.
DERPMap tailcfg.DERPMap
}