fix(acls,machines): apply code review suggestions

This commit is contained in:
Adrien Raffin-Caboisse 2022-02-20 21:24:02 +01:00
parent 4f9ece14c5
commit d00251c63e
3 changed files with 25 additions and 22 deletions

View file

@ -204,7 +204,7 @@ func expandAlias(
return ips, err return ips, err
} }
for _, n := range namespaces { for _, n := range namespaces {
nodes := listMachinesInNamespace(machines, n) nodes := filterMachinesByNamespace(machines, n)
for _, node := range nodes { for _, node := range nodes {
ips = append(ips, node.IPAddresses.ToStringSlice()...) ips = append(ips, node.IPAddresses.ToStringSlice()...)
} }
@ -219,7 +219,7 @@ func expandAlias(
return ips, err return ips, err
} }
for _, namespace := range owners { for _, namespace := range owners {
machines := listMachinesInNamespace(machines, namespace) machines := filterMachinesByNamespace(machines, namespace)
for _, machine := range machines { for _, machine := range machines {
if len(machine.HostInfo) == 0 { if len(machine.HostInfo) == 0 {
continue continue
@ -240,7 +240,7 @@ func expandAlias(
} }
// if alias is a namespace // if alias is a namespace
nodes := listMachinesInNamespace(machines, alias) nodes := filterMachinesByNamespace(machines, alias)
nodes, err := excludeCorrectlyTaggedNodes(aclPolicy, nodes, alias) nodes, err := excludeCorrectlyTaggedNodes(aclPolicy, nodes, alias)
if err != nil { if err != nil {
return ips, err return ips, err
@ -357,7 +357,7 @@ func expandPorts(portsStr string) (*[]tailcfg.PortRange, error) {
return &ports, nil return &ports, nil
} }
func listMachinesInNamespace(machines []Machine, namespace string) []Machine { func filterMachinesByNamespace(machines []Machine, namespace string) []Machine {
out := []Machine{} out := []Machine{}
for _, machine := range machines { for _, machine := range machines {
if machine.Namespace.Name == namespace { if machine.Namespace.Name == namespace {

View file

@ -687,7 +687,7 @@ func Test_listMachinesInNamespace(t *testing.T) {
} }
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
if got := listMachinesInNamespace(test.args.machines, test.args.namespace); !reflect.DeepEqual( if got := filterMachinesByNamespace(test.args.machines, test.args.namespace); !reflect.DeepEqual(
got, got,
test.want, test.want,
) { ) {

View file

@ -142,6 +142,16 @@ func containsAddresses(inputs []string, addrs MachineAddresses) bool {
return false return false
} }
// matchSourceAndDestinationWithRule will check if source is authorized to communicate with destination through
// the given rule.
func matchSourceAndDestinationWithRule(rule tailcfg.FilterRule, source Machine, destination Machine) bool {
var dst []string
for _, d := range rule.DstPorts {
dst = append(dst, d.IP)
}
return (containsAddresses(rule.SrcIPs, source.IPAddresses) && containsAddresses(dst, destination.IPAddresses)) || containsString(dst, "*")
}
// getFilteredByACLPeerss should return the list of peers authorized to be accessed from machine. // getFilteredByACLPeerss should return the list of peers authorized to be accessed from machine.
func (h *Headscale) getFilteredByACLPeers(machine *Machine) (Machines, error) { func (h *Headscale) getFilteredByACLPeers(machine *Machine) (Machines, error) {
log.Trace(). log.Trace().
@ -149,14 +159,12 @@ func (h *Headscale) getFilteredByACLPeers(machine *Machine) (Machines, error) {
Str("machine", machine.Name). Str("machine", machine.Name).
Msg("Finding peers filtered by ACLs") Msg("Finding peers filtered by ACLs")
machines := Machines{} machines, err := h.ListAllMachines()
if err := h.db.Preload("Namespace").Where("machine_key <> ? AND registered", if err != nil {
machine.MachineKey).Find(&machines).Error; err != nil { log.Error().Err(err).Msg("Error retrieving list of machines")
log.Error().Err(err).Msg("Error accessing db")
return Machines{}, err return Machines{}, err
} }
mMachines := make(map[uint64]Machine) peers := make(map[uint64]Machine)
// Aclfilter peers here. We are itering through machines in all namespaces and search through the computed aclRules // Aclfilter peers here. We are itering through machines in all namespaces and search through the computed aclRules
// for match between rule SrcIPs and DstPorts. If the rule is a match we allow the machine to be viewable. // for match between rule SrcIPs and DstPorts. If the rule is a match we allow the machine to be viewable.
@ -175,21 +183,16 @@ func (h *Headscale) getFilteredByACLPeers(machine *Machine) (Machines, error) {
// In order to do this we would need to be able to identify that node A want to talk to node B but that Node B doesn't know // In order to do this we would need to be able to identify that node A want to talk to node B but that Node B doesn't know
// how to talk to node A and then add the peering resource. // how to talk to node A and then add the peering resource.
for _, mchn := range machines { for _, peer := range machines {
for _, rule := range h.aclRules { for _, rule := range h.aclRules {
var dst []string if matchSourceAndDestinationWithRule(rule, *machine, peer) || matchSourceAndDestinationWithRule(rule, peer, *machine) {
for _, d := range rule.DstPorts { peers[peer.ID] = peer
dst = append(dst, d.IP)
}
if (containsAddresses(rule.SrcIPs, machine.IPAddresses) && (containsAddresses(dst, mchn.IPAddresses) || containsString(dst, "*"))) ||
(containsAddresses(rule.SrcIPs, mchn.IPAddresses) && containsAddresses(dst, machine.IPAddresses)) {
mMachines[mchn.ID] = mchn
} }
} }
} }
authorizedMachines := make([]Machine, 0, len(mMachines)) authorizedMachines := make([]Machine, 0, len(peers))
for _, m := range mMachines { for _, m := range peers {
authorizedMachines = append(authorizedMachines, m) authorizedMachines = append(authorizedMachines, m)
} }
sort.Slice( sort.Slice(
@ -200,7 +203,7 @@ func (h *Headscale) getFilteredByACLPeers(machine *Machine) (Machines, error) {
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine", machine.Name). Str("machine", machine.Name).
Msgf("Found some machines: %s", machines.String()) Msgf("Found some machines: %v", machines)
return authorizedMachines, nil return authorizedMachines, nil
} }