mirror of
https://github.com/juanfont/headscale.git
synced 2024-11-29 18:33:05 +00:00
optimize generateACLPeerCacheMap (#1377)
This commit is contained in:
parent
6215eb6471
commit
d0113732fe
3 changed files with 69 additions and 47 deletions
15
acls.go
15
acls.go
|
@ -163,23 +163,20 @@ func (h *Headscale) UpdateACLRules() error {
|
|||
// generateACLPeerCacheMap takes a list of Tailscale filter rules and generates a map
|
||||
// of which Sources ("*" and IPs) can access destinations. This is to speed up the
|
||||
// process of generating MapResponses when deciding which Peers to inform nodes about.
|
||||
func generateACLPeerCacheMap(rules []tailcfg.FilterRule) map[string]map[string]struct{} {
|
||||
aclCachePeerMap := make(map[string]map[string]struct{})
|
||||
func generateACLPeerCacheMap(rules []tailcfg.FilterRule) map[string][]string {
|
||||
aclCachePeerMap := make(map[string][]string)
|
||||
for _, rule := range rules {
|
||||
for _, srcIP := range rule.SrcIPs {
|
||||
for _, ip := range expandACLPeerAddr(srcIP) {
|
||||
if data, ok := aclCachePeerMap[ip]; ok {
|
||||
for _, dstPort := range rule.DstPorts {
|
||||
for _, dstIP := range expandACLPeerAddr(dstPort.IP) {
|
||||
data[dstIP] = struct{}{}
|
||||
}
|
||||
data = append(data, dstPort.IP)
|
||||
}
|
||||
aclCachePeerMap[ip] = data
|
||||
} else {
|
||||
dstPortsMap := make(map[string]struct{}, len(rule.DstPorts))
|
||||
dstPortsMap := make([]string, 0)
|
||||
for _, dstPort := range rule.DstPorts {
|
||||
for _, dstIP := range expandACLPeerAddr(dstPort.IP) {
|
||||
dstPortsMap[dstIP] = struct{}{}
|
||||
}
|
||||
dstPortsMap = append(dstPortsMap, dstPort.IP)
|
||||
}
|
||||
aclCachePeerMap[ip] = dstPortsMap
|
||||
}
|
||||
|
|
2
app.go
2
app.go
|
@ -87,7 +87,7 @@ type Headscale struct {
|
|||
aclPolicy *ACLPolicy
|
||||
aclRules []tailcfg.FilterRule
|
||||
aclPeerCacheMapRW sync.RWMutex
|
||||
aclPeerCacheMap map[string]map[string]struct{}
|
||||
aclPeerCacheMap map[string][]string
|
||||
sshPolicy *tailcfg.SSHPolicy
|
||||
|
||||
lastStateChange *xsync.MapOf[string, time.Time]
|
||||
|
|
41
machine.go
41
machine.go
|
@ -4,6 +4,7 @@ import (
|
|||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sort"
|
||||
"strconv"
|
||||
|
@ -172,7 +173,7 @@ func filterMachinesByACL(
|
|||
machine *Machine,
|
||||
machines Machines,
|
||||
lock *sync.RWMutex,
|
||||
aclPeerCacheMap map[string]map[string]struct{},
|
||||
aclPeerCacheMap map[string][]string,
|
||||
) Machines {
|
||||
log.Trace().
|
||||
Caller().
|
||||
|
@ -197,43 +198,59 @@ func filterMachinesByACL(
|
|||
|
||||
if dstMap, ok := aclPeerCacheMap["*"]; ok {
|
||||
// match source and all destination
|
||||
if _, dstOk := dstMap["*"]; dstOk {
|
||||
|
||||
for _, dst := range dstMap {
|
||||
if dst == "*" {
|
||||
peers[peer.ID] = peer
|
||||
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// match source and all destination
|
||||
for _, peerIP := range peerIPs {
|
||||
if _, dstOk := dstMap[peerIP]; dstOk {
|
||||
for _, dst := range dstMap {
|
||||
_, cdr, _ := net.ParseCIDR(dst)
|
||||
ip := net.ParseIP(peerIP)
|
||||
if dst == peerIP || (cdr != nil && ip != nil && cdr.Contains(ip)) {
|
||||
peers[peer.ID] = peer
|
||||
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// match all sources and source
|
||||
for _, machineIP := range machineIPs {
|
||||
if _, dstOk := dstMap[machineIP]; dstOk {
|
||||
for _, dst := range dstMap {
|
||||
_, cdr, _ := net.ParseCIDR(dst)
|
||||
ip := net.ParseIP(machineIP)
|
||||
if dst == machineIP || (cdr != nil && ip != nil && cdr.Contains(ip)) {
|
||||
peers[peer.ID] = peer
|
||||
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, machineIP := range machineIPs {
|
||||
if dstMap, ok := aclPeerCacheMap[machineIP]; ok {
|
||||
// match source and all destination
|
||||
if _, dstOk := dstMap["*"]; dstOk {
|
||||
for _, dst := range dstMap {
|
||||
if dst == "*" {
|
||||
peers[peer.ID] = peer
|
||||
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// match source and destination
|
||||
for _, peerIP := range peerIPs {
|
||||
if _, dstOk := dstMap[peerIP]; dstOk {
|
||||
for _, dst := range dstMap {
|
||||
_, cdr, _ := net.ParseCIDR(dst)
|
||||
ip := net.ParseIP(peerIP)
|
||||
if dst == peerIP || (cdr != nil && ip != nil && cdr.Contains(ip)) {
|
||||
peers[peer.ID] = peer
|
||||
|
||||
continue
|
||||
|
@ -241,18 +258,25 @@ func filterMachinesByACL(
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, peerIP := range peerIPs {
|
||||
if dstMap, ok := aclPeerCacheMap[peerIP]; ok {
|
||||
// match source and all destination
|
||||
if _, dstOk := dstMap["*"]; dstOk {
|
||||
for _, dst := range dstMap {
|
||||
if dst == "*" {
|
||||
peers[peer.ID] = peer
|
||||
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// match return path
|
||||
for _, machineIP := range machineIPs {
|
||||
if _, dstOk := dstMap[machineIP]; dstOk {
|
||||
for _, dst := range dstMap {
|
||||
_, cdr, _ := net.ParseCIDR(dst)
|
||||
ip := net.ParseIP(machineIP)
|
||||
if dst == machineIP || (cdr != nil && ip != nil && cdr.Contains(ip)) {
|
||||
peers[peer.ID] = peer
|
||||
|
||||
continue
|
||||
|
@ -261,6 +285,7 @@ func filterMachinesByACL(
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
lock.RUnlock()
|
||||
|
||||
|
|
Loading…
Reference in a new issue