Migrate utils to net/netip

This commit is contained in:
Juan Font Alonso 2022-09-02 00:05:18 +02:00
parent a24710a961
commit cdf48b1216
2 changed files with 31 additions and 27 deletions

View file

@ -13,6 +13,7 @@ import (
"fmt" "fmt"
"io/fs" "io/fs"
"net" "net"
"net/netip"
"os" "os"
"path/filepath" "path/filepath"
"reflect" "reflect"
@ -21,7 +22,7 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/spf13/viper" "github.com/spf13/viper"
"inet.af/netaddr" "go4.org/netipx"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key" "tailscale.com/types/key"
) )
@ -140,7 +141,7 @@ func (h *Headscale) getAvailableIPs() (MachineAddresses, error) {
var err error var err error
ipPrefixes := h.cfg.IPPrefixes ipPrefixes := h.cfg.IPPrefixes
for _, ipPrefix := range ipPrefixes { for _, ipPrefix := range ipPrefixes {
var ip *netaddr.IP var ip *netip.Addr
ip, err = h.getAvailableIP(ipPrefix) ip, err = h.getAvailableIP(ipPrefix)
if err != nil { if err != nil {
return ips, err return ips, err
@ -151,16 +152,16 @@ func (h *Headscale) getAvailableIPs() (MachineAddresses, error) {
return ips, err return ips, err
} }
func GetIPPrefixEndpoints(na netaddr.IPPrefix) (netaddr.IP, netaddr.IP) { func GetIPPrefixEndpoints(na netip.Prefix) (netip.Addr, netip.Addr) {
var network, broadcast netaddr.IP var network, broadcast netip.Addr
ipRange := na.Range() ipRange := netipx.RangeOfPrefix(na)
network = ipRange.From() network = ipRange.From()
broadcast = ipRange.To() broadcast = ipRange.To()
return network, broadcast return network, broadcast
} }
func (h *Headscale) getAvailableIP(ipPrefix netaddr.IPPrefix) (*netaddr.IP, error) { func (h *Headscale) getAvailableIP(ipPrefix netip.Prefix) (*netip.Addr, error) {
usedIps, err := h.getUsedIPs() usedIps, err := h.getUsedIPs()
if err != nil { if err != nil {
return nil, err return nil, err
@ -181,7 +182,7 @@ func (h *Headscale) getAvailableIP(ipPrefix netaddr.IPPrefix) (*netaddr.IP, erro
fallthrough fallthrough
case usedIps.Contains(ip): case usedIps.Contains(ip):
fallthrough fallthrough
case ip.IsZero() || ip.IsLoopback(): case ip == netip.Addr{} || ip.IsLoopback():
ip = ip.Next() ip = ip.Next()
continue continue
@ -192,19 +193,19 @@ func (h *Headscale) getAvailableIP(ipPrefix netaddr.IPPrefix) (*netaddr.IP, erro
} }
} }
func (h *Headscale) getUsedIPs() (*netaddr.IPSet, error) { func (h *Headscale) getUsedIPs() (*netipx.IPSet, error) {
// FIXME: This really deserves a better data model, // FIXME: This really deserves a better data model,
// but this was quick to get running and it should be enough // but this was quick to get running and it should be enough
// to begin experimenting with a dual stack tailnet. // to begin experimenting with a dual stack tailnet.
var addressesSlices []string var addressesSlices []string
h.db.Model(&Machine{}).Pluck("ip_addresses", &addressesSlices) h.db.Model(&Machine{}).Pluck("ip_addresses", &addressesSlices)
var ips netaddr.IPSetBuilder var ips netipx.IPSetBuilder
for _, slice := range addressesSlices { for _, slice := range addressesSlices {
var machineAddresses MachineAddresses var machineAddresses MachineAddresses
err := machineAddresses.Scan(slice) err := machineAddresses.Scan(slice)
if err != nil { if err != nil {
return &netaddr.IPSet{}, fmt.Errorf( return &netipx.IPSet{}, fmt.Errorf(
"failed to read ip from database: %w", "failed to read ip from database: %w",
err, err,
) )
@ -217,7 +218,7 @@ func (h *Headscale) getUsedIPs() (*netaddr.IPSet, error) {
ipSet, err := ips.IPSet() ipSet, err := ips.IPSet()
if err != nil { if err != nil {
return &netaddr.IPSet{}, fmt.Errorf( return &netipx.IPSet{}, fmt.Errorf(
"failed to build IP Set: %w", "failed to build IP Set: %w",
err, err,
) )
@ -250,7 +251,7 @@ func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) {
return d.DialContext(ctx, "unix", addr) return d.DialContext(ctx, "unix", addr)
} }
func ipPrefixToString(prefixes []netaddr.IPPrefix) []string { func ipPrefixToString(prefixes []netip.Prefix) []string {
result := make([]string, len(prefixes)) result := make([]string, len(prefixes))
for index, prefix := range prefixes { for index, prefix := range prefixes {
@ -260,13 +261,13 @@ func ipPrefixToString(prefixes []netaddr.IPPrefix) []string {
return result return result
} }
func stringToIPPrefix(prefixes []string) ([]netaddr.IPPrefix, error) { func stringToIPPrefix(prefixes []string) ([]netip.Prefix, error) {
result := make([]netaddr.IPPrefix, len(prefixes)) result := make([]netip.Prefix, len(prefixes))
for index, prefixStr := range prefixes { for index, prefixStr := range prefixes {
prefix, err := netaddr.ParseIPPrefix(prefixStr) prefix, err := netip.ParsePrefix(prefixStr)
if err != nil { if err != nil {
return []netaddr.IPPrefix{}, err return []netip.Prefix{}, err
} }
result[index] = prefix result[index] = prefix
@ -275,7 +276,7 @@ func stringToIPPrefix(prefixes []string) ([]netaddr.IPPrefix, error) {
return result, nil return result, nil
} }
func contains[T string | netaddr.IPPrefix](ts []T, t T) bool { func contains[T string | netip.Prefix](ts []T, t T) bool {
for _, v := range ts { for _, v := range ts {
if reflect.DeepEqual(v, t) { if reflect.DeepEqual(v, t) {
return true return true

View file

@ -1,6 +1,9 @@
package headscale package headscale
import ( import (
"net/netip"
"go4.org/netipx"
"gopkg.in/check.v1" "gopkg.in/check.v1"
"inet.af/netaddr" "inet.af/netaddr"
) )
@ -10,7 +13,7 @@ func (s *Suite) TestGetAvailableIp(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
expected := netaddr.MustParseIP("10.27.0.1") expected := netip.MustParseAddr("10.27.0.1")
c.Assert(len(ips), check.Equals, 1) c.Assert(len(ips), check.Equals, 1)
c.Assert(ips[0].String(), check.Equals, expected.String()) c.Assert(ips[0].String(), check.Equals, expected.String())
@ -46,8 +49,8 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
expected := netaddr.MustParseIP("10.27.0.1") expected := netip.MustParseAddr("10.27.0.1")
expectedIPSetBuilder := netaddr.IPSetBuilder{} expectedIPSetBuilder := netipx.IPSetBuilder{}
expectedIPSetBuilder.Add(expected) expectedIPSetBuilder.Add(expected)
expectedIPSet, _ := expectedIPSetBuilder.IPSet() expectedIPSet, _ := expectedIPSetBuilder.IPSet()
@ -96,11 +99,11 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
usedIps, err := app.getUsedIPs() usedIps, err := app.getUsedIPs()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
expected0 := netaddr.MustParseIP("10.27.0.1") expected0 := netip.MustParseAddr("10.27.0.1")
expected9 := netaddr.MustParseIP("10.27.0.10") expected9 := netip.MustParseAddr("10.27.0.10")
expected300 := netaddr.MustParseIP("10.27.0.45") expected300 := netip.MustParseAddr("10.27.0.45")
notExpectedIPSetBuilder := netaddr.IPSetBuilder{} notExpectedIPSetBuilder := netipx.IPSetBuilder{}
notExpectedIPSetBuilder.Add(expected0) notExpectedIPSetBuilder.Add(expected0)
notExpectedIPSetBuilder.Add(expected9) notExpectedIPSetBuilder.Add(expected9)
notExpectedIPSetBuilder.Add(expected300) notExpectedIPSetBuilder.Add(expected300)
@ -121,7 +124,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
c.Assert( c.Assert(
machine1.IPAddresses[0], machine1.IPAddresses[0],
check.Equals, check.Equals,
netaddr.MustParseIP("10.27.0.1"), netip.MustParseAddr("10.27.0.1"),
) )
machine50, err := app.GetMachineByID(50) machine50, err := app.GetMachineByID(50)
@ -130,10 +133,10 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
c.Assert( c.Assert(
machine50.IPAddresses[0], machine50.IPAddresses[0],
check.Equals, check.Equals,
netaddr.MustParseIP("10.27.0.50"), netip.MustParseAddr("10.27.0.50"),
) )
expectedNextIP := netaddr.MustParseIP("10.27.1.95") expectedNextIP := netip.MustParseAddr("10.27.1.95")
nextIP, err := app.getAvailableIPs() nextIP, err := app.getAvailableIPs()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)