Fix utils tests, use ipset datastructure

This commit is contained in:
Kristoffer Dalby 2022-02-25 08:28:22 +00:00
parent eda0a9f88a
commit ebe59a5a27
2 changed files with 36 additions and 11 deletions

View file

@ -12,7 +12,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net" "net"
"sort"
"strings" "strings"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -190,7 +189,7 @@ func (h *Headscale) getAvailableIP(ipPrefix netaddr.IPPrefix) (*netaddr.IP, erro
} }
} }
func (h *Headscale) getUsedIPs() (netaddr.IPSet, error) { func (h *Headscale) getUsedIPs() (*netaddr.IPSet, error) {
// FIXME: This really deserves a better data model, // FIXME: This really deserves a better data model,
// but this was quick to get running and it should be enough // but this was quick to get running and it should be enough
// to begin experimenting with a dual stack tailnet. // to begin experimenting with a dual stack tailnet.
@ -206,7 +205,7 @@ func (h *Headscale) getUsedIPs() (netaddr.IPSet, error) {
var machineAddresses MachineAddresses var machineAddresses MachineAddresses
err := machineAddresses.Scan(slice) err := machineAddresses.Scan(slice)
if err != nil { if err != nil {
return netaddr.IPSet{}, fmt.Errorf( return &netaddr.IPSet{}, fmt.Errorf(
"failed to read ip from database: %w", "failed to read ip from database: %w",
err, err,
) )
@ -221,7 +220,15 @@ func (h *Headscale) getUsedIPs() (netaddr.IPSet, error) {
Interface("addresses", ips). Interface("addresses", ips).
Msg("Parsed ip addresses that has been allocated from databases") Msg("Parsed ip addresses that has been allocated from databases")
return netaddr.IPSet{}, nil ipSet, err := ips.IPSet()
if err != nil {
return &netaddr.IPSet{}, fmt.Errorf(
"failed to build IP Set: %w",
err,
)
}
return ipSet, nil
} }
func containsString(ss []string, s string) bool { func containsString(ss []string, s string) bool {

View file

@ -48,9 +48,12 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
expected := netaddr.MustParseIP("10.27.0.1") expected := netaddr.MustParseIP("10.27.0.1")
expectedIPSetBuilder := netaddr.IPSetBuilder{}
expectedIPSetBuilder.Add(expected)
expectedIPSet, _ := expectedIPSetBuilder.IPSet()
c.Assert(len(usedIps), check.Equals, 1) c.Assert(usedIps.Equal(expectedIPSet), check.Equals, true)
c.Assert(usedIps[0], check.Equals, expected) c.Assert(usedIps.Contains(expected), check.Equals, true)
machine1, err := app.GetMachineByID(0) machine1, err := app.GetMachineByID(0)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -64,6 +67,8 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
for index := 1; index <= 350; index++ { for index := 1; index <= 350; index++ {
app.ipAllocationMutex.Lock()
ips, err := app.getAvailableIPs() ips, err := app.getAvailableIPs()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -86,17 +91,30 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
IPAddresses: ips, IPAddresses: ips,
} }
app.db.Save(&machine) app.db.Save(&machine)
app.ipAllocationMutex.Unlock()
} }
usedIps, err := app.getUsedIPs() usedIps, err := app.getUsedIPs()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(usedIps), check.Equals, 350) expected0 := netaddr.MustParseIP("10.27.0.1")
expected9 := netaddr.MustParseIP("10.27.0.10")
expected300 := netaddr.MustParseIP("10.27.0.45")
c.Assert(usedIps[0], check.Equals, netaddr.MustParseIP("10.27.0.1")) notExpectedIPSetBuilder := netaddr.IPSetBuilder{}
c.Assert(usedIps[9], check.Equals, netaddr.MustParseIP("10.27.0.10")) notExpectedIPSetBuilder.Add(expected0)
c.Assert(usedIps[300], check.Equals, netaddr.MustParseIP("10.27.1.45")) notExpectedIPSetBuilder.Add(expected9)
notExpectedIPSetBuilder.Add(expected300)
notExpectedIPSet, err := notExpectedIPSetBuilder.IPSet()
c.Assert(err, check.IsNil)
// We actually expect it to be a lot larger
c.Assert(usedIps.Equal(notExpectedIPSet), check.Equals, false)
c.Assert(usedIps.Contains(expected0), check.Equals, true)
c.Assert(usedIps.Contains(expected9), check.Equals, true)
c.Assert(usedIps.Contains(expected300), check.Equals, true)
// Check that we can read back the IPs // Check that we can read back the IPs
machine1, err := app.GetMachineByID(1) machine1, err := app.GetMachineByID(1)