From e55fe0671a9abc51cbd954cf138cf59fd23682d5 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 26 Jul 2023 13:55:03 +0200 Subject: [PATCH] only send lite map responses when omitpeers Signed-off-by: Kristoffer Dalby --- hscontrol/mapper/mapper.go | 127 +++++++++++++++++++++++-------------- hscontrol/poll.go | 36 ++++++++--- 2 files changed, 104 insertions(+), 59 deletions(-) diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 8ae71c90..0d4ded5d 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -124,66 +124,19 @@ func fullMapResponse( return nil, err } - rules, sshPolicy, err := policy.GenerateFilterAndSSHRules( - pol, - machine, - peers, - ) - if err != nil { - return nil, err - } - - // Filter out peers that have expired. - peers = lo.Filter(peers, func(item types.Machine, index int) bool { - return !item.IsExpired() - }) - - // If there are filter rules present, see if there are any machines that cannot - // access eachother at all and remove them from the peers. - if len(rules) > 0 { - peers = policy.FilterMachinesByACL(machine, peers, rules) - } - - profiles := generateUserProfiles(machine, peers, baseDomain) - - dnsConfig := generateDNSConfig( - dnsCfg, - baseDomain, - *machine, - peers, - ) - - tailPeers, err := tailNodes(peers, pol, dnsCfg, baseDomain) - if err != nil { - return nil, err - } - - // Peers is always returned sorted by Node.ID. - sort.SliceStable(tailPeers, func(x, y int) bool { - return tailPeers[x].ID < tailPeers[y].ID - }) - now := time.Now() resp := tailcfg.MapResponse{ - Node: tailnode, - Peers: tailPeers, + Node: tailnode, DERPMap: derpMap, - DNSConfig: dnsConfig, - Domain: baseDomain, + Domain: baseDomain, // Do not instruct clients to collect services we do not // support or do anything with them CollectServices: "false", - PacketFilter: policy.ReduceFilterRules(machine, rules), - - UserProfiles: profiles, - - SSHPolicy: sshPolicy, - ControlTime: &now, KeepAlive: false, OnlineChange: db.OnlineMachineMap(peers), @@ -194,6 +147,53 @@ func fullMapResponse( }, } + if peers != nil || len(peers) > 0 { + rules, sshPolicy, err := policy.GenerateFilterAndSSHRules( + pol, + machine, + peers, + ) + if err != nil { + return nil, err + } + + // Filter out peers that have expired. + peers = lo.Filter(peers, func(item types.Machine, index int) bool { + return !item.IsExpired() + }) + + // If there are filter rules present, see if there are any machines that cannot + // access eachother at all and remove them from the peers. + if len(rules) > 0 { + peers = policy.FilterMachinesByACL(machine, peers, rules) + } + + profiles := generateUserProfiles(machine, peers, baseDomain) + + dnsConfig := generateDNSConfig( + dnsCfg, + baseDomain, + *machine, + peers, + ) + + tailPeers, err := tailNodes(peers, pol, dnsCfg, baseDomain) + if err != nil { + return nil, err + } + + // Peers is always returned sorted by Node.ID. + sort.SliceStable(tailPeers, func(x, y int) bool { + return tailPeers[x].ID < tailPeers[y].ID + }) + + resp.Peers = tailPeers + resp.DNSConfig = dnsConfig + resp.PacketFilter = policy.ReduceFilterRules(machine, rules) + resp.UserProfiles = profiles + resp.SSHPolicy = sshPolicy + } + return &resp, nil } @@ -327,6 +327,35 @@ func (m *Mapper) FullMapResponse( return m.marshalMapResponse(mapResponse, machine, mapRequest.Compress) } +// LiteMapResponse returns a MapResponse for the given machine. +// Lite means that the peers has been omited, this is intended +// to be used to answer MapRequests with OmitPeers set to true. +func (m *Mapper) LiteMapResponse( + mapRequest tailcfg.MapRequest, + machine *types.Machine, + pol *policy.ACLPolicy, +) ([]byte, error) { + mapResponse, err := fullMapResponse( + pol, + machine, + nil, + m.baseDomain, + m.dnsCfg, + m.derpMap, + m.logtail, + m.randomClientPort, + ) + if err != nil { + return nil, err + } + + if m.isNoise { + return m.marshalMapResponse(mapResponse, machine, mapRequest.Compress) + } + + return m.marshalMapResponse(mapResponse, machine, mapRequest.Compress) +} + func (m *Mapper) KeepAliveResponse( mapRequest tailcfg.MapRequest, machine *types.Machine, diff --git a/hscontrol/poll.go b/hscontrol/poll.go index 075e6825..b717e324 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -116,14 +116,6 @@ func (h *Headscale) handlePoll( return } - mapResp, err := mapp.FullMapResponse(mapRequest, machine, h.ACLPolicy) - if err != nil { - logErr(err, "Failed to create MapResponse") - http.Error(writer, "", http.StatusInternalServerError) - - return - } - // We update our peers if the client is not sending ReadOnly in the MapRequest // so we don't distribute its initial request (it comes with // empty endpoints to peers) @@ -134,9 +126,17 @@ func (h *Headscale) handlePoll( if mapRequest.ReadOnly { logInfo("Client is starting up. Probably interested in a DERP map") + mapResp, err := mapp.FullMapResponse(mapRequest, machine, h.ACLPolicy) + if err != nil { + logErr(err, "Failed to create MapResponse") + http.Error(writer, "", http.StatusInternalServerError) + + return + } + writer.Header().Set("Content-Type", "application/json; charset=utf-8") writer.WriteHeader(http.StatusOK) - _, err := writer.Write(mapResp) + _, err = writer.Write(mapResp) if err != nil { logErr(err, "Failed to write response") } @@ -151,9 +151,17 @@ func (h *Headscale) handlePoll( if mapRequest.OmitPeers && !mapRequest.Stream { logInfo("Client sent endpoint update and is ok with a response without peer list") + mapResp, err := mapp.LiteMapResponse(mapRequest, machine, h.ACLPolicy) + if err != nil { + logErr(err, "Failed to create MapResponse") + http.Error(writer, "", http.StatusInternalServerError) + + return + } + writer.Header().Set("Content-Type", "application/json; charset=utf-8") writer.WriteHeader(http.StatusOK) - _, err := writer.Write(mapResp) + _, err = writer.Write(mapResp) if err != nil { logErr(err, "Failed to write response") } @@ -183,6 +191,14 @@ func (h *Headscale) handlePoll( logInfo("Sending initial map") + mapResp, err := mapp.FullMapResponse(mapRequest, machine, h.ACLPolicy) + if err != nil { + logErr(err, "Failed to create MapResponse") + http.Error(writer, "", http.StatusInternalServerError) + + return + } + // Send the client an update to make sure we send an initial mapresponse _, err = writer.Write(mapResp) if err != nil {