mirror of
https://github.com/juanfont/headscale.git
synced 2024-11-26 08:53:05 +00:00
move MapResponse peer logic into function and reuse
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
387aa03adb
commit
432e975a7f
7 changed files with 193 additions and 173 deletions
|
@ -92,6 +92,8 @@ type Headscale struct {
|
|||
|
||||
shutdownChan chan struct{}
|
||||
pollNetMapStreamWG sync.WaitGroup
|
||||
|
||||
pollStreamOpenMu sync.Mutex
|
||||
}
|
||||
|
||||
func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||
|
|
|
@ -340,6 +340,8 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error {
|
|||
continue
|
||||
}
|
||||
|
||||
machine := &route.Machine
|
||||
|
||||
if !route.IsPrimary {
|
||||
_, err := hsdb.getPrimaryRoute(netip.Prefix(route.Prefix))
|
||||
if hsdb.isUniquePrefix(route) || errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
|
@ -355,7 +357,7 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error {
|
|||
return err
|
||||
}
|
||||
|
||||
changedMachines = append(changedMachines, &route.Machine)
|
||||
changedMachines = append(changedMachines, machine)
|
||||
|
||||
continue
|
||||
}
|
||||
|
@ -429,7 +431,7 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error {
|
|||
return err
|
||||
}
|
||||
|
||||
changedMachines = append(changedMachines, &route.Machine)
|
||||
changedMachines = append(changedMachines, machine)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -38,6 +38,16 @@ const (
|
|||
|
||||
var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH")
|
||||
|
||||
// TODO: Optimise
|
||||
// As this work continues, the idea is that there will be one Mapper instance
|
||||
// per node, attached to the open stream between the control and client.
|
||||
// This means that this can hold a state per machine and we can use that to
|
||||
// improve the mapresponses sent.
|
||||
// We could:
|
||||
// - Keep information about the previous mapresponse so we can send a diff
|
||||
// - Store hashes
|
||||
// - Create a "minifier" that removes info not needed for the node
|
||||
|
||||
type Mapper struct {
|
||||
privateKey2019 *key.MachinePrivate
|
||||
isNoise bool
|
||||
|
@ -102,105 +112,6 @@ func (m *Mapper) String() string {
|
|||
return fmt.Sprintf("Mapper: { seq: %d, uid: %s, created: %s }", m.seq, m.uid, m.created)
|
||||
}
|
||||
|
||||
// TODO: Optimise
|
||||
// As this work continues, the idea is that there will be one Mapper instance
|
||||
// per node, attached to the open stream between the control and client.
|
||||
// This means that this can hold a state per machine and we can use that to
|
||||
// improve the mapresponses sent.
|
||||
// We could:
|
||||
// - Keep information about the previous mapresponse so we can send a diff
|
||||
// - Store hashes
|
||||
// - Create a "minifier" that removes info not needed for the node
|
||||
|
||||
// fullMapResponse is the internal function for generating a MapResponse
|
||||
// for a machine.
|
||||
func fullMapResponse(
|
||||
pol *policy.ACLPolicy,
|
||||
machine *types.Machine,
|
||||
peers types.Machines,
|
||||
|
||||
baseDomain string,
|
||||
dnsCfg *tailcfg.DNSConfig,
|
||||
derpMap *tailcfg.DERPMap,
|
||||
logtail bool,
|
||||
randomClientPort bool,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
tailnode, err := tailNode(machine, pol, dnsCfg, baseDomain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
resp := tailcfg.MapResponse{
|
||||
Node: tailnode,
|
||||
|
||||
DERPMap: derpMap,
|
||||
|
||||
Domain: baseDomain,
|
||||
|
||||
// Do not instruct clients to collect services we do not
|
||||
// support or do anything with them
|
||||
CollectServices: "false",
|
||||
|
||||
ControlTime: &now,
|
||||
KeepAlive: false,
|
||||
OnlineChange: db.OnlineMachineMap(peers),
|
||||
|
||||
Debug: &tailcfg.Debug{
|
||||
DisableLogTail: !logtail,
|
||||
RandomizeClientPort: randomClientPort,
|
||||
},
|
||||
}
|
||||
|
||||
if peers != nil || len(peers) > 0 {
|
||||
rules, sshPolicy, err := policy.GenerateFilterAndSSHRules(
|
||||
pol,
|
||||
machine,
|
||||
peers,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Filter out peers that have expired.
|
||||
peers = filterExpiredAndNotReady(peers)
|
||||
|
||||
// If there are filter rules present, see if there are any machines that cannot
|
||||
// access eachother at all and remove them from the peers.
|
||||
if len(rules) > 0 {
|
||||
peers = policy.FilterMachinesByACL(machine, peers, rules)
|
||||
}
|
||||
|
||||
profiles := generateUserProfiles(machine, peers, baseDomain)
|
||||
|
||||
dnsConfig := generateDNSConfig(
|
||||
dnsCfg,
|
||||
baseDomain,
|
||||
machine,
|
||||
peers,
|
||||
)
|
||||
|
||||
tailPeers, err := tailNodes(peers, pol, dnsCfg, baseDomain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Peers is always returned sorted by Node.ID.
|
||||
sort.SliceStable(tailPeers, func(x, y int) bool {
|
||||
return tailPeers[x].ID < tailPeers[y].ID
|
||||
})
|
||||
|
||||
resp.Peers = tailPeers
|
||||
resp.DNSConfig = dnsConfig
|
||||
resp.PacketFilter = policy.ReduceFilterRules(machine, rules)
|
||||
resp.UserProfiles = profiles
|
||||
resp.SSHPolicy = sshPolicy
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func generateUserProfiles(
|
||||
machine *types.Machine,
|
||||
peers types.Machines,
|
||||
|
@ -294,6 +205,38 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine *types.Machine) {
|
|||
}
|
||||
}
|
||||
|
||||
// fullMapResponse creates a complete MapResponse for a node.
|
||||
// It is a separate function to make testing easier.
|
||||
func (m *Mapper) fullMapResponse(
|
||||
machine *types.Machine,
|
||||
pol *policy.ACLPolicy,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
peers := machineMapToList(m.peers)
|
||||
|
||||
resp, err := m.baseWithConfigMapResponse(machine, pol)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO(kradalby): Move this into appendPeerChanges?
|
||||
resp.OnlineChange = db.OnlineMachineMap(peers)
|
||||
|
||||
err = appendPeerChanges(
|
||||
resp,
|
||||
pol,
|
||||
machine,
|
||||
peers,
|
||||
peers,
|
||||
m.baseDomain,
|
||||
m.dnsCfg,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// FullMapResponse returns a MapResponse for the given machine.
|
||||
func (m *Mapper) FullMapResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
|
@ -303,25 +246,16 @@ func (m *Mapper) FullMapResponse(
|
|||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
mapResponse, err := fullMapResponse(
|
||||
pol,
|
||||
machine,
|
||||
machineMapToList(m.peers),
|
||||
m.baseDomain,
|
||||
m.dnsCfg,
|
||||
m.derpMap,
|
||||
m.logtail,
|
||||
m.randomClientPort,
|
||||
)
|
||||
resp, err := m.fullMapResponse(machine, pol)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if m.isNoise {
|
||||
return m.marshalMapResponse(mapRequest, mapResponse, machine, mapRequest.Compress)
|
||||
return m.marshalMapResponse(mapRequest, resp, machine, mapRequest.Compress)
|
||||
}
|
||||
|
||||
return m.marshalMapResponse(mapRequest, mapResponse, machine, mapRequest.Compress)
|
||||
return m.marshalMapResponse(mapRequest, resp, machine, mapRequest.Compress)
|
||||
}
|
||||
|
||||
// LiteMapResponse returns a MapResponse for the given machine.
|
||||
|
@ -332,32 +266,23 @@ func (m *Mapper) LiteMapResponse(
|
|||
machine *types.Machine,
|
||||
pol *policy.ACLPolicy,
|
||||
) ([]byte, error) {
|
||||
mapResponse, err := fullMapResponse(
|
||||
pol,
|
||||
machine,
|
||||
nil,
|
||||
m.baseDomain,
|
||||
m.dnsCfg,
|
||||
m.derpMap,
|
||||
m.logtail,
|
||||
m.randomClientPort,
|
||||
)
|
||||
resp, err := m.baseWithConfigMapResponse(machine, pol)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if m.isNoise {
|
||||
return m.marshalMapResponse(mapRequest, mapResponse, machine, mapRequest.Compress)
|
||||
return m.marshalMapResponse(mapRequest, resp, machine, mapRequest.Compress)
|
||||
}
|
||||
|
||||
return m.marshalMapResponse(mapRequest, mapResponse, machine, mapRequest.Compress)
|
||||
return m.marshalMapResponse(mapRequest, resp, machine, mapRequest.Compress)
|
||||
}
|
||||
|
||||
func (m *Mapper) KeepAliveResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
machine *types.Machine,
|
||||
) ([]byte, error) {
|
||||
resp := m.baseMapResponse(machine)
|
||||
resp := m.baseMapResponse()
|
||||
resp.KeepAlive = true
|
||||
|
||||
return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress)
|
||||
|
@ -368,7 +293,7 @@ func (m *Mapper) DERPMapResponse(
|
|||
machine *types.Machine,
|
||||
derpMap tailcfg.DERPMap,
|
||||
) ([]byte, error) {
|
||||
resp := m.baseMapResponse(machine)
|
||||
resp := m.baseMapResponse()
|
||||
resp.DERPMap = &derpMap
|
||||
|
||||
return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress)
|
||||
|
@ -383,7 +308,6 @@ func (m *Mapper) PeerChangedResponse(
|
|||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
var err error
|
||||
lastSeen := make(map[tailcfg.NodeID]bool)
|
||||
|
||||
// Update our internal map.
|
||||
|
@ -394,37 +318,21 @@ func (m *Mapper) PeerChangedResponse(
|
|||
lastSeen[tailcfg.NodeID(machine.ID)] = true
|
||||
}
|
||||
|
||||
rules, sshPolicy, err := policy.GenerateFilterAndSSHRules(
|
||||
resp := m.baseMapResponse()
|
||||
|
||||
err := appendPeerChanges(
|
||||
&resp,
|
||||
pol,
|
||||
machine,
|
||||
machineMapToList(m.peers),
|
||||
changed,
|
||||
m.baseDomain,
|
||||
m.dnsCfg,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
changed = filterExpiredAndNotReady(changed)
|
||||
|
||||
// If there are filter rules present, see if there are any machines that cannot
|
||||
// access eachother at all and remove them from the changed.
|
||||
if len(rules) > 0 {
|
||||
changed = policy.FilterMachinesByACL(machine, changed, rules)
|
||||
}
|
||||
|
||||
tailPeers, err := tailNodes(changed, pol, m.dnsCfg, m.baseDomain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Peers is always returned sorted by Node.ID.
|
||||
sort.SliceStable(tailPeers, func(x, y int) bool {
|
||||
return tailPeers[x].ID < tailPeers[y].ID
|
||||
})
|
||||
|
||||
resp := m.baseMapResponse(machine)
|
||||
resp.PeersChanged = tailPeers
|
||||
resp.PacketFilter = policy.ReduceFilterRules(machine, rules)
|
||||
resp.SSHPolicy = sshPolicy
|
||||
// resp.PeerSeenChange = lastSeen
|
||||
|
||||
return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress)
|
||||
|
@ -443,7 +351,7 @@ func (m *Mapper) PeerRemovedResponse(
|
|||
delete(m.peers, uint64(id))
|
||||
}
|
||||
|
||||
resp := m.baseMapResponse(machine)
|
||||
resp := m.baseMapResponse()
|
||||
resp.PeersRemoved = removed
|
||||
|
||||
return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress)
|
||||
|
@ -497,7 +405,7 @@ func (m *Mapper) marshalMapResponse(
|
|||
panic(err)
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
now := time.Now().UnixNano()
|
||||
|
||||
mapResponsePath := path.Join(
|
||||
mPath,
|
||||
|
@ -583,7 +491,9 @@ var zstdEncoderPool = &sync.Pool{
|
|||
},
|
||||
}
|
||||
|
||||
func (m *Mapper) baseMapResponse(_ *types.Machine) tailcfg.MapResponse {
|
||||
// baseMapResponse returns a tailcfg.MapResponse with
|
||||
// KeepAlive false and ControlTime set to now.
|
||||
func (m *Mapper) baseMapResponse() tailcfg.MapResponse {
|
||||
now := time.Now()
|
||||
|
||||
resp := tailcfg.MapResponse{
|
||||
|
@ -591,14 +501,43 @@ func (m *Mapper) baseMapResponse(_ *types.Machine) tailcfg.MapResponse {
|
|||
ControlTime: &now,
|
||||
}
|
||||
|
||||
// online, err := m.db.ListOnlineMachines(machine)
|
||||
// if err == nil {
|
||||
// resp.OnlineChange = online
|
||||
// }
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// baseWithConfigMapResponse returns a tailcfg.MapResponse struct
|
||||
// with the basic configuration from headscale set.
|
||||
// It is used in for bigger updates, such as full and lite, not
|
||||
// incremental.
|
||||
func (m *Mapper) baseWithConfigMapResponse(
|
||||
machine *types.Machine,
|
||||
pol *policy.ACLPolicy,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
resp := m.baseMapResponse()
|
||||
|
||||
tailnode, err := tailNode(machine, pol, m.dnsCfg, m.baseDomain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp.Node = tailnode
|
||||
|
||||
resp.DERPMap = m.derpMap
|
||||
|
||||
resp.Domain = m.baseDomain
|
||||
|
||||
// Do not instruct clients to collect services we do not
|
||||
// support or do anything with them
|
||||
resp.CollectServices = "false"
|
||||
|
||||
resp.KeepAlive = false
|
||||
|
||||
resp.Debug = &tailcfg.Debug{
|
||||
DisableLogTail: !m.logtail,
|
||||
RandomizeClientPort: m.randomClientPort,
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func machineMapToList(machines map[uint64]*types.Machine) types.Machines {
|
||||
ret := make(types.Machines, 0)
|
||||
|
||||
|
@ -617,3 +556,67 @@ func filterExpiredAndNotReady(peers types.Machines) types.Machines {
|
|||
return !item.IsExpired() || len(item.Endpoints) > 0
|
||||
})
|
||||
}
|
||||
|
||||
// appendPeerChanges mutates a tailcfg.MapResponse with all the
|
||||
// necessary changes when peers have changed.
|
||||
func appendPeerChanges(
|
||||
resp *tailcfg.MapResponse,
|
||||
|
||||
pol *policy.ACLPolicy,
|
||||
machine *types.Machine,
|
||||
peers types.Machines,
|
||||
changed types.Machines,
|
||||
baseDomain string,
|
||||
dnsCfg *tailcfg.DNSConfig,
|
||||
) error {
|
||||
fullChange := len(peers) == len(changed)
|
||||
|
||||
rules, sshPolicy, err := policy.GenerateFilterAndSSHRules(
|
||||
pol,
|
||||
machine,
|
||||
peers,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Filter out peers that have expired.
|
||||
changed = filterExpiredAndNotReady(changed)
|
||||
|
||||
// If there are filter rules present, see if there are any machines that cannot
|
||||
// access eachother at all and remove them from the peers.
|
||||
if len(rules) > 0 {
|
||||
changed = policy.FilterMachinesByACL(machine, changed, rules)
|
||||
}
|
||||
|
||||
profiles := generateUserProfiles(machine, changed, baseDomain)
|
||||
|
||||
dnsConfig := generateDNSConfig(
|
||||
dnsCfg,
|
||||
baseDomain,
|
||||
machine,
|
||||
peers,
|
||||
)
|
||||
|
||||
tailPeers, err := tailNodes(changed, pol, dnsCfg, baseDomain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Peers is always returned sorted by Node.ID.
|
||||
sort.SliceStable(tailPeers, func(x, y int) bool {
|
||||
return tailPeers[x].ID < tailPeers[y].ID
|
||||
})
|
||||
|
||||
if fullChange {
|
||||
resp.Peers = tailPeers
|
||||
} else {
|
||||
resp.PeersChanged = tailPeers
|
||||
}
|
||||
resp.DNSConfig = dnsConfig
|
||||
resp.PacketFilter = policy.ReduceFilterRules(machine, rules)
|
||||
resp.UserProfiles = profiles
|
||||
resp.SSHPolicy = sshPolicy
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -441,9 +441,11 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}},
|
||||
SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
|
||||
ControlTime: &time.Time{},
|
||||
UserProfiles: []tailcfg.UserProfile{
|
||||
{LoginName: "mini", DisplayName: "mini"},
|
||||
},
|
||||
SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
|
||||
ControlTime: &time.Time{},
|
||||
Debug: &tailcfg.Debug{
|
||||
DisableLogTail: true,
|
||||
},
|
||||
|
@ -454,17 +456,23 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := fullMapResponse(
|
||||
tt.pol,
|
||||
mappy := NewMapper(
|
||||
tt.machine,
|
||||
tt.peers,
|
||||
nil,
|
||||
false,
|
||||
tt.derpMap,
|
||||
tt.baseDomain,
|
||||
tt.dnsConfig,
|
||||
tt.derpMap,
|
||||
tt.logtail,
|
||||
tt.randomClientPort,
|
||||
)
|
||||
|
||||
got, err := mappy.fullMapResponse(
|
||||
tt.machine,
|
||||
tt.pol,
|
||||
)
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("fullMapResponse() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
||||
|
|
|
@ -55,6 +55,8 @@ func logPollFunc(
|
|||
|
||||
// handlePoll is the common code for the legacy and Noise protocols to
|
||||
// managed the poll loop.
|
||||
//
|
||||
//nolint:gocyclo
|
||||
func (h *Headscale) handlePoll(
|
||||
writer http.ResponseWriter,
|
||||
ctx context.Context,
|
||||
|
@ -67,6 +69,7 @@ func (h *Headscale) handlePoll(
|
|||
// following updates missing
|
||||
var updateChan chan types.StateUpdate
|
||||
if mapRequest.Stream {
|
||||
h.pollStreamOpenMu.Lock()
|
||||
h.pollNetMapStreamWG.Add(1)
|
||||
defer h.pollNetMapStreamWG.Done()
|
||||
|
||||
|
@ -251,6 +254,8 @@ func (h *Headscale) handlePoll(
|
|||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
h.pollStreamOpenMu.Unlock()
|
||||
|
||||
for {
|
||||
logInfo("Waiting for update on stream channel")
|
||||
select {
|
||||
|
|
|
@ -407,9 +407,8 @@ func TestResolveMagicDNS(t *testing.T) {
|
|||
defer scenario.Shutdown()
|
||||
|
||||
spec := map[string]int{
|
||||
// Omit 1.16.2 (-1) because it does not have the FQDN field
|
||||
"magicdns1": len(MustTestVersions) - 1,
|
||||
"magicdns2": len(MustTestVersions) - 1,
|
||||
"magicdns1": len(MustTestVersions),
|
||||
"magicdns2": len(MustTestVersions),
|
||||
}
|
||||
|
||||
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("magicdns"))
|
||||
|
|
|
@ -20,10 +20,11 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
tsicHashLength = 6
|
||||
defaultPingCount = 10
|
||||
dockerContextPath = "../."
|
||||
headscaleCertPath = "/usr/local/share/ca-certificates/headscale.crt"
|
||||
tsicHashLength = 6
|
||||
defaultPingTimeout = 300 * time.Millisecond
|
||||
defaultPingCount = 10
|
||||
dockerContextPath = "../."
|
||||
headscaleCertPath = "/usr/local/share/ca-certificates/headscale.crt"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -591,7 +592,7 @@ func WithPingUntilDirect(direct bool) PingOption {
|
|||
// TODO(kradalby): Make multiping, go routine magic.
|
||||
func (t *TailscaleInContainer) Ping(hostnameOrIP string, opts ...PingOption) error {
|
||||
args := pingArgs{
|
||||
timeout: 300 * time.Millisecond,
|
||||
timeout: defaultPingTimeout,
|
||||
count: defaultPingCount,
|
||||
direct: true,
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue