diff --git a/machine.go b/machine.go index 2fa4d9e0..f97169b2 100644 --- a/machine.go +++ b/machine.go @@ -78,13 +78,13 @@ func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) { return machines, nil } +// getShared fetches machines that are shared to the `Namespace` of the machine we are getting peers for func (h *Headscale) getShared(m *Machine) (Machines, error) { log.Trace(). Str("func", "getShared"). Str("machine", m.Name). Msg("Finding shared peers") - // We fetch here machines that are shared to the `Namespace` of the machine we are getting peers for sharedMachines := []SharedMachine{} if err := h.db.Preload("Namespace").Preload("Machine").Where("namespace_id = ?", m.NamespaceID).Find(&sharedMachines).Error; err != nil { @@ -105,6 +105,37 @@ func (h *Headscale) getShared(m *Machine) (Machines, error) { return peers, nil } +// getSharedTo fetches the machines of the namespaces this machine is shared in +func (h *Headscale) getSharedTo(m *Machine) (Machines, error) { + log.Trace(). + Str("func", "getSharedTo"). + Str("machine", m.Name). + Msg("Finding peers in namespaces this machine is shared with") + + sharedMachines := []SharedMachine{} + if err := h.db.Preload("Namespace").Preload("Machine").Where("machine_id = ?", + m.ID).Find(&sharedMachines).Error; err != nil { + return Machines{}, err + } + + peers := make(Machines, 0) + for _, sharedMachine := range sharedMachines { + namespaceMachines, err := h.ListMachinesInNamespace(sharedMachine.Namespace.Name) + if err != nil { + return Machines{}, err + } + peers = append(peers, *namespaceMachines...) + } + + sort.Slice(peers, func(i, j int) bool { return peers[i].ID < peers[j].ID }) + + log.Trace(). + Str("func", "getSharedTo"). + Str("machine", m.Name). + Msgf("Found peers we are shared with: %s", peers.String()) + return peers, nil +} + func (h *Headscale) getPeers(m *Machine) (Machines, error) { direct, err := h.getDirectPeers(m) if err != nil { @@ -118,13 +149,24 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) { shared, err := h.getShared(m) if err != nil { log.Error(). - Str("func", "getDirectPeers"). + Str("func", "getShared"). + Err(err). + Msg("Cannot fetch peers") + return Machines{}, err + } + + sharedTo, err := h.getSharedTo(m) + if err != nil { + log.Error(). + Str("func", "sharedTo"). Err(err). Msg("Cannot fetch peers") return Machines{}, err } peers := append(direct, shared...) + peers = append(peers, sharedTo...) + sort.Slice(peers, func(i, j int) bool { return peers[i].ID < peers[j].ID }) log.Trace().