move MapResponse peer logic into function and reuse

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2023-08-09 22:56:21 +02:00 committed by Kristoffer Dalby
parent 387aa03adb
commit 432e975a7f
7 changed files with 193 additions and 173 deletions

View file

@ -92,6 +92,8 @@ type Headscale struct {
shutdownChan chan struct{} shutdownChan chan struct{}
pollNetMapStreamWG sync.WaitGroup pollNetMapStreamWG sync.WaitGroup
pollStreamOpenMu sync.Mutex
} }
func NewHeadscale(cfg *types.Config) (*Headscale, error) { func NewHeadscale(cfg *types.Config) (*Headscale, error) {

View file

@ -340,6 +340,8 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error {
continue continue
} }
machine := &route.Machine
if !route.IsPrimary { if !route.IsPrimary {
_, err := hsdb.getPrimaryRoute(netip.Prefix(route.Prefix)) _, err := hsdb.getPrimaryRoute(netip.Prefix(route.Prefix))
if hsdb.isUniquePrefix(route) || errors.Is(err, gorm.ErrRecordNotFound) { if hsdb.isUniquePrefix(route) || errors.Is(err, gorm.ErrRecordNotFound) {
@ -355,7 +357,7 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error {
return err return err
} }
changedMachines = append(changedMachines, &route.Machine) changedMachines = append(changedMachines, machine)
continue continue
} }
@ -429,7 +431,7 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error {
return err return err
} }
changedMachines = append(changedMachines, &route.Machine) changedMachines = append(changedMachines, machine)
} }
} }

View file

@ -38,6 +38,16 @@ const (
var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH") var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH")
// TODO: Optimise
// As this work continues, the idea is that there will be one Mapper instance
// per node, attached to the open stream between the control and client.
// This means that this can hold a state per machine and we can use that to
// improve the mapresponses sent.
// We could:
// - Keep information about the previous mapresponse so we can send a diff
// - Store hashes
// - Create a "minifier" that removes info not needed for the node
type Mapper struct { type Mapper struct {
privateKey2019 *key.MachinePrivate privateKey2019 *key.MachinePrivate
isNoise bool isNoise bool
@ -102,105 +112,6 @@ func (m *Mapper) String() string {
return fmt.Sprintf("Mapper: { seq: %d, uid: %s, created: %s }", m.seq, m.uid, m.created) return fmt.Sprintf("Mapper: { seq: %d, uid: %s, created: %s }", m.seq, m.uid, m.created)
} }
// TODO: Optimise
// As this work continues, the idea is that there will be one Mapper instance
// per node, attached to the open stream between the control and client.
// This means that this can hold a state per machine and we can use that to
// improve the mapresponses sent.
// We could:
// - Keep information about the previous mapresponse so we can send a diff
// - Store hashes
// - Create a "minifier" that removes info not needed for the node
// fullMapResponse is the internal function for generating a MapResponse
// for a machine.
func fullMapResponse(
pol *policy.ACLPolicy,
machine *types.Machine,
peers types.Machines,
baseDomain string,
dnsCfg *tailcfg.DNSConfig,
derpMap *tailcfg.DERPMap,
logtail bool,
randomClientPort bool,
) (*tailcfg.MapResponse, error) {
tailnode, err := tailNode(machine, pol, dnsCfg, baseDomain)
if err != nil {
return nil, err
}
now := time.Now()
resp := tailcfg.MapResponse{
Node: tailnode,
DERPMap: derpMap,
Domain: baseDomain,
// Do not instruct clients to collect services we do not
// support or do anything with them
CollectServices: "false",
ControlTime: &now,
KeepAlive: false,
OnlineChange: db.OnlineMachineMap(peers),
Debug: &tailcfg.Debug{
DisableLogTail: !logtail,
RandomizeClientPort: randomClientPort,
},
}
if peers != nil || len(peers) > 0 {
rules, sshPolicy, err := policy.GenerateFilterAndSSHRules(
pol,
machine,
peers,
)
if err != nil {
return nil, err
}
// Filter out peers that have expired.
peers = filterExpiredAndNotReady(peers)
// If there are filter rules present, see if there are any machines that cannot
// access eachother at all and remove them from the peers.
if len(rules) > 0 {
peers = policy.FilterMachinesByACL(machine, peers, rules)
}
profiles := generateUserProfiles(machine, peers, baseDomain)
dnsConfig := generateDNSConfig(
dnsCfg,
baseDomain,
machine,
peers,
)
tailPeers, err := tailNodes(peers, pol, dnsCfg, baseDomain)
if err != nil {
return nil, err
}
// Peers is always returned sorted by Node.ID.
sort.SliceStable(tailPeers, func(x, y int) bool {
return tailPeers[x].ID < tailPeers[y].ID
})
resp.Peers = tailPeers
resp.DNSConfig = dnsConfig
resp.PacketFilter = policy.ReduceFilterRules(machine, rules)
resp.UserProfiles = profiles
resp.SSHPolicy = sshPolicy
}
return &resp, nil
}
func generateUserProfiles( func generateUserProfiles(
machine *types.Machine, machine *types.Machine,
peers types.Machines, peers types.Machines,
@ -294,6 +205,38 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine *types.Machine) {
} }
} }
// fullMapResponse creates a complete MapResponse for a node.
// It is a separate function to make testing easier.
func (m *Mapper) fullMapResponse(
machine *types.Machine,
pol *policy.ACLPolicy,
) (*tailcfg.MapResponse, error) {
peers := machineMapToList(m.peers)
resp, err := m.baseWithConfigMapResponse(machine, pol)
if err != nil {
return nil, err
}
// TODO(kradalby): Move this into appendPeerChanges?
resp.OnlineChange = db.OnlineMachineMap(peers)
err = appendPeerChanges(
resp,
pol,
machine,
peers,
peers,
m.baseDomain,
m.dnsCfg,
)
if err != nil {
return nil, err
}
return resp, nil
}
// FullMapResponse returns a MapResponse for the given machine. // FullMapResponse returns a MapResponse for the given machine.
func (m *Mapper) FullMapResponse( func (m *Mapper) FullMapResponse(
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
@ -303,25 +246,16 @@ func (m *Mapper) FullMapResponse(
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
mapResponse, err := fullMapResponse( resp, err := m.fullMapResponse(machine, pol)
pol,
machine,
machineMapToList(m.peers),
m.baseDomain,
m.dnsCfg,
m.derpMap,
m.logtail,
m.randomClientPort,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if m.isNoise { if m.isNoise {
return m.marshalMapResponse(mapRequest, mapResponse, machine, mapRequest.Compress) return m.marshalMapResponse(mapRequest, resp, machine, mapRequest.Compress)
} }
return m.marshalMapResponse(mapRequest, mapResponse, machine, mapRequest.Compress) return m.marshalMapResponse(mapRequest, resp, machine, mapRequest.Compress)
} }
// LiteMapResponse returns a MapResponse for the given machine. // LiteMapResponse returns a MapResponse for the given machine.
@ -332,32 +266,23 @@ func (m *Mapper) LiteMapResponse(
machine *types.Machine, machine *types.Machine,
pol *policy.ACLPolicy, pol *policy.ACLPolicy,
) ([]byte, error) { ) ([]byte, error) {
mapResponse, err := fullMapResponse( resp, err := m.baseWithConfigMapResponse(machine, pol)
pol,
machine,
nil,
m.baseDomain,
m.dnsCfg,
m.derpMap,
m.logtail,
m.randomClientPort,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if m.isNoise { if m.isNoise {
return m.marshalMapResponse(mapRequest, mapResponse, machine, mapRequest.Compress) return m.marshalMapResponse(mapRequest, resp, machine, mapRequest.Compress)
} }
return m.marshalMapResponse(mapRequest, mapResponse, machine, mapRequest.Compress) return m.marshalMapResponse(mapRequest, resp, machine, mapRequest.Compress)
} }
func (m *Mapper) KeepAliveResponse( func (m *Mapper) KeepAliveResponse(
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
machine *types.Machine, machine *types.Machine,
) ([]byte, error) { ) ([]byte, error) {
resp := m.baseMapResponse(machine) resp := m.baseMapResponse()
resp.KeepAlive = true resp.KeepAlive = true
return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress) return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress)
@ -368,7 +293,7 @@ func (m *Mapper) DERPMapResponse(
machine *types.Machine, machine *types.Machine,
derpMap tailcfg.DERPMap, derpMap tailcfg.DERPMap,
) ([]byte, error) { ) ([]byte, error) {
resp := m.baseMapResponse(machine) resp := m.baseMapResponse()
resp.DERPMap = &derpMap resp.DERPMap = &derpMap
return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress) return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress)
@ -383,7 +308,6 @@ func (m *Mapper) PeerChangedResponse(
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
var err error
lastSeen := make(map[tailcfg.NodeID]bool) lastSeen := make(map[tailcfg.NodeID]bool)
// Update our internal map. // Update our internal map.
@ -394,37 +318,21 @@ func (m *Mapper) PeerChangedResponse(
lastSeen[tailcfg.NodeID(machine.ID)] = true lastSeen[tailcfg.NodeID(machine.ID)] = true
} }
rules, sshPolicy, err := policy.GenerateFilterAndSSHRules( resp := m.baseMapResponse()
err := appendPeerChanges(
&resp,
pol, pol,
machine, machine,
machineMapToList(m.peers), machineMapToList(m.peers),
changed,
m.baseDomain,
m.dnsCfg,
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
changed = filterExpiredAndNotReady(changed)
// If there are filter rules present, see if there are any machines that cannot
// access eachother at all and remove them from the changed.
if len(rules) > 0 {
changed = policy.FilterMachinesByACL(machine, changed, rules)
}
tailPeers, err := tailNodes(changed, pol, m.dnsCfg, m.baseDomain)
if err != nil {
return nil, err
}
// Peers is always returned sorted by Node.ID.
sort.SliceStable(tailPeers, func(x, y int) bool {
return tailPeers[x].ID < tailPeers[y].ID
})
resp := m.baseMapResponse(machine)
resp.PeersChanged = tailPeers
resp.PacketFilter = policy.ReduceFilterRules(machine, rules)
resp.SSHPolicy = sshPolicy
// resp.PeerSeenChange = lastSeen // resp.PeerSeenChange = lastSeen
return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress) return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress)
@ -443,7 +351,7 @@ func (m *Mapper) PeerRemovedResponse(
delete(m.peers, uint64(id)) delete(m.peers, uint64(id))
} }
resp := m.baseMapResponse(machine) resp := m.baseMapResponse()
resp.PeersRemoved = removed resp.PeersRemoved = removed
return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress) return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress)
@ -497,7 +405,7 @@ func (m *Mapper) marshalMapResponse(
panic(err) panic(err)
} }
now := time.Now().Unix() now := time.Now().UnixNano()
mapResponsePath := path.Join( mapResponsePath := path.Join(
mPath, mPath,
@ -583,7 +491,9 @@ var zstdEncoderPool = &sync.Pool{
}, },
} }
func (m *Mapper) baseMapResponse(_ *types.Machine) tailcfg.MapResponse { // baseMapResponse returns a tailcfg.MapResponse with
// KeepAlive false and ControlTime set to now.
func (m *Mapper) baseMapResponse() tailcfg.MapResponse {
now := time.Now() now := time.Now()
resp := tailcfg.MapResponse{ resp := tailcfg.MapResponse{
@ -591,14 +501,43 @@ func (m *Mapper) baseMapResponse(_ *types.Machine) tailcfg.MapResponse {
ControlTime: &now, ControlTime: &now,
} }
// online, err := m.db.ListOnlineMachines(machine)
// if err == nil {
// resp.OnlineChange = online
// }
return resp return resp
} }
// baseWithConfigMapResponse returns a tailcfg.MapResponse struct
// with the basic configuration from headscale set.
// It is used in for bigger updates, such as full and lite, not
// incremental.
func (m *Mapper) baseWithConfigMapResponse(
machine *types.Machine,
pol *policy.ACLPolicy,
) (*tailcfg.MapResponse, error) {
resp := m.baseMapResponse()
tailnode, err := tailNode(machine, pol, m.dnsCfg, m.baseDomain)
if err != nil {
return nil, err
}
resp.Node = tailnode
resp.DERPMap = m.derpMap
resp.Domain = m.baseDomain
// Do not instruct clients to collect services we do not
// support or do anything with them
resp.CollectServices = "false"
resp.KeepAlive = false
resp.Debug = &tailcfg.Debug{
DisableLogTail: !m.logtail,
RandomizeClientPort: m.randomClientPort,
}
return &resp, nil
}
func machineMapToList(machines map[uint64]*types.Machine) types.Machines { func machineMapToList(machines map[uint64]*types.Machine) types.Machines {
ret := make(types.Machines, 0) ret := make(types.Machines, 0)
@ -617,3 +556,67 @@ func filterExpiredAndNotReady(peers types.Machines) types.Machines {
return !item.IsExpired() || len(item.Endpoints) > 0 return !item.IsExpired() || len(item.Endpoints) > 0
}) })
} }
// appendPeerChanges mutates a tailcfg.MapResponse with all the
// necessary changes when peers have changed.
func appendPeerChanges(
resp *tailcfg.MapResponse,
pol *policy.ACLPolicy,
machine *types.Machine,
peers types.Machines,
changed types.Machines,
baseDomain string,
dnsCfg *tailcfg.DNSConfig,
) error {
fullChange := len(peers) == len(changed)
rules, sshPolicy, err := policy.GenerateFilterAndSSHRules(
pol,
machine,
peers,
)
if err != nil {
return err
}
// Filter out peers that have expired.
changed = filterExpiredAndNotReady(changed)
// If there are filter rules present, see if there are any machines that cannot
// access eachother at all and remove them from the peers.
if len(rules) > 0 {
changed = policy.FilterMachinesByACL(machine, changed, rules)
}
profiles := generateUserProfiles(machine, changed, baseDomain)
dnsConfig := generateDNSConfig(
dnsCfg,
baseDomain,
machine,
peers,
)
tailPeers, err := tailNodes(changed, pol, dnsCfg, baseDomain)
if err != nil {
return err
}
// Peers is always returned sorted by Node.ID.
sort.SliceStable(tailPeers, func(x, y int) bool {
return tailPeers[x].ID < tailPeers[y].ID
})
if fullChange {
resp.Peers = tailPeers
} else {
resp.PeersChanged = tailPeers
}
resp.DNSConfig = dnsConfig
resp.PacketFilter = policy.ReduceFilterRules(machine, rules)
resp.UserProfiles = profiles
resp.SSHPolicy = sshPolicy
return nil
}

View file

@ -441,9 +441,11 @@ func Test_fullMapResponse(t *testing.T) {
}, },
}, },
}, },
UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}}, UserProfiles: []tailcfg.UserProfile{
SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}}, {LoginName: "mini", DisplayName: "mini"},
ControlTime: &time.Time{}, },
SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
ControlTime: &time.Time{},
Debug: &tailcfg.Debug{ Debug: &tailcfg.Debug{
DisableLogTail: true, DisableLogTail: true,
}, },
@ -454,17 +456,23 @@ func Test_fullMapResponse(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := fullMapResponse( mappy := NewMapper(
tt.pol,
tt.machine, tt.machine,
tt.peers, tt.peers,
nil,
false,
tt.derpMap,
tt.baseDomain, tt.baseDomain,
tt.dnsConfig, tt.dnsConfig,
tt.derpMap,
tt.logtail, tt.logtail,
tt.randomClientPort, tt.randomClientPort,
) )
got, err := mappy.fullMapResponse(
tt.machine,
tt.pol,
)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("fullMapResponse() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("fullMapResponse() error = %v, wantErr %v", err, tt.wantErr)

View file

@ -55,6 +55,8 @@ func logPollFunc(
// handlePoll is the common code for the legacy and Noise protocols to // handlePoll is the common code for the legacy and Noise protocols to
// managed the poll loop. // managed the poll loop.
//
//nolint:gocyclo
func (h *Headscale) handlePoll( func (h *Headscale) handlePoll(
writer http.ResponseWriter, writer http.ResponseWriter,
ctx context.Context, ctx context.Context,
@ -67,6 +69,7 @@ func (h *Headscale) handlePoll(
// following updates missing // following updates missing
var updateChan chan types.StateUpdate var updateChan chan types.StateUpdate
if mapRequest.Stream { if mapRequest.Stream {
h.pollStreamOpenMu.Lock()
h.pollNetMapStreamWG.Add(1) h.pollNetMapStreamWG.Add(1)
defer h.pollNetMapStreamWG.Done() defer h.pollNetMapStreamWG.Done()
@ -251,6 +254,8 @@ func (h *Headscale) handlePoll(
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
h.pollStreamOpenMu.Unlock()
for { for {
logInfo("Waiting for update on stream channel") logInfo("Waiting for update on stream channel")
select { select {

View file

@ -407,9 +407,8 @@ func TestResolveMagicDNS(t *testing.T) {
defer scenario.Shutdown() defer scenario.Shutdown()
spec := map[string]int{ spec := map[string]int{
// Omit 1.16.2 (-1) because it does not have the FQDN field "magicdns1": len(MustTestVersions),
"magicdns1": len(MustTestVersions) - 1, "magicdns2": len(MustTestVersions),
"magicdns2": len(MustTestVersions) - 1,
} }
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("magicdns")) err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("magicdns"))

View file

@ -20,10 +20,11 @@ import (
) )
const ( const (
tsicHashLength = 6 tsicHashLength = 6
defaultPingCount = 10 defaultPingTimeout = 300 * time.Millisecond
dockerContextPath = "../." defaultPingCount = 10
headscaleCertPath = "/usr/local/share/ca-certificates/headscale.crt" dockerContextPath = "../."
headscaleCertPath = "/usr/local/share/ca-certificates/headscale.crt"
) )
var ( var (
@ -591,7 +592,7 @@ func WithPingUntilDirect(direct bool) PingOption {
// TODO(kradalby): Make multiping, go routine magic. // TODO(kradalby): Make multiping, go routine magic.
func (t *TailscaleInContainer) Ping(hostnameOrIP string, opts ...PingOption) error { func (t *TailscaleInContainer) Ping(hostnameOrIP string, opts ...PingOption) error {
args := pingArgs{ args := pingArgs{
timeout: 300 * time.Millisecond, timeout: defaultPingTimeout,
count: defaultPingCount, count: defaultPingCount,
direct: true, direct: true,
} }