Add support for multiple IP prefixes

This commit is contained in:
Csaba Sarkadi 2022-01-16 14:16:59 +01:00
parent 3a3aecb774
commit 1a6e5d8770
18 changed files with 423 additions and 257 deletions

View file

@ -185,7 +185,7 @@ func (h *Headscale) expandAlias(alias string) ([]string, error) {
return nil, errInvalidNamespace return nil, errInvalidNamespace
} }
for _, node := range nodes { for _, node := range nodes {
ips = append(ips, node.IPAddress) ips = append(ips, node.IPAddresses.ToStringSlice()...)
} }
} }
@ -219,7 +219,7 @@ func (h *Headscale) expandAlias(alias string) ([]string, error) {
// FIXME: Check TagOwners allows this // FIXME: Check TagOwners allows this
for _, t := range hostinfo.RequestTags { for _, t := range hostinfo.RequestTags {
if alias[4:] == t { if alias[4:] == t {
ips = append(ips, machine.IPAddress) ips = append(ips, machine.IPAddresses.ToStringSlice()...)
break break
} }
@ -238,7 +238,7 @@ func (h *Headscale) expandAlias(alias string) ([]string, error) {
} }
ips := []string{} ips := []string{}
for _, n := range nodes { for _, n := range nodes {
ips = append(ips, n.IPAddress) ips = append(ips, n.IPAddresses.ToStringSlice()...)
} }
return ips, nil return ips, nil

View file

@ -61,9 +61,9 @@ func (s *Suite) TestPortRange(c *check.C) {
c.Assert(rules, check.NotNil) c.Assert(rules, check.NotNil)
c.Assert(rules, check.HasLen, 1) c.Assert(rules, check.HasLen, 1)
c.Assert((rules)[0].DstPorts, check.HasLen, 1) c.Assert(rules[0].DstPorts, check.HasLen, 1)
c.Assert((rules)[0].DstPorts[0].Ports.First, check.Equals, uint16(5400)) c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(5400))
c.Assert((rules)[0].DstPorts[0].Ports.Last, check.Equals, uint16(5500)) c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(5500))
} }
func (s *Suite) TestPortWildcard(c *check.C) { func (s *Suite) TestPortWildcard(c *check.C) {
@ -75,11 +75,11 @@ func (s *Suite) TestPortWildcard(c *check.C) {
c.Assert(rules, check.NotNil) c.Assert(rules, check.NotNil)
c.Assert(rules, check.HasLen, 1) c.Assert(rules, check.HasLen, 1)
c.Assert((rules)[0].DstPorts, check.HasLen, 1) c.Assert(rules[0].DstPorts, check.HasLen, 1)
c.Assert((rules)[0].DstPorts[0].Ports.First, check.Equals, uint16(0)) c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0))
c.Assert((rules)[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535)) c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535))
c.Assert((rules)[0].SrcIPs, check.HasLen, 1) c.Assert(rules[0].SrcIPs, check.HasLen, 1)
c.Assert((rules)[0].SrcIPs[0], check.Equals, "*") c.Assert(rules[0].SrcIPs[0], check.Equals, "*")
} }
func (s *Suite) TestPortNamespace(c *check.C) { func (s *Suite) TestPortNamespace(c *check.C) {
@ -91,7 +91,7 @@ func (s *Suite) TestPortNamespace(c *check.C) {
_, err = app.GetMachine("testnamespace", "testmachine") _, err = app.GetMachine("testnamespace", "testmachine")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
ip, _ := app.getAvailableIP() ips, _ := app.getAvailableIPs()
machine := Machine{ machine := Machine{
ID: 0, ID: 0,
MachineKey: "foo", MachineKey: "foo",
@ -101,7 +101,7 @@ func (s *Suite) TestPortNamespace(c *check.C) {
NamespaceID: namespace.ID, NamespaceID: namespace.ID,
Registered: true, Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
IPAddress: ip.String(), IPAddresses: ips,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
} }
app.db.Save(&machine) app.db.Save(&machine)
@ -116,12 +116,13 @@ func (s *Suite) TestPortNamespace(c *check.C) {
c.Assert(rules, check.NotNil) c.Assert(rules, check.NotNil)
c.Assert(rules, check.HasLen, 1) c.Assert(rules, check.HasLen, 1)
c.Assert((rules)[0].DstPorts, check.HasLen, 1) c.Assert(rules[0].DstPorts, check.HasLen, 1)
c.Assert((rules)[0].DstPorts[0].Ports.First, check.Equals, uint16(0)) c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0))
c.Assert((rules)[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535)) c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535))
c.Assert((rules)[0].SrcIPs, check.HasLen, 1) c.Assert(rules[0].SrcIPs, check.HasLen, 1)
c.Assert((rules)[0].SrcIPs[0], check.Not(check.Equals), "not an ip") c.Assert(rules[0].SrcIPs[0], check.Not(check.Equals), "not an ip")
c.Assert((rules)[0].SrcIPs[0], check.Equals, ip.String()) c.Assert(len(ips), check.Equals, 1)
c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String())
} }
func (s *Suite) TestPortGroup(c *check.C) { func (s *Suite) TestPortGroup(c *check.C) {
@ -133,7 +134,7 @@ func (s *Suite) TestPortGroup(c *check.C) {
_, err = app.GetMachine("testnamespace", "testmachine") _, err = app.GetMachine("testnamespace", "testmachine")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
ip, _ := app.getAvailableIP() ips, _ := app.getAvailableIPs()
machine := Machine{ machine := Machine{
ID: 0, ID: 0,
MachineKey: "foo", MachineKey: "foo",
@ -143,7 +144,7 @@ func (s *Suite) TestPortGroup(c *check.C) {
NamespaceID: namespace.ID, NamespaceID: namespace.ID,
Registered: true, Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
IPAddress: ip.String(), IPAddresses: ips,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
} }
app.db.Save(&machine) app.db.Save(&machine)
@ -156,10 +157,11 @@ func (s *Suite) TestPortGroup(c *check.C) {
c.Assert(rules, check.NotNil) c.Assert(rules, check.NotNil)
c.Assert(rules, check.HasLen, 1) c.Assert(rules, check.HasLen, 1)
c.Assert((rules)[0].DstPorts, check.HasLen, 1) c.Assert(rules[0].DstPorts, check.HasLen, 1)
c.Assert((rules)[0].DstPorts[0].Ports.First, check.Equals, uint16(0)) c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0))
c.Assert((rules)[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535)) c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535))
c.Assert((rules)[0].SrcIPs, check.HasLen, 1) c.Assert(rules[0].SrcIPs, check.HasLen, 1)
c.Assert((rules)[0].SrcIPs[0], check.Not(check.Equals), "not an ip") c.Assert(rules[0].SrcIPs[0], check.Not(check.Equals), "not an ip")
c.Assert((rules)[0].SrcIPs[0], check.Equals, ip.String()) c.Assert(len(ips), check.Equals, 1)
c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String())
} }

15
api.go
View file

@ -497,6 +497,7 @@ func (h *Headscale) handleMachineRegistrationNew(
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
} }
// TODO: check if any locks are needed around IP allocation.
func (h *Headscale) handleAuthKey( func (h *Headscale) handleAuthKey(
ctx *gin.Context, ctx *gin.Context,
machineKey key.MachinePublic, machineKey key.MachinePublic,
@ -554,14 +555,14 @@ func (h *Headscale) handleAuthKey(
log.Debug(). log.Debug().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", machine.Name). Str("machine", machine.Name).
Msg("Authentication key was valid, proceeding to acquire an IP address") Msg("Authentication key was valid, proceeding to acquire IP addresses")
ip, err := h.getAvailableIP() ips, err := h.getAvailableIPs()
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", machine.Name). Str("machine", machine.Name).
Msg("Failed to find an available IP") Msg("Failed to find an available IP address")
machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name). machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
Inc() Inc()
@ -570,12 +571,12 @@ func (h *Headscale) handleAuthKey(
log.Info(). log.Info().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", machine.Name). Str("machine", machine.Name).
Str("ip", ip.String()). Str("ips", strings.Join(ips.ToStringSlice(), ",")).
Msgf("Assigning %s to %s", ip, machine.Name) Msgf("Assigning %s to %s", strings.Join(ips.ToStringSlice(), ","), machine.Name)
machine.Expiry = &registerRequest.Expiry machine.Expiry = &registerRequest.Expiry
machine.AuthKeyID = uint(pak.ID) machine.AuthKeyID = uint(pak.ID)
machine.IPAddress = ip.String() machine.IPAddresses = ips
machine.NamespaceID = pak.NamespaceID machine.NamespaceID = pak.NamespaceID
machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey) machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey)
@ -610,6 +611,6 @@ func (h *Headscale) handleAuthKey(
log.Info(). log.Info().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", machine.Name). Str("machine", machine.Name).
Str("ip", machine.IPAddress). Str("ips", strings.Join(machine.IPAddresses.ToStringSlice(), ", ")).
Msg("Successfully authenticated via AuthKey") Msg("Successfully authenticated via AuthKey")
} }

6
app.go
View file

@ -68,7 +68,7 @@ type Config struct {
ServerURL string ServerURL string
Addr string Addr string
EphemeralNodeInactivityTimeout time.Duration EphemeralNodeInactivityTimeout time.Duration
IPPrefix netaddr.IPPrefix IPPrefixes []netaddr.IPPrefix
PrivateKeyPath string PrivateKeyPath string
BaseDomain string BaseDomain string
@ -197,9 +197,7 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
} }
if app.cfg.DNSConfig != nil && app.cfg.DNSConfig.Proxied { // if MagicDNS if app.cfg.DNSConfig != nil && app.cfg.DNSConfig.Proxied { // if MagicDNS
magicDNSDomains := generateMagicDNSRootDomains( magicDNSDomains := generateMagicDNSRootDomains(app.cfg.IPPrefixes)
app.cfg.IPPrefix,
)
// we might have routes already from Split DNS // we might have routes already from Split DNS
if app.cfg.DNSConfig.Routes == nil { if app.cfg.DNSConfig.Routes == nil {
app.cfg.DNSConfig.Routes = make(map[string][]dnstype.Resolver) app.cfg.DNSConfig.Routes = make(map[string][]dnstype.Resolver)

View file

@ -41,7 +41,9 @@ func (s *Suite) ResetDB(c *check.C) {
c.Fatal(err) c.Fatal(err)
} }
cfg := Config{ cfg := Config{
IPPrefix: netaddr.MustParseIPPrefix("10.27.0.0/23"), IPPrefixes: []netaddr.IPPrefix{
netaddr.MustParseIPPrefix("10.27.0.0/23"),
},
} }
app = Headscale{ app = Headscale{

View file

@ -4,6 +4,7 @@ import (
"time" "time"
"gopkg.in/check.v1" "gopkg.in/check.v1"
"inet.af/netaddr"
) )
func (s *Suite) TestRegisterMachine(c *check.C) { func (s *Suite) TestRegisterMachine(c *check.C) {
@ -19,16 +20,17 @@ func (s *Suite) TestRegisterMachine(c *check.C) {
DiscoKey: "faa", DiscoKey: "faa",
Name: "testmachine", Name: "testmachine",
NamespaceID: namespace.ID, NamespaceID: namespace.ID,
IPAddress: "10.0.0.1", IPAddresses: []netaddr.IP{netaddr.MustParseIP("10.0.0.1")},
Expiry: &now, Expiry: &now,
} }
app.db.Save(&machine) err = app.db.Save(&machine).Error
c.Assert(err, check.IsNil)
_, err = app.GetMachine("test", "testmachine") _, err = app.GetMachine(namespace.Name, machine.Name)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
machineAfterRegistering, err := app.RegisterMachine( machineAfterRegistering, err := app.RegisterMachine(
"8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e", machine.MachineKey,
namespace.Name, namespace.Name,
) )
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)

View file

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"log" "log"
"strconv" "strconv"
"strings"
"time" "time"
survey "github.com/AlecAivazis/survey/v2" survey "github.com/AlecAivazis/survey/v2"
@ -459,7 +460,7 @@ func nodesToPtables(
"Name", "Name",
"NodeKey", "NodeKey",
"Namespace", "Namespace",
"IP address", "IP addresses",
"Ephemeral", "Ephemeral",
"Last seen", "Last seen",
"Online", "Online",
@ -523,7 +524,7 @@ func nodesToPtables(
machine.Name, machine.Name,
nodeKey.ShortString(), nodeKey.ShortString(),
namespace, namespace,
machine.IpAddress, strings.Join(machine.IpAddresses, ", "),
strconv.FormatBool(ephemeral), strconv.FormatBool(ephemeral),
lastSeenTime, lastSeenTime,
online, online,

View file

@ -41,7 +41,7 @@ func LoadConfig(path string) error {
viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache") viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache")
viper.SetDefault("tls_letsencrypt_challenge_type", "HTTP-01") viper.SetDefault("tls_letsencrypt_challenge_type", "HTTP-01")
viper.SetDefault("ip_prefix", "100.64.0.0/10") viper.SetDefault("ip_prefixes", []string{"100.64.0.0/10"})
viper.SetDefault("log_level", "info") viper.SetDefault("log_level", "info")
@ -221,10 +221,20 @@ func getHeadscaleConfig() headscale.Config {
dnsConfig, baseDomain := GetDNSConfig() dnsConfig, baseDomain := GetDNSConfig()
derpConfig := GetDERPConfig() derpConfig := GetDERPConfig()
configuredPrefixes := viper.GetStringSlice("ip_prefixes")
prefixes := make([]netaddr.IPPrefix, 0, len(configuredPrefixes))
for i, prefixInConfig := range configuredPrefixes {
prefix, err := netaddr.ParseIPPrefix(prefixInConfig)
if err != nil {
panic(fmt.Errorf("failed to parse ip_prefixes[%d]: %w", i, err))
}
prefixes = append(prefixes, prefix)
}
return headscale.Config{ return headscale.Config{
ServerURL: viper.GetString("server_url"), ServerURL: viper.GetString("server_url"),
Addr: viper.GetString("listen_addr"), Addr: viper.GetString("listen_addr"),
IPPrefix: netaddr.MustParseIPPrefix(viper.GetString("ip_prefix")), IPPrefixes: prefixes,
PrivateKeyPath: absPath(viper.GetString("private_key_path")), PrivateKeyPath: absPath(viper.GetString("private_key_path")),
BaseDomain: baseDomain, BaseDomain: baseDomain,

27
dns.go
View file

@ -34,14 +34,25 @@ const (
// From the netmask we can find out the wildcard bits (the bits that are not set in the netmask). // From the netmask we can find out the wildcard bits (the bits that are not set in the netmask).
// This allows us to then calculate the subnets included in the subsequent class block and generate the entries. // This allows us to then calculate the subnets included in the subsequent class block and generate the entries.
func generateMagicDNSRootDomains( func generateMagicDNSRootDomains(ipPrefixes []netaddr.IPPrefix) []dnsname.FQDN {
ipPrefix netaddr.IPPrefix, fqdns := make([]dnsname.FQDN, 0, len(ipPrefixes))
) []dnsname.FQDN { for _, ipPrefix := range ipPrefixes {
// TODO(juanfont): we are not handing out IPv6 addresses yet var generateDnsRoot func(netaddr.IPPrefix) []dnsname.FQDN
// and in fact this is Tailscale.com's range (note the fd7a:115c:a1e0: range in the fc00::/7 network) switch ipPrefix.IP().BitLen() {
ipv6base := dnsname.FQDN("0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa.") case 32:
fqdns := []dnsname.FQDN{ipv6base} generateDnsRoot = generateIPv4DNSRootDomain
default:
panic(fmt.Sprintf("unsupported IP version with address length %d", ipPrefix.IP().BitLen()))
}
fqdns = append(fqdns, generateDnsRoot(ipPrefix)...)
}
return fqdns
}
func generateIPv4DNSRootDomain(ipPrefix netaddr.IPPrefix) (fqdns []dnsname.FQDN) {
// Conversion to the std lib net.IPnet, a bit easier to operate // Conversion to the std lib net.IPnet, a bit easier to operate
netRange := ipPrefix.IPNet() netRange := ipPrefix.IPNet()
maskBits, _ := netRange.Mask.Size() maskBits, _ := netRange.Mask.Size()
@ -73,7 +84,7 @@ func generateMagicDNSRootDomains(
fqdns = append(fqdns, fqdn) fqdns = append(fqdns, fqdn)
} }
return fqdns return
} }
func getMapResponseDNSConfig( func getMapResponseDNSConfig(

View file

@ -124,7 +124,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
Namespace: *namespaceShared1, Namespace: *namespaceShared1,
Registered: true, Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
IPAddress: "100.64.0.1", IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.1")},
AuthKeyID: uint(preAuthKeyInShared1.ID), AuthKeyID: uint(preAuthKeyInShared1.ID),
} }
app.db.Save(machineInShared1) app.db.Save(machineInShared1)
@ -142,7 +142,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
Namespace: *namespaceShared2, Namespace: *namespaceShared2,
Registered: true, Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
IPAddress: "100.64.0.2", IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.2")},
AuthKeyID: uint(preAuthKeyInShared2.ID), AuthKeyID: uint(preAuthKeyInShared2.ID),
} }
app.db.Save(machineInShared2) app.db.Save(machineInShared2)
@ -160,7 +160,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
Namespace: *namespaceShared3, Namespace: *namespaceShared3,
Registered: true, Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
IPAddress: "100.64.0.3", IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.3")},
AuthKeyID: uint(preAuthKeyInShared3.ID), AuthKeyID: uint(preAuthKeyInShared3.ID),
} }
app.db.Save(machineInShared3) app.db.Save(machineInShared3)
@ -178,7 +178,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
Namespace: *namespaceShared1, Namespace: *namespaceShared1,
Registered: true, Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
IPAddress: "100.64.0.4", IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.4")},
AuthKeyID: uint(PreAuthKey2InShared1.ID), AuthKeyID: uint(PreAuthKey2InShared1.ID),
} }
app.db.Save(machine2InShared1) app.db.Save(machine2InShared1)
@ -273,7 +273,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
Namespace: *namespaceShared1, Namespace: *namespaceShared1,
Registered: true, Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
IPAddress: "100.64.0.1", IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.1")},
AuthKeyID: uint(preAuthKeyInShared1.ID), AuthKeyID: uint(preAuthKeyInShared1.ID),
} }
app.db.Save(machineInShared1) app.db.Save(machineInShared1)
@ -291,7 +291,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
Namespace: *namespaceShared2, Namespace: *namespaceShared2,
Registered: true, Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
IPAddress: "100.64.0.2", IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.2")},
AuthKeyID: uint(preAuthKeyInShared2.ID), AuthKeyID: uint(preAuthKeyInShared2.ID),
} }
app.db.Save(machineInShared2) app.db.Save(machineInShared2)
@ -309,7 +309,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
Namespace: *namespaceShared3, Namespace: *namespaceShared3,
Registered: true, Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
IPAddress: "100.64.0.3", IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.3")},
AuthKeyID: uint(preAuthKeyInShared3.ID), AuthKeyID: uint(preAuthKeyInShared3.ID),
} }
app.db.Save(machineInShared3) app.db.Save(machineInShared3)
@ -327,7 +327,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
Namespace: *namespaceShared1, Namespace: *namespaceShared1,
Registered: true, Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
IPAddress: "100.64.0.4", IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.4")},
AuthKeyID: uint(preAuthKey2InShared1.ID), AuthKeyID: uint(preAuthKey2InShared1.ID),
} }
app.db.Save(machine2InShared1) app.db.Save(machine2InShared1)

View file

@ -372,70 +372,74 @@ func (s *IntegrationTestSuite) TestListNodes() {
func (s *IntegrationTestSuite) TestGetIpAddresses() { func (s *IntegrationTestSuite) TestGetIpAddresses() {
for _, scales := range s.namespaces { for _, scales := range s.namespaces {
ipPrefix := netaddr.MustParseIPPrefix("100.64.0.0/10")
ips, err := getIPs(scales.tailscales) ips, err := getIPs(scales.tailscales)
assert.Nil(s.T(), err) assert.Nil(s.T(), err)
for hostname := range scales.tailscales { for hostname, _ := range scales.tailscales {
s.T().Run(hostname, func(t *testing.T) { ips := ips[hostname]
ip, ok := ips[hostname] for _, ip := range ips {
s.T().Run(hostname, func(t *testing.T) {
assert.NotNil(t, ip)
assert.True(t, ok) fmt.Printf("IP for %s: %s\n", hostname, ip)
assert.NotNil(t, ip)
fmt.Printf("IP for %s: %s\n", hostname, ip) // c.Assert(ip.Valid(), check.IsTrue)
assert.True(t, ip.Is4() || ip.Is6())
// c.Assert(ip.Valid(), check.IsTrue) switch {
assert.True(t, ip.Is4()) case ip.Is4():
assert.True(t, ipPrefix.Contains(ip)) assert.True(t, IpPrefix4.Contains(ip))
}) case ip.Is6():
assert.True(t, IpPrefix6.Contains(ip))
}
})
}
} }
} }
} }
// TODO(kradalby): fix this test // TODO(kradalby): fix this test
// We need some way to impot ipnstate.Status from multiple go packages. // We need some way to import ipnstate.Status from multiple go packages.
// Currently it will only work with 1.18.x since that is the last // Currently it will only work with 1.18.x since that is the last
// version we have in go.mod // version we have in go.mod
// func (s *IntegrationTestSuite) TestStatus() { // func (s *IntegrationTestSuite) TestStatus() {
// for _, scales := range s.namespaces { // for _, scales := range s.namespaces {
// ips, err := getIPs(scales.tailscales) // ips, err := getIPs(scales.tailscales)
// assert.Nil(s.T(), err) // assert.Nil(s.T(), err)
// //
// for hostname, tailscale := range scales.tailscales { // for hostname, tailscale := range scales.tailscales {
// s.T().Run(hostname, func(t *testing.T) { // s.T().Run(hostname, func(t *testing.T) {
// command := []string{"tailscale", "status", "--json"} // command := []string{"tailscale", "status", "--json"}
// //
// fmt.Printf("Getting status for %s\n", hostname) // fmt.Printf("Getting status for %s\n", hostname)
// result, err := ExecuteCommand( // result, err := ExecuteCommand(
// &tailscale, // &tailscale,
// command, // command,
// []string{}, // []string{},
// ) // )
// assert.Nil(t, err) // assert.Nil(t, err)
// //
// var status ipnstate.Status // var status ipnstate.Status
// err = json.Unmarshal([]byte(result), &status) // err = json.Unmarshal([]byte(result), &status)
// assert.Nil(s.T(), err) // assert.Nil(s.T(), err)
// //
// // TODO(kradalby): Replace this check with peer length of SAME namespace // // TODO(kradalby): Replace this check with peer length of SAME namespace
// // Check if we have as many nodes in status // // Check if we have as many nodes in status
// // as we have IPs/tailscales // // as we have IPs/tailscales
// // lines := strings.Split(result, "\n") // // lines := strings.Split(result, "\n")
// // assert.Equal(t, len(ips), len(lines)-1) // // assert.Equal(t, len(ips), len(lines)-1)
// // assert.Equal(t, len(scales.tailscales), len(lines)-1) // // assert.Equal(t, len(scales.tailscales), len(lines)-1)
// //
// peerIps := getIPsfromIPNstate(status) // peerIps := getIPsfromIPNstate(status)
// //
// // Check that all hosts is present in all hosts status // // Check that all hosts is present in all hosts status
// for ipHostname, ip := range ips { // for ipHostname, ip := range ips {
// if hostname != ipHostname { // if hostname != ipHostname {
// assert.Contains(t, peerIps, ip) // assert.Contains(t, peerIps, ip)
// } // }
// } // }
// }) // })
// } // }
// } // }
// } // }
func getIPsfromIPNstate(status ipnstate.Status) []netaddr.IP { func getIPsfromIPNstate(status ipnstate.Status) []netaddr.IP {
@ -448,16 +452,19 @@ func getIPsfromIPNstate(status ipnstate.Status) []netaddr.IP {
return ips return ips
} }
func (s *IntegrationTestSuite) TestPingAllPeers() { func (s *IntegrationTestSuite) TestPingAllPeersByAddress() {
for _, scales := range s.namespaces { for _, scales := range s.namespaces {
ips, err := getIPs(scales.tailscales) ips, err := getIPs(scales.tailscales)
assert.Nil(s.T(), err) assert.Nil(s.T(), err)
for hostname, tailscale := range scales.tailscales { for hostname, tailscale := range scales.tailscales {
for peername, ip := range ips { for peername, peerIPs := range ips {
s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) { for i, ip := range peerIPs {
// We currently cant ping ourselves, so skip that. // We currently cant ping ourselves, so skip that.
if peername != hostname { if peername == hostname {
continue
}
s.T().Run(fmt.Sprintf("%s-%s-%d", hostname, peername, i), func(t *testing.T) {
// We are only interested in "direct ping" which means what we // We are only interested in "direct ping" which means what we
// might need a couple of more attempts before reaching the node. // might need a couple of more attempts before reaching the node.
command := []string{ command := []string{
@ -469,9 +476,8 @@ func (s *IntegrationTestSuite) TestPingAllPeers() {
} }
fmt.Printf( fmt.Printf(
"Pinging from %s (%s) to %s (%s)\n", "Pinging from %s to %s (%s)\n",
hostname, hostname,
ips[hostname],
peername, peername,
ip, ip,
) )
@ -483,8 +489,8 @@ func (s *IntegrationTestSuite) TestPingAllPeers() {
assert.Nil(t, err) assert.Nil(t, err)
fmt.Printf("Result for %s: %s\n", hostname, result) fmt.Printf("Result for %s: %s\n", hostname, result)
assert.Contains(t, result, "pong") assert.Contains(t, result, "pong")
} })
}) }
} }
} }
} }
@ -553,17 +559,17 @@ func (s *IntegrationTestSuite) TestSharedNodes() {
// TODO(juanfont): We have to find out why do we need to wait // TODO(juanfont): We have to find out why do we need to wait
time.Sleep(100 * time.Second) // Wait for the nodes to receive updates time.Sleep(100 * time.Second) // Wait for the nodes to receive updates
mainIps, err := getIPs(main.tailscales)
assert.Nil(s.T(), err)
sharedIps, err := getIPs(shared.tailscales) sharedIps, err := getIPs(shared.tailscales)
assert.Nil(s.T(), err) assert.Nil(s.T(), err)
for hostname, tailscale := range main.tailscales { for hostname, tailscale := range main.tailscales {
for peername, ip := range sharedIps { for peername, peerIPs := range sharedIps {
s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) { for i, ip := range peerIPs {
// We currently cant ping ourselves, so skip that. // We currently cant ping ourselves, so skip that.
if peername != hostname { if peername == hostname {
continue
}
s.T().Run(fmt.Sprintf("%s-%s-%d", hostname, peername, i), func(t *testing.T) {
// We are only interested in "direct ping" which means what we // We are only interested in "direct ping" which means what we
// might need a couple of more attempts before reaching the node. // might need a couple of more attempts before reaching the node.
command := []string{ command := []string{
@ -575,9 +581,8 @@ func (s *IntegrationTestSuite) TestSharedNodes() {
} }
fmt.Printf( fmt.Printf(
"Pinging from %s (%s) to %s (%s)\n", "Pinging from %s to %s (%s)\n",
hostname, hostname,
mainIps[hostname],
peername, peername,
ip, ip,
) )
@ -589,8 +594,8 @@ func (s *IntegrationTestSuite) TestSharedNodes() {
assert.Nil(t, err) assert.Nil(t, err)
fmt.Printf("Result for %s: %s\n", hostname, result) fmt.Printf("Result for %s: %s\n", hostname, result)
assert.Contains(t, result, "pong") assert.Contains(t, result, "pong")
} })
}) }
} }
} }
} }
@ -607,7 +612,7 @@ func (s *IntegrationTestSuite) TestTailDrop() {
_, err := ExecuteCommand( _, err := ExecuteCommand(
&tailscale, &tailscale,
command, command,
[]string{}, []string{"GOMAXPROCS=32"},
) )
assert.Nil(s.T(), err) assert.Nil(s.T(), err)
for peername, ip := range ips { for peername, ip := range ips {
@ -653,7 +658,7 @@ func (s *IntegrationTestSuite) TestTailDrop() {
_, err = ExecuteCommand( _, err = ExecuteCommand(
&tailscale, &tailscale,
command, command,
[]string{"ALL_PROXY=socks5://localhost:1055"}, []string{"ALL_PROXY=socks5://localhost:1055", "GOMAXPROCS=32"},
) )
if err == nil { if err == nil {
break break
@ -684,78 +689,125 @@ func (s *IntegrationTestSuite) TestTailDrop() {
) )
assert.Nil(s.T(), err) assert.Nil(s.T(), err)
for peername, ip := range ips { for peername, ip := range ips {
if peername == hostname {
continue
}
s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) { s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) {
if peername != hostname { command := []string{
command := []string{ "ls",
"ls", fmt.Sprintf("/tmp/file_from_%s", peername),
fmt.Sprintf("/tmp/file_from_%s", peername),
}
fmt.Printf(
"Checking file in %s (%s) from %s (%s)\n",
hostname,
ips[hostname],
peername,
ip,
)
result, err := ExecuteCommand(
&tailscale,
command,
[]string{},
)
assert.Nil(t, err)
fmt.Printf("Result for %s: %s\n", peername, result)
assert.Equal(
t,
result,
fmt.Sprintf("/tmp/file_from_%s\n", peername),
)
} }
fmt.Printf(
"Checking file in %s (%s) from %s (%s)\n",
hostname,
ips[hostname],
peername,
ip,
)
result, err := ExecuteCommand(
&tailscale,
command,
[]string{},
)
assert.Nil(t, err)
fmt.Printf("Result for %s: %s\n", peername, result)
assert.Equal(
t,
fmt.Sprintf("/tmp/file_from_%s\n", peername),
result,
)
}) })
} }
} }
} }
} }
func (s *IntegrationTestSuite) TestMagicDNS() { func (s *IntegrationTestSuite) TestPingAllPeersByHostname() {
for namespace, scales := range s.namespaces { for namespace, scales := range s.namespaces {
ips, err := getIPs(scales.tailscales) ips, err := getIPs(scales.tailscales)
assert.Nil(s.T(), err) assert.Nil(s.T(), err)
for hostname, tailscale := range scales.tailscales { for hostname, tailscale := range scales.tailscales {
for peername, ip := range ips { for peername, _ := range ips {
if peername == hostname {
continue
}
s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) { s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) {
if peername != hostname { command := []string{
command := []string{ "tailscale", "ping",
"tailscale", "ping", "--timeout=10s",
"--timeout=10s", "--c=20",
"--c=20", "--until-direct=true",
"--until-direct=true", fmt.Sprintf("%s.%s.headscale.net", peername, namespace),
fmt.Sprintf("%s.%s.headscale.net", peername, namespace),
}
fmt.Printf(
"Pinging using Hostname (magicdns) from %s (%s) to %s (%s)\n",
hostname,
ips[hostname],
peername,
ip,
)
result, err := ExecuteCommand(
&tailscale,
command,
[]string{},
)
assert.Nil(t, err)
fmt.Printf("Result for %s: %s\n", hostname, result)
assert.Contains(t, result, "pong")
} }
fmt.Printf(
"Pinging using Hostname from %s to %s\n",
hostname,
peername,
)
result, err := ExecuteCommand(
&tailscale,
command,
[]string{},
)
assert.Nil(t, err)
fmt.Printf("Result for %s: %s\n", hostname, result)
assert.Contains(t, result, "pong")
}) })
} }
} }
} }
} }
func getIPs(tailscales map[string]dockertest.Resource) (map[string]netaddr.IP, error) { // TODO:
ips := make(map[string]netaddr.IP) // * With manual testing, MagicDNS does not respond to AAAA queries. Why?
// * Tailscaled only adds a route to the IPv4 (100.100.100.100) address of the MagicDNS service,
// event though there is an IPv6 one (fd7a:115c:a1e0::53) as well.
func (s *IntegrationTestSuite) TestMagicDNSv4() {
for namespace, scales := range s.namespaces {
ips, err := getIPs(scales.tailscales)
assert.Nil(s.T(), err)
for hostname, tailscale := range scales.tailscales {
for peername, ips := range ips {
if peername == hostname {
continue
}
s.T().Run(fmt.Sprintf("%s-%s-ipv4", hostname, peername), func(t *testing.T) {
command := []string{
"host", "-4", "-t", "A",
fmt.Sprintf("%s.%s.headscale.net", peername, namespace),
"100.100.100.100",
}
fmt.Printf(
"Resolving name %s (IPv4) from %s over IPv4\n",
peername,
hostname,
)
result, err := ExecuteCommand(
&tailscale,
command,
[]string{},
)
assert.Nil(t, err)
fmt.Printf("Result for %s: %s\n", hostname, result)
resolved := false
for _, ip := range ips {
if strings.Contains(result, fmt.Sprintf("has address %s", ip.String())) {
resolved = true
break
}
}
assert.Equal(t, true, resolved)
})
}
}
}
}
func getIPs(tailscales map[string]dockertest.Resource) (map[string][]netaddr.IP, error) {
ips := make(map[string][]netaddr.IP)
for hostname, tailscale := range tailscales { for hostname, tailscale := range tailscales {
command := []string{"tailscale", "ip"} command := []string{"tailscale", "ip"}
@ -768,12 +820,17 @@ func getIPs(tailscales map[string]dockertest.Resource) (map[string]netaddr.IP, e
return nil, err return nil, err
} }
ip, err := netaddr.ParseIP(strings.TrimSuffix(result, "\n")) for _, address := range strings.Split(result, "\n") {
if err != nil { address = strings.TrimSuffix(address, "\n")
return nil, err if len(address) < 1 {
continue
}
ip, err := netaddr.ParseIP(address)
if err != nil {
return nil, err
}
ips[hostname] = append(ips[hostname], ip)
} }
ips[hostname] = ip
} }
return ips, nil return ips, nil

View file

@ -1,6 +1,7 @@
package headscale package headscale
import ( import (
"database/sql/driver"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -23,6 +24,7 @@ const (
errMachineNotFound = Error("machine not found") errMachineNotFound = Error("machine not found")
errMachineAlreadyRegistered = Error("machine already registered") errMachineAlreadyRegistered = Error("machine already registered")
errMachineRouteIsNotAvailable = Error("route is not available on machine") errMachineRouteIsNotAvailable = Error("route is not available on machine")
errMachineAddressesInvalid = Error("failed to parse machine addresses")
) )
// Machine is a Headscale client. // Machine is a Headscale client.
@ -31,7 +33,7 @@ type Machine struct {
MachineKey string `gorm:"type:varchar(64);unique_index"` MachineKey string `gorm:"type:varchar(64);unique_index"`
NodeKey string NodeKey string
DiscoKey string DiscoKey string
IPAddress string IPAddresses MachineAddresses
Name string Name string
NamespaceID uint NamespaceID uint
Namespace Namespace `gorm:"foreignKey:NamespaceID"` Namespace Namespace `gorm:"foreignKey:NamespaceID"`
@ -64,6 +66,47 @@ func (machine Machine) isRegistered() bool {
return machine.Registered return machine.Registered
} }
type MachineAddresses []netaddr.IP
func (ma MachineAddresses) ToStringSlice() []string {
strSlice := make([]string, 0, len(ma))
for _, addr := range ma {
strSlice = append(strSlice, addr.String())
}
return strSlice
}
func (ma *MachineAddresses) Scan(destination interface{}) error {
switch value := destination.(type) {
case string:
addresses := strings.Split(value, ",")
*ma = (*ma)[:0]
for _, addr := range addresses {
if len(addr) < 1 {
continue
}
parsed, err := netaddr.ParseIP(addr)
if err != nil {
return err
}
*ma = append(*ma, parsed)
}
return nil
default:
return fmt.Errorf("%w: unexpected data type %T", errMachineAddressesInvalid, destination)
}
}
// Value return json value, implement driver.Valuer interface.
func (ma MachineAddresses) Value() (driver.Value, error) {
addresses := strings.Join(ma.ToStringSlice(), ",")
return addresses, nil
}
// isExpired returns whether the machine registration has expired. // isExpired returns whether the machine registration has expired.
func (machine Machine) isExpired() bool { func (machine Machine) isExpired() bool {
// If Expiry is not set, the client has not indicated that // If Expiry is not set, the client has not indicated that
@ -470,22 +513,12 @@ func (machine Machine) toNode(
} }
addrs := []netaddr.IPPrefix{} addrs := []netaddr.IPPrefix{}
nodeAddr, err := netaddr.ParseIP(m.IPAddresses) for _, machineAddress := range machine.IPAddresses {
if err != nil { ip := netaddr.IPPrefixFrom(machineAddress, machineAddress.BitLen())
log.Trace(). addrs = append(addrs, ip)
Caller().
Str("ip", machine.IPAddresses).
Msgf("Failed to parse machine IP: %s", machine.IPAddresses)
return nil, err
} }
ip := netaddr.IPPrefixFrom(nodeAddr, nodeAddr.BitLen())
addrs = append(addrs, ip)
allowedIPs := []netaddr.IPPrefix{} allowedIPs := append([]netaddr.IPPrefix{}, addrs...) // we append the node own IP, as it is required by the clients
allowedIPs = append(
allowedIPs,
ip,
) // we append the node own IP, as it is required by the clients
if includeRoutes { if includeRoutes {
routesStr := []string{} routesStr := []string{}
@ -592,11 +625,11 @@ func (machine *Machine) toProto() *v1.Machine {
Id: machine.ID, Id: machine.ID,
MachineKey: machine.MachineKey, MachineKey: machine.MachineKey,
NodeKey: machine.NodeKey, NodeKey: machine.NodeKey,
DiscoKey: machine.DiscoKey, DiscoKey: machine.DiscoKey,
IpAddress: machine.IPAddress, IpAddresses: machine.IPAddresses.ToStringSlice(),
Name: machine.Name, Name: machine.Name,
Namespace: machine.Namespace.toProto(), Namespace: machine.Namespace.toProto(),
Registered: machine.Registered, Registered: machine.Registered,
@ -695,7 +728,7 @@ func (h *Headscale) RegisterMachine(
return nil, err return nil, err
} }
ip, err := h.getAvailableIP() ips, err := h.getAvailableIPs()
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
@ -709,10 +742,10 @@ func (h *Headscale) RegisterMachine(
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine", machine.Name). Str("machine", machine.Name).
Str("ip", ip.String()). Str("ip", strings.Join(ips.ToStringSlice(), ",")).
Msg("Found IP for host") Msg("Found IP for host")
machine.IPAddress = ip.String() machine.IPAddresses = ips
machine.NamespaceID = namespace.ID machine.NamespaceID = namespace.ID
machine.Registered = true machine.Registered = true
machine.RegisterMethod = RegisterMethodCLI machine.RegisterMethod = RegisterMethodCLI
@ -722,7 +755,7 @@ func (h *Headscale) RegisterMachine(
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine", machine.Name). Str("machine", machine.Name).
Str("ip", ip.String()). Str("ip", strings.Join(ips.ToStringSlice(), ",")).
Msg("Machine registered with the database") Msg("Machine registered with the database")
return machine, nil return machine, nil

View file

@ -6,6 +6,7 @@ import (
"time" "time"
"gopkg.in/check.v1" "gopkg.in/check.v1"
"inet.af/netaddr"
) )
func (s *Suite) TestGetMachine(c *check.C) { func (s *Suite) TestGetMachine(c *check.C) {
@ -199,3 +200,22 @@ func (s *Suite) TestExpireMachine(c *check.C) {
c.Assert(machineFromDB.isExpired(), check.Equals, true) c.Assert(machineFromDB.isExpired(), check.Equals, true)
} }
func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) {
input := MachineAddresses([]netaddr.IP{
netaddr.MustParseIP("192.0.2.1"),
netaddr.MustParseIP("2001:db8::1"),
})
serialized, err := input.Value()
c.Assert(err, check.IsNil)
c.Assert(serialized.(string), check.Equals, "192.0.2.1,2001:db8::1")
var deserialized MachineAddresses
err = deserialized.Scan(serialized)
c.Assert(err, check.IsNil)
c.Assert(len(deserialized), check.Equals, len(input))
for i := range deserialized {
c.Assert(deserialized[i], check.Equals, input[i])
}
}

View file

@ -4,6 +4,7 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"gopkg.in/check.v1" "gopkg.in/check.v1"
"gorm.io/gorm" "gorm.io/gorm"
"inet.af/netaddr"
) )
func (s *Suite) TestCreateAndDestroyNamespace(c *check.C) { func (s *Suite) TestCreateAndDestroyNamespace(c *check.C) {
@ -146,7 +147,7 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
Namespace: *namespaceShared1, Namespace: *namespaceShared1,
Registered: true, Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
IPAddress: "100.64.0.1", IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.1")},
AuthKeyID: uint(preAuthKeyShared1.ID), AuthKeyID: uint(preAuthKeyShared1.ID),
} }
app.db.Save(machineInShared1) app.db.Save(machineInShared1)
@ -164,7 +165,7 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
Namespace: *namespaceShared2, Namespace: *namespaceShared2,
Registered: true, Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
IPAddress: "100.64.0.2", IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.2")},
AuthKeyID: uint(preAuthKeyShared2.ID), AuthKeyID: uint(preAuthKeyShared2.ID),
} }
app.db.Save(machineInShared2) app.db.Save(machineInShared2)
@ -182,7 +183,7 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
Namespace: *namespaceShared3, Namespace: *namespaceShared3,
Registered: true, Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
IPAddress: "100.64.0.3", IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.3")},
AuthKeyID: uint(preAuthKeyShared3.ID), AuthKeyID: uint(preAuthKeyShared3.ID),
} }
app.db.Save(machineInShared3) app.db.Save(machineInShared3)
@ -200,7 +201,7 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
Namespace: *namespaceShared1, Namespace: *namespaceShared1,
Registered: true, Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
IPAddress: "100.64.0.4", IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.4")},
AuthKeyID: uint(preAuthKey2Shared1.ID), AuthKeyID: uint(preAuthKey2Shared1.ID),
} }
app.db.Save(machine2InShared1) app.db.Save(machine2InShared1)

View file

@ -126,6 +126,7 @@ var oidcCallbackTemplate = template.Must(
</html>`), </html>`),
) )
// TODO: Why is the entire machine registration logic duplicated here?
// OIDCCallback handles the callback from the OIDC endpoint // OIDCCallback handles the callback from the OIDC endpoint
// Retrieves the mkey from the state cache and adds the machine to the users email namespace // Retrieves the mkey from the state cache and adds the machine to the users email namespace
// TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities // TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities
@ -316,7 +317,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
return return
} }
ip, err := h.getAvailableIP() ips, err := h.getAvailableIPs()
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
@ -330,7 +331,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
return return
} }
machine.IPAddress = ip.String() machine.IPAddresses = ips
machine.NamespaceID = namespace.ID machine.NamespaceID = namespace.ID
machine.Registered = true machine.Registered = true
machine.RegisterMethod = RegisterMethodOIDC machine.RegisterMethod = RegisterMethodOIDC

View file

@ -2,6 +2,7 @@ package headscale
import ( import (
"gopkg.in/check.v1" "gopkg.in/check.v1"
"inet.af/netaddr"
) )
func CreateNodeNamespace( func CreateNodeNamespace(
@ -26,7 +27,7 @@ func CreateNodeNamespace(
NamespaceID: namespace.ID, NamespaceID: namespace.ID,
Registered: true, Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
IPAddress: ip, IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.1")},
AuthKeyID: uint(pak1.ID), AuthKeyID: uint(pak1.ID),
} }
app.db.Save(machine) app.db.Save(machine)
@ -214,7 +215,7 @@ func (s *Suite) TestComplexSharingAcrossNamespaces(c *check.C) {
NamespaceID: namespace1.ID, NamespaceID: namespace1.ID,
Registered: true, Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
IPAddress: "100.64.0.4", IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.4")},
AuthKeyID: uint(pak4.ID), AuthKeyID: uint(pak4.ID),
} }
app.db.Save(machine4) app.db.Save(machine4)
@ -294,7 +295,7 @@ func (s *Suite) TestDeleteSharedMachine(c *check.C) {
NamespaceID: namespace1.ID, NamespaceID: namespace1.ID,
Registered: true, Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
IPAddress: "100.64.0.4", IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.4")},
AuthKeyID: uint(pak4n1.ID), AuthKeyID: uint(pak4n1.ID),
} }
app.db.Save(machine4) app.db.Save(machine4)

View file

@ -133,9 +133,24 @@ func encode(
return privKey.SealTo(*pubKey, b), nil return privKey.SealTo(*pubKey, b), nil
} }
func (h *Headscale) getAvailableIP() (*netaddr.IP, error) { func (h *Headscale) getAvailableIPs() (ips MachineAddresses, err error) {
ipPrefix := h.cfg.IPPrefix ipPrefixes := h.cfg.IPPrefixes
for _, ipPrefix := range ipPrefixes {
var ip *netaddr.IP
ip, err = h.getAvailableIP(ipPrefix)
if err != nil {
return
}
ips = append(ips, *ip)
}
return
}
// TODO: Is this concurrency safe?
// What would happen if multiple hosts were to register at the same time?
// Would we attempt to assign the same addresses to multiple nodes?
func (h *Headscale) getAvailableIP(ipPrefix netaddr.IPPrefix) (*netaddr.IP, error) {
usedIps, err := h.getUsedIPs() usedIps, err := h.getUsedIPs()
if err != nil { if err != nil {
return nil, err return nil, err
@ -143,6 +158,7 @@ func (h *Headscale) getAvailableIP() (*netaddr.IP, error) {
ipPrefixNetworkAddress, ipPrefixBroadcastAddress := func() (netaddr.IP, netaddr.IP) { ipPrefixNetworkAddress, ipPrefixBroadcastAddress := func() (netaddr.IP, netaddr.IP) {
ipRange := ipPrefix.Range() ipRange := ipPrefix.Range()
return ipRange.From(), ipRange.To() return ipRange.From(), ipRange.To()
}() }()
@ -171,19 +187,20 @@ func (h *Headscale) getAvailableIP() (*netaddr.IP, error) {
} }
func (h *Headscale) getUsedIPs() ([]netaddr.IP, error) { func (h *Headscale) getUsedIPs() ([]netaddr.IP, error) {
var addresses []string // FIXME: This really deserves a better data model,
h.db.Model(&Machine{}).Pluck("ip_address", &addresses) // but this was quick to get running and it should be enough
// to begin experimenting with a dual stack tailnet.
var addressesSlices []string
h.db.Model(&Machine{}).Pluck("ip_addresses", &addressesSlices)
ips := make([]netaddr.IP, len(addresses)) ips := make([]netaddr.IP, 0, len(h.cfg.IPPrefixes)*len(addressesSlices))
for index, addr := range addresses { for _, slice := range addressesSlices {
if addr != "" { var a MachineAddresses
ip, err := netaddr.ParseIP(addr) err := a.Scan(slice)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse ip from database: %w", err) return nil, fmt.Errorf("failed to read ip from database: %w", err)
}
ips[index] = ip
} }
ips = append(ips, a...)
} }
return ips, nil return ips, nil

View file

@ -6,17 +6,18 @@ import (
) )
func (s *Suite) TestGetAvailableIp(c *check.C) { func (s *Suite) TestGetAvailableIp(c *check.C) {
ip, err := app.getAvailableIP() ips, err := app.getAvailableIPs()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
expected := netaddr.MustParseIP("10.27.0.1") expected := netaddr.MustParseIP("10.27.0.1")
c.Assert(ip.String(), check.Equals, expected.String()) c.Assert(len(ips), check.Equals, 1)
c.Assert(ips[0].String(), check.Equals, expected.String())
} }
func (s *Suite) TestGetUsedIps(c *check.C) { func (s *Suite) TestGetUsedIps(c *check.C) {
ip, err := app.getAvailableIP() ips, err := app.getAvailableIPs()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
namespace, err := app.CreateNamespace("test_ip") namespace, err := app.CreateNamespace("test_ip")
@ -38,22 +39,24 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
Registered: true, Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
IPAddress: ip.String(), IPAddresses: ips,
} }
app.db.Save(&machine) app.db.Save(&machine)
ips, err := app.getUsedIPs() usedIps, err := app.getUsedIPs()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
expected := netaddr.MustParseIP("10.27.0.1") expected := netaddr.MustParseIP("10.27.0.1")
c.Assert(ips[0], check.Equals, expected) c.Assert(len(usedIps), check.Equals, 1)
c.Assert(usedIps[0], check.Equals, expected)
machine1, err := app.GetMachineByID(0) machine1, err := app.GetMachineByID(0)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(machine1.IPAddress, check.Equals, expected.String()) c.Assert(len(machine1.IPAddresses), check.Equals, 1)
c.Assert(machine1.IPAddresses[0], check.Equals, expected)
} }
func (s *Suite) TestGetMultiIp(c *check.C) { func (s *Suite) TestGetMultiIp(c *check.C) {
@ -61,7 +64,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
for index := 1; index <= 350; index++ { for index := 1; index <= 350; index++ {
ip, err := app.getAvailableIP() ips, err := app.getAvailableIPs()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil) pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
@ -80,59 +83,64 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
Registered: true, Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
IPAddress: ip.String(), IPAddresses: ips,
} }
app.db.Save(&machine) app.db.Save(&machine)
} }
ips, err := app.getUsedIPs() usedIps, err := app.getUsedIPs()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(ips), check.Equals, 350) c.Assert(len(usedIps), check.Equals, 350)
c.Assert(ips[0], check.Equals, netaddr.MustParseIP("10.27.0.1")) c.Assert(usedIps[0], check.Equals, netaddr.MustParseIP("10.27.0.1"))
c.Assert(ips[9], check.Equals, netaddr.MustParseIP("10.27.0.10")) c.Assert(usedIps[9], check.Equals, netaddr.MustParseIP("10.27.0.10"))
c.Assert(ips[300], check.Equals, netaddr.MustParseIP("10.27.1.45")) c.Assert(usedIps[300], check.Equals, netaddr.MustParseIP("10.27.1.45"))
// Check that we can read back the IPs // Check that we can read back the IPs
machine1, err := app.GetMachineByID(1) machine1, err := app.GetMachineByID(1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(machine1.IPAddresses), check.Equals, 1)
c.Assert( c.Assert(
machine1.IPAddress, machine1.IPAddresses[0],
check.Equals, check.Equals,
netaddr.MustParseIP("10.27.0.1").String(), netaddr.MustParseIP("10.27.0.1"),
) )
machine50, err := app.GetMachineByID(50) machine50, err := app.GetMachineByID(50)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(machine50.IPAddresses), check.Equals, 1)
c.Assert( c.Assert(
machine50.IPAddress, machine50.IPAddresses[0],
check.Equals, check.Equals,
netaddr.MustParseIP("10.27.0.50").String(), netaddr.MustParseIP("10.27.0.50"),
) )
expectedNextIP := netaddr.MustParseIP("10.27.1.95") expectedNextIP := netaddr.MustParseIP("10.27.1.95")
nextIP, err := app.getAvailableIP() nextIP, err := app.getAvailableIPs()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(nextIP.String(), check.Equals, expectedNextIP.String()) c.Assert(len(nextIP), check.Equals, 1)
c.Assert(nextIP[0].String(), check.Equals, expectedNextIP.String())
// If we call get Available again, we should receive // If we call get Available again, we should receive
// the same IP, as it has not been reserved. // the same IP, as it has not been reserved.
nextIP2, err := app.getAvailableIP() nextIP2, err := app.getAvailableIPs()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(nextIP2.String(), check.Equals, expectedNextIP.String()) c.Assert(len(nextIP2), check.Equals, 1)
c.Assert(nextIP2[0].String(), check.Equals, expectedNextIP.String())
} }
func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) { func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
ip, err := app.getAvailableIP() ips, err := app.getAvailableIPs()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
expected := netaddr.MustParseIP("10.27.0.1") expected := netaddr.MustParseIP("10.27.0.1")
c.Assert(ip.String(), check.Equals, expected.String()) c.Assert(len(ips), check.Equals, 1)
c.Assert(ips[0].String(), check.Equals, expected.String())
namespace, err := app.CreateNamespace("test_ip") namespace, err := app.CreateNamespace("test_ip")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -156,8 +164,9 @@ func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
} }
app.db.Save(&machine) app.db.Save(&machine)
ip2, err := app.getAvailableIP() ips2, err := app.getAvailableIPs()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(ip2.String(), check.Equals, expected.String()) c.Assert(len(ips2), check.Equals, 1)
c.Assert(ips2[0].String(), check.Equals, expected.String())
} }