mirror of
https://github.com/juanfont/headscale.git
synced 2024-11-26 08:53:05 +00:00
use gorm serialiser instead of custom hooks (#2156)
* add sqlite to debug/test image Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * test using gorm serialiser instead of custom hooks Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> --------- Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
3964dec1c6
commit
bc9e83b52e
21 changed files with 240 additions and 351 deletions
|
@ -8,7 +8,7 @@ ENV GOPATH /go
|
||||||
WORKDIR /go/src/headscale
|
WORKDIR /go/src/headscale
|
||||||
|
|
||||||
RUN apt-get update \
|
RUN apt-get update \
|
||||||
&& apt-get install --no-install-recommends --yes less jq \
|
&& apt-get install --no-install-recommends --yes less jq sqlite3 \
|
||||||
&& rm -rf /var/lib/apt/lists/* \
|
&& rm -rf /var/lib/apt/lists/* \
|
||||||
&& apt-get clean
|
&& apt-get clean
|
||||||
RUN mkdir -p /var/run/headscale
|
RUN mkdir -p /var/run/headscale
|
||||||
|
|
|
@ -20,9 +20,14 @@ import (
|
||||||
"gorm.io/driver/postgres"
|
"gorm.io/driver/postgres"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"gorm.io/gorm/logger"
|
"gorm.io/gorm/logger"
|
||||||
|
"gorm.io/gorm/schema"
|
||||||
"tailscale.com/util/set"
|
"tailscale.com/util/set"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
schema.RegisterSerializer("text", TextSerialiser{})
|
||||||
|
}
|
||||||
|
|
||||||
var errDatabaseNotSupported = errors.New("database type not supported")
|
var errDatabaseNotSupported = errors.New("database type not supported")
|
||||||
|
|
||||||
// KV is a key-value store in a psql table. For future use...
|
// KV is a key-value store in a psql table. For future use...
|
||||||
|
@ -34,6 +39,7 @@ type KV struct {
|
||||||
|
|
||||||
type HSDatabase struct {
|
type HSDatabase struct {
|
||||||
DB *gorm.DB
|
DB *gorm.DB
|
||||||
|
cfg *types.DatabaseConfig
|
||||||
|
|
||||||
baseDomain string
|
baseDomain string
|
||||||
}
|
}
|
||||||
|
@ -191,7 +197,7 @@ func NewHeadscaleDatabase(
|
||||||
|
|
||||||
type NodeAux struct {
|
type NodeAux struct {
|
||||||
ID uint64
|
ID uint64
|
||||||
EnabledRoutes types.IPPrefixes
|
EnabledRoutes []netip.Prefix `gorm:"serializer:json"`
|
||||||
}
|
}
|
||||||
|
|
||||||
nodesAux := []NodeAux{}
|
nodesAux := []NodeAux{}
|
||||||
|
@ -214,7 +220,7 @@ func NewHeadscaleDatabase(
|
||||||
}
|
}
|
||||||
|
|
||||||
err = tx.Preload("Node").
|
err = tx.Preload("Node").
|
||||||
Where("node_id = ? AND prefix = ?", node.ID, types.IPPrefix(prefix)).
|
Where("node_id = ? AND prefix = ?", node.ID, prefix).
|
||||||
First(&types.Route{}).
|
First(&types.Route{}).
|
||||||
Error
|
Error
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
@ -229,7 +235,7 @@ func NewHeadscaleDatabase(
|
||||||
NodeID: node.ID,
|
NodeID: node.ID,
|
||||||
Advertised: true,
|
Advertised: true,
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Prefix: types.IPPrefix(prefix),
|
Prefix: prefix,
|
||||||
}
|
}
|
||||||
if err := tx.Create(&route).Error; err != nil {
|
if err := tx.Create(&route).Error; err != nil {
|
||||||
log.Error().Err(err).Msg("Error creating route")
|
log.Error().Err(err).Msg("Error creating route")
|
||||||
|
@ -477,6 +483,7 @@ func NewHeadscaleDatabase(
|
||||||
|
|
||||||
db := HSDatabase{
|
db := HSDatabase{
|
||||||
DB: dbConn,
|
DB: dbConn,
|
||||||
|
cfg: &cfg,
|
||||||
|
|
||||||
baseDomain: baseDomain,
|
baseDomain: baseDomain,
|
||||||
}
|
}
|
||||||
|
@ -676,6 +683,10 @@ func (hsdb *HSDatabase) Close() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if hsdb.cfg.Type == types.DatabaseSqlite && hsdb.cfg.Sqlite.WriteAheadLog {
|
||||||
|
db.Exec("VACUUM")
|
||||||
|
}
|
||||||
|
|
||||||
return db.Close()
|
return db.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -13,13 +13,14 @@ import (
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/google/go-cmp/cmp/cmpopts"
|
"github.com/google/go-cmp/cmp/cmpopts"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMigrations(t *testing.T) {
|
func TestMigrations(t *testing.T) {
|
||||||
ipp := func(p string) types.IPPrefix {
|
ipp := func(p string) netip.Prefix {
|
||||||
return types.IPPrefix(netip.MustParsePrefix(p))
|
return netip.MustParsePrefix(p)
|
||||||
}
|
}
|
||||||
r := func(id uint64, p string, a, e, i bool) types.Route {
|
r := func(id uint64, p string, a, e, i bool) types.Route {
|
||||||
return types.Route{
|
return types.Route{
|
||||||
|
@ -56,9 +57,7 @@ func TestMigrations(t *testing.T) {
|
||||||
r(31, "::/0", true, false, false),
|
r(31, "::/0", true, false, false),
|
||||||
r(32, "192.168.0.24/32", true, true, true),
|
r(32, "192.168.0.24/32", true, true, true),
|
||||||
}
|
}
|
||||||
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), cmp.Comparer(func(x, y types.IPPrefix) bool {
|
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), util.PrefixComparer); diff != "" {
|
||||||
return x == y
|
|
||||||
})); diff != "" {
|
|
||||||
t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff)
|
t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -103,9 +102,7 @@ func TestMigrations(t *testing.T) {
|
||||||
r(13, "::/0", true, true, false),
|
r(13, "::/0", true, true, false),
|
||||||
r(13, "10.18.80.2/32", true, true, true),
|
r(13, "10.18.80.2/32", true, true, true),
|
||||||
}
|
}
|
||||||
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), cmp.Comparer(func(x, y types.IPPrefix) bool {
|
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), util.PrefixComparer); diff != "" {
|
||||||
return x == y
|
|
||||||
})); diff != "" {
|
|
||||||
t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff)
|
t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -172,6 +169,29 @@ func TestMigrations(t *testing.T) {
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
dbPath: "testdata/0-23-0-to-0-24-0-no-more-special-types.sqlite",
|
||||||
|
wantFunc: func(t *testing.T, h *HSDatabase) {
|
||||||
|
nodes, err := Read(h.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
||||||
|
return ListNodes(rx)
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
for _, node := range nodes {
|
||||||
|
assert.Falsef(t, node.MachineKey.IsZero(), "expected non zero machinekey")
|
||||||
|
assert.Contains(t, node.MachineKey.String(), "mkey:")
|
||||||
|
assert.Falsef(t, node.NodeKey.IsZero(), "expected non zero nodekey")
|
||||||
|
assert.Contains(t, node.NodeKey.String(), "nodekey:")
|
||||||
|
assert.Falsef(t, node.DiscoKey.IsZero(), "expected non zero discokey")
|
||||||
|
assert.Contains(t, node.DiscoKey.String(), "discokey:")
|
||||||
|
assert.NotNil(t, node.IPv4)
|
||||||
|
assert.NotNil(t, node.IPv4)
|
||||||
|
assert.Len(t, node.Endpoints, 1)
|
||||||
|
assert.NotNil(t, node.Hostinfo)
|
||||||
|
assert.NotNil(t, node.MachineKey)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -294,15 +293,7 @@ func TestBackfillIPAddresses(t *testing.T) {
|
||||||
v4 := fmt.Sprintf("100.64.0.%d", i)
|
v4 := fmt.Sprintf("100.64.0.%d", i)
|
||||||
v6 := fmt.Sprintf("fd7a:115c:a1e0::%d", i)
|
v6 := fmt.Sprintf("fd7a:115c:a1e0::%d", i)
|
||||||
return &types.Node{
|
return &types.Node{
|
||||||
IPv4DatabaseField: sql.NullString{
|
|
||||||
Valid: true,
|
|
||||||
String: v4,
|
|
||||||
},
|
|
||||||
IPv4: nap(v4),
|
IPv4: nap(v4),
|
||||||
IPv6DatabaseField: sql.NullString{
|
|
||||||
Valid: true,
|
|
||||||
String: v6,
|
|
||||||
},
|
|
||||||
IPv6: nap(v6),
|
IPv6: nap(v6),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -334,15 +325,7 @@ func TestBackfillIPAddresses(t *testing.T) {
|
||||||
|
|
||||||
want: types.Nodes{
|
want: types.Nodes{
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4DatabaseField: sql.NullString{
|
|
||||||
Valid: true,
|
|
||||||
String: "100.64.0.1",
|
|
||||||
},
|
|
||||||
IPv4: nap("100.64.0.1"),
|
IPv4: nap("100.64.0.1"),
|
||||||
IPv6DatabaseField: sql.NullString{
|
|
||||||
Valid: true,
|
|
||||||
String: "fd7a:115c:a1e0::1",
|
|
||||||
},
|
|
||||||
IPv6: nap("fd7a:115c:a1e0::1"),
|
IPv6: nap("fd7a:115c:a1e0::1"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -367,15 +350,7 @@ func TestBackfillIPAddresses(t *testing.T) {
|
||||||
|
|
||||||
want: types.Nodes{
|
want: types.Nodes{
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4DatabaseField: sql.NullString{
|
|
||||||
Valid: true,
|
|
||||||
String: "100.64.0.1",
|
|
||||||
},
|
|
||||||
IPv4: nap("100.64.0.1"),
|
IPv4: nap("100.64.0.1"),
|
||||||
IPv6DatabaseField: sql.NullString{
|
|
||||||
Valid: true,
|
|
||||||
String: "fd7a:115c:a1e0::1",
|
|
||||||
},
|
|
||||||
IPv6: nap("fd7a:115c:a1e0::1"),
|
IPv6: nap("fd7a:115c:a1e0::1"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -400,10 +375,6 @@ func TestBackfillIPAddresses(t *testing.T) {
|
||||||
|
|
||||||
want: types.Nodes{
|
want: types.Nodes{
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4DatabaseField: sql.NullString{
|
|
||||||
Valid: true,
|
|
||||||
String: "100.64.0.1",
|
|
||||||
},
|
|
||||||
IPv4: nap("100.64.0.1"),
|
IPv4: nap("100.64.0.1"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -428,10 +399,6 @@ func TestBackfillIPAddresses(t *testing.T) {
|
||||||
|
|
||||||
want: types.Nodes{
|
want: types.Nodes{
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv6DatabaseField: sql.NullString{
|
|
||||||
Valid: true,
|
|
||||||
String: "fd7a:115c:a1e0::1",
|
|
||||||
},
|
|
||||||
IPv6: nap("fd7a:115c:a1e0::1"),
|
IPv6: nap("fd7a:115c:a1e0::1"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -477,13 +444,9 @@ func TestBackfillIPAddresses(t *testing.T) {
|
||||||
|
|
||||||
comps := append(util.Comparers, cmpopts.IgnoreFields(types.Node{},
|
comps := append(util.Comparers, cmpopts.IgnoreFields(types.Node{},
|
||||||
"ID",
|
"ID",
|
||||||
"MachineKeyDatabaseField",
|
|
||||||
"NodeKeyDatabaseField",
|
|
||||||
"DiscoKeyDatabaseField",
|
|
||||||
"User",
|
"User",
|
||||||
"UserID",
|
"UserID",
|
||||||
"Endpoints",
|
"Endpoints",
|
||||||
"HostinfoDatabaseField",
|
|
||||||
"Hostinfo",
|
"Hostinfo",
|
||||||
"Routes",
|
"Routes",
|
||||||
"CreatedAt",
|
"CreatedAt",
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
@ -207,21 +208,26 @@ func SetTags(
|
||||||
) error {
|
) error {
|
||||||
if len(tags) == 0 {
|
if len(tags) == 0 {
|
||||||
// if no tags are provided, we remove all forced tags
|
// if no tags are provided, we remove all forced tags
|
||||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", types.StringList{}).Error; err != nil {
|
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", "[]").Error; err != nil {
|
||||||
return fmt.Errorf("failed to remove tags for node in the database: %w", err)
|
return fmt.Errorf("failed to remove tags for node in the database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var newTags types.StringList
|
var newTags []string
|
||||||
for _, tag := range tags {
|
for _, tag := range tags {
|
||||||
if !slices.Contains(newTags, tag) {
|
if !slices.Contains(newTags, tag) {
|
||||||
newTags = append(newTags, tag)
|
newTags = append(newTags, tag)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", newTags).Error; err != nil {
|
b, err := json.Marshal(newTags)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", string(b)).Error; err != nil {
|
||||||
return fmt.Errorf("failed to update tags for node in the database: %w", err)
|
return fmt.Errorf("failed to update tags for node in the database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -569,7 +575,7 @@ func enableRoutes(tx *gorm.DB,
|
||||||
for _, prefix := range newRoutes {
|
for _, prefix := range newRoutes {
|
||||||
route := types.Route{}
|
route := types.Route{}
|
||||||
err := tx.Preload("Node").
|
err := tx.Preload("Node").
|
||||||
Where("node_id = ? AND prefix = ?", node.ID, types.IPPrefix(prefix)).
|
Where("node_id = ? AND prefix = ?", node.ID, prefix.String()).
|
||||||
First(&route).Error
|
First(&route).Error
|
||||||
if err == nil {
|
if err == nil {
|
||||||
route.Enabled = true
|
route.Enabled = true
|
||||||
|
|
|
@ -201,7 +201,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
||||||
nodeKey := key.NewNode()
|
nodeKey := key.NewNode()
|
||||||
machineKey := key.NewMachine()
|
machineKey := key.NewMachine()
|
||||||
|
|
||||||
v4 := netip.MustParseAddr(fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1)))
|
v4 := netip.MustParseAddr(fmt.Sprintf("100.64.0.%d", index+1))
|
||||||
node := types.Node{
|
node := types.Node{
|
||||||
ID: types.NodeID(index),
|
ID: types.NodeID(index),
|
||||||
MachineKey: machineKey.Public(),
|
MachineKey: machineKey.Public(),
|
||||||
|
@ -239,6 +239,8 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
||||||
|
|
||||||
adminNode, err := db.GetNodeByID(1)
|
adminNode, err := db.GetNodeByID(1)
|
||||||
c.Logf("Node(%v), user: %v", adminNode.Hostname, adminNode.User)
|
c.Logf("Node(%v), user: %v", adminNode.Hostname, adminNode.User)
|
||||||
|
c.Assert(adminNode.IPv4, check.NotNil)
|
||||||
|
c.Assert(adminNode.IPv6, check.IsNil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
testNode, err := db.GetNodeByID(2)
|
testNode, err := db.GetNodeByID(2)
|
||||||
|
@ -247,9 +249,11 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
||||||
|
|
||||||
adminPeers, err := db.ListPeers(adminNode.ID)
|
adminPeers, err := db.ListPeers(adminNode.ID)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(len(adminPeers), check.Equals, 9)
|
||||||
|
|
||||||
testPeers, err := db.ListPeers(testNode.ID)
|
testPeers, err := db.ListPeers(testNode.ID)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(len(testPeers), check.Equals, 9)
|
||||||
|
|
||||||
adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers)
|
adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
@ -259,14 +263,14 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
||||||
|
|
||||||
peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules)
|
peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules)
|
||||||
peersOfTestNode := policy.FilterNodesByACL(testNode, testPeers, testRules)
|
peersOfTestNode := policy.FilterNodesByACL(testNode, testPeers, testRules)
|
||||||
|
c.Log(peersOfAdminNode)
|
||||||
c.Log(peersOfTestNode)
|
c.Log(peersOfTestNode)
|
||||||
|
|
||||||
c.Assert(len(peersOfTestNode), check.Equals, 9)
|
c.Assert(len(peersOfTestNode), check.Equals, 9)
|
||||||
c.Assert(peersOfTestNode[0].Hostname, check.Equals, "testnode1")
|
c.Assert(peersOfTestNode[0].Hostname, check.Equals, "testnode1")
|
||||||
c.Assert(peersOfTestNode[1].Hostname, check.Equals, "testnode3")
|
c.Assert(peersOfTestNode[1].Hostname, check.Equals, "testnode3")
|
||||||
c.Assert(peersOfTestNode[3].Hostname, check.Equals, "testnode5")
|
c.Assert(peersOfTestNode[3].Hostname, check.Equals, "testnode5")
|
||||||
|
|
||||||
c.Log(peersOfAdminNode)
|
|
||||||
c.Assert(len(peersOfAdminNode), check.Equals, 9)
|
c.Assert(len(peersOfAdminNode), check.Equals, 9)
|
||||||
c.Assert(peersOfAdminNode[0].Hostname, check.Equals, "testnode2")
|
c.Assert(peersOfAdminNode[0].Hostname, check.Equals, "testnode2")
|
||||||
c.Assert(peersOfAdminNode[2].Hostname, check.Equals, "testnode4")
|
c.Assert(peersOfAdminNode[2].Hostname, check.Equals, "testnode4")
|
||||||
|
@ -346,7 +350,7 @@ func (s *Suite) TestSetTags(c *check.C) {
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
node, err = db.getNode("test", "testnode")
|
node, err = db.getNode("test", "testnode")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(node.ForcedTags, check.DeepEquals, types.StringList(sTags))
|
c.Assert(node.ForcedTags, check.DeepEquals, sTags)
|
||||||
|
|
||||||
// assign duplicate tags, expect no errors but no doubles in DB
|
// assign duplicate tags, expect no errors but no doubles in DB
|
||||||
eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"}
|
eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"}
|
||||||
|
@ -357,7 +361,7 @@ func (s *Suite) TestSetTags(c *check.C) {
|
||||||
c.Assert(
|
c.Assert(
|
||||||
node.ForcedTags,
|
node.ForcedTags,
|
||||||
check.DeepEquals,
|
check.DeepEquals,
|
||||||
types.StringList([]string{"tag:bar", "tag:test", "tag:unknown"}),
|
[]string{"tag:bar", "tag:test", "tag:unknown"},
|
||||||
)
|
)
|
||||||
|
|
||||||
// test removing tags
|
// test removing tags
|
||||||
|
@ -365,7 +369,7 @@ func (s *Suite) TestSetTags(c *check.C) {
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
node, err = db.getNode("test", "testnode")
|
node, err = db.getNode("test", "testnode")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(node.ForcedTags, check.DeepEquals, types.StringList([]string{}))
|
c.Assert(node.ForcedTags, check.DeepEquals, []string{})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHeadscale_generateGivenName(t *testing.T) {
|
func TestHeadscale_generateGivenName(t *testing.T) {
|
||||||
|
|
|
@ -77,7 +77,7 @@ func CreatePreAuthKey(
|
||||||
Ephemeral: ephemeral,
|
Ephemeral: ephemeral,
|
||||||
CreatedAt: &now,
|
CreatedAt: &now,
|
||||||
Expiration: expiration,
|
Expiration: expiration,
|
||||||
Tags: types.StringList(aclTags),
|
Tags: aclTags,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tx.Save(&key).Error; err != nil {
|
if err := tx.Save(&key).Error; err != nil {
|
||||||
|
|
|
@ -49,7 +49,7 @@ func getRoutesByPrefix(tx *gorm.DB, pref netip.Prefix) (types.Routes, error) {
|
||||||
err := tx.
|
err := tx.
|
||||||
Preload("Node").
|
Preload("Node").
|
||||||
Preload("Node.User").
|
Preload("Node.User").
|
||||||
Where("prefix = ?", types.IPPrefix(pref)).
|
Where("prefix = ?", pref.String()).
|
||||||
Find(&routes).Error
|
Find(&routes).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -286,7 +286,7 @@ func isUniquePrefix(tx *gorm.DB, route types.Route) bool {
|
||||||
var count int64
|
var count int64
|
||||||
tx.Model(&types.Route{}).
|
tx.Model(&types.Route{}).
|
||||||
Where("prefix = ? AND node_id != ? AND advertised = ? AND enabled = ?",
|
Where("prefix = ? AND node_id != ? AND advertised = ? AND enabled = ?",
|
||||||
route.Prefix,
|
route.Prefix.String(),
|
||||||
route.NodeID,
|
route.NodeID,
|
||||||
true, true).Count(&count)
|
true, true).Count(&count)
|
||||||
|
|
||||||
|
@ -297,7 +297,7 @@ func getPrimaryRoute(tx *gorm.DB, prefix netip.Prefix) (*types.Route, error) {
|
||||||
var route types.Route
|
var route types.Route
|
||||||
err := tx.
|
err := tx.
|
||||||
Preload("Node").
|
Preload("Node").
|
||||||
Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", types.IPPrefix(prefix), true, true, true).
|
Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", prefix.String(), true, true, true).
|
||||||
First(&route).Error
|
First(&route).Error
|
||||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -392,7 +392,7 @@ func SaveNodeRoutes(tx *gorm.DB, node *types.Node) (bool, error) {
|
||||||
if !exists {
|
if !exists {
|
||||||
route := types.Route{
|
route := types.Route{
|
||||||
NodeID: node.ID.Uint64(),
|
NodeID: node.ID.Uint64(),
|
||||||
Prefix: types.IPPrefix(prefix),
|
Prefix: prefix,
|
||||||
Advertised: true,
|
Advertised: true,
|
||||||
Enabled: false,
|
Enabled: false,
|
||||||
}
|
}
|
||||||
|
|
|
@ -290,7 +290,7 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ipp = func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) }
|
ipp = func(s string) netip.Prefix { return netip.MustParsePrefix(s) }
|
||||||
mkNode = func(nid types.NodeID) types.Node {
|
mkNode = func(nid types.NodeID) types.Node {
|
||||||
return types.Node{ID: nid}
|
return types.Node{ID: nid}
|
||||||
}
|
}
|
||||||
|
@ -301,7 +301,7 @@ var np = func(nid types.NodeID) *types.Node {
|
||||||
return &no
|
return &no
|
||||||
}
|
}
|
||||||
|
|
||||||
var r = func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) types.Route {
|
var r = func(id uint, nid types.NodeID, prefix netip.Prefix, enabled, primary bool) types.Route {
|
||||||
return types.Route{
|
return types.Route{
|
||||||
Model: gorm.Model{
|
Model: gorm.Model{
|
||||||
ID: id,
|
ID: id,
|
||||||
|
@ -313,7 +313,7 @@ var r = func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var rp = func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) *types.Route {
|
var rp = func(id uint, nid types.NodeID, prefix netip.Prefix, enabled, primary bool) *types.Route {
|
||||||
ro := r(id, nid, prefix, enabled, primary)
|
ro := r(id, nid, prefix, enabled, primary)
|
||||||
return &ro
|
return &ro
|
||||||
}
|
}
|
||||||
|
@ -1069,7 +1069,7 @@ func TestFailoverRouteTx(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFailoverRoute(t *testing.T) {
|
func TestFailoverRoute(t *testing.T) {
|
||||||
r := func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) types.Route {
|
r := func(id uint, nid types.NodeID, prefix netip.Prefix, enabled, primary bool) types.Route {
|
||||||
return types.Route{
|
return types.Route{
|
||||||
Model: gorm.Model{
|
Model: gorm.Model{
|
||||||
ID: id,
|
ID: id,
|
||||||
|
@ -1082,7 +1082,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
IsPrimary: primary,
|
IsPrimary: primary,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
rp := func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) *types.Route {
|
rp := func(id uint, nid types.NodeID, prefix netip.Prefix, enabled, primary bool) *types.Route {
|
||||||
ro := r(id, nid, prefix, enabled, primary)
|
ro := r(id, nid, prefix, enabled, primary)
|
||||||
return &ro
|
return &ro
|
||||||
}
|
}
|
||||||
|
@ -1205,13 +1205,6 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
cmps := append(
|
|
||||||
util.Comparers,
|
|
||||||
cmp.Comparer(func(x, y types.IPPrefix) bool {
|
|
||||||
return netip.Prefix(x) == netip.Prefix(y)
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
gotf := failoverRoute(smap(tt.isConnected), &tt.failingRoute, tt.routes)
|
gotf := failoverRoute(smap(tt.isConnected), &tt.failingRoute, tt.routes)
|
||||||
|
@ -1235,7 +1228,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
"old": gotf.old,
|
"old": gotf.old,
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(want, got, cmps...); diff != "" {
|
if diff := cmp.Diff(want, got, util.Comparers...); diff != "" {
|
||||||
t.Fatalf("failoverRoute unexpected result (-want +got):\n%s", diff)
|
t.Fatalf("failoverRoute unexpected result (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
BIN
hscontrol/db/testdata/0-23-0-to-0-24-0-no-more-special-types.sqlite
vendored
Normal file
BIN
hscontrol/db/testdata/0-23-0-to-0-24-0-no-more-special-types.sqlite
vendored
Normal file
Binary file not shown.
99
hscontrol/db/text_serialiser.go
Normal file
99
hscontrol/db/text_serialiser.go
Normal file
|
@ -0,0 +1,99 @@
|
||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
|
||||||
|
"gorm.io/gorm/schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Got from https://github.com/xdg-go/strum/blob/main/types.go
|
||||||
|
var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
|
||||||
|
|
||||||
|
func isTextUnmarshaler(rv reflect.Value) bool {
|
||||||
|
return rv.Type().Implements(textUnmarshalerType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func maybeInstantiatePtr(rv reflect.Value) {
|
||||||
|
if rv.Kind() == reflect.Ptr && rv.IsNil() {
|
||||||
|
np := reflect.New(rv.Type().Elem())
|
||||||
|
rv.Set(np)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodingError(name string, err error) error {
|
||||||
|
return fmt.Errorf("error decoding to %s: %w", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TextSerialiser implements the Serialiser interface for fields that
|
||||||
|
// have a type that implements encoding.TextUnmarshaler.
|
||||||
|
type TextSerialiser struct{}
|
||||||
|
|
||||||
|
func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue interface{}) (err error) {
|
||||||
|
fieldValue := reflect.New(field.FieldType)
|
||||||
|
|
||||||
|
// If the field is a pointer, we need to dereference it to get the actual type
|
||||||
|
// so we do not end with a second pointer.
|
||||||
|
if fieldValue.Elem().Kind() == reflect.Ptr {
|
||||||
|
fieldValue = fieldValue.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if dbValue != nil {
|
||||||
|
var bytes []byte
|
||||||
|
switch v := dbValue.(type) {
|
||||||
|
case []byte:
|
||||||
|
bytes = v
|
||||||
|
case string:
|
||||||
|
bytes = []byte(v)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("failed to unmarshal text value: %#v", dbValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
if isTextUnmarshaler(fieldValue) {
|
||||||
|
maybeInstantiatePtr(fieldValue)
|
||||||
|
f := fieldValue.MethodByName("UnmarshalText")
|
||||||
|
args := []reflect.Value{reflect.ValueOf(bytes)}
|
||||||
|
ret := f.Call(args)
|
||||||
|
if !ret[0].IsNil() {
|
||||||
|
return decodingError(field.Name, ret[0].Interface().(error))
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the underlying field is to a pointer type, we need to
|
||||||
|
// assign the value as a pointer to it.
|
||||||
|
// If it is not a pointer, we need to assign the value to the
|
||||||
|
// field.
|
||||||
|
dstField := field.ReflectValueOf(ctx, dst)
|
||||||
|
if dstField.Kind() == reflect.Ptr {
|
||||||
|
dstField.Set(fieldValue)
|
||||||
|
} else {
|
||||||
|
dstField.Set(fieldValue.Elem())
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("unsupported type: %T", fieldValue.Interface())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
|
||||||
|
switch v := fieldValue.(type) {
|
||||||
|
case encoding.TextMarshaler:
|
||||||
|
// If the value is nil, we return nil, however, go nil values are not
|
||||||
|
// always comparable, particularly when reflection is involved:
|
||||||
|
// https://dev.to/arxeiss/in-go-nil-is-not-equal-to-nil-sometimes-jn8
|
||||||
|
if v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil()) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
b, err := v.MarshalText()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return string(b), nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("only encoding.TextMarshaler is supported, got %t", v)
|
||||||
|
}
|
||||||
|
}
|
|
@ -196,19 +196,19 @@ func Test_fullMapResponse(t *testing.T) {
|
||||||
Hostinfo: &tailcfg.Hostinfo{},
|
Hostinfo: &tailcfg.Hostinfo{},
|
||||||
Routes: []types.Route{
|
Routes: []types.Route{
|
||||||
{
|
{
|
||||||
Prefix: types.IPPrefix(tsaddr.AllIPv4()),
|
Prefix: tsaddr.AllIPv4(),
|
||||||
Advertised: true,
|
Advertised: true,
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
IsPrimary: false,
|
IsPrimary: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Prefix: types.IPPrefix(netip.MustParsePrefix("192.168.0.0/24")),
|
Prefix: netip.MustParsePrefix("192.168.0.0/24"),
|
||||||
Advertised: true,
|
Advertised: true,
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
IsPrimary: true,
|
IsPrimary: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Prefix: types.IPPrefix(netip.MustParsePrefix("172.0.0.0/10")),
|
Prefix: netip.MustParsePrefix("172.0.0.0/10"),
|
||||||
Advertised: true,
|
Advertised: true,
|
||||||
Enabled: false,
|
Enabled: false,
|
||||||
IsPrimary: true,
|
IsPrimary: true,
|
||||||
|
|
|
@ -109,19 +109,19 @@ func TestTailNode(t *testing.T) {
|
||||||
Hostinfo: &tailcfg.Hostinfo{},
|
Hostinfo: &tailcfg.Hostinfo{},
|
||||||
Routes: []types.Route{
|
Routes: []types.Route{
|
||||||
{
|
{
|
||||||
Prefix: types.IPPrefix(tsaddr.AllIPv4()),
|
Prefix: tsaddr.AllIPv4(),
|
||||||
Advertised: true,
|
Advertised: true,
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
IsPrimary: false,
|
IsPrimary: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Prefix: types.IPPrefix(netip.MustParsePrefix("192.168.0.0/24")),
|
Prefix: netip.MustParsePrefix("192.168.0.0/24"),
|
||||||
Advertised: true,
|
Advertised: true,
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
IsPrimary: true,
|
IsPrimary: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Prefix: types.IPPrefix(netip.MustParsePrefix("172.0.0.0/10")),
|
Prefix: netip.MustParsePrefix("172.0.0.0/10"),
|
||||||
Advertised: true,
|
Advertised: true,
|
||||||
Enabled: false,
|
Enabled: false,
|
||||||
IsPrimary: true,
|
IsPrimary: true,
|
||||||
|
|
|
@ -595,6 +595,11 @@ func (pol *ACLPolicy) ExpandAlias(
|
||||||
// excludeCorrectlyTaggedNodes will remove from the list of input nodes the ones
|
// excludeCorrectlyTaggedNodes will remove from the list of input nodes the ones
|
||||||
// that are correctly tagged since they should not be listed as being in the user
|
// that are correctly tagged since they should not be listed as being in the user
|
||||||
// we assume in this function that we only have nodes from 1 user.
|
// we assume in this function that we only have nodes from 1 user.
|
||||||
|
//
|
||||||
|
// TODO(kradalby): It is quite hard to understand what this function is doing,
|
||||||
|
// it seems like it trying to ensure that we dont include nodes that are tagged
|
||||||
|
// when we look up the nodes owned by a user.
|
||||||
|
// This should be refactored to be more clear as part of the Tags work in #1369
|
||||||
func excludeCorrectlyTaggedNodes(
|
func excludeCorrectlyTaggedNodes(
|
||||||
aclPolicy *ACLPolicy,
|
aclPolicy *ACLPolicy,
|
||||||
nodes types.Nodes,
|
nodes types.Nodes,
|
||||||
|
@ -613,10 +618,7 @@ func excludeCorrectlyTaggedNodes(
|
||||||
for _, node := range nodes {
|
for _, node := range nodes {
|
||||||
found := false
|
found := false
|
||||||
|
|
||||||
if node.Hostinfo == nil {
|
if node.Hostinfo != nil {
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, t := range node.Hostinfo.RequestTags {
|
for _, t := range node.Hostinfo.RequestTags {
|
||||||
if slices.Contains(tags, t) {
|
if slices.Contains(tags, t) {
|
||||||
found = true
|
found = true
|
||||||
|
@ -624,6 +626,8 @@ func excludeCorrectlyTaggedNodes(
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if len(node.ForcedTags) > 0 {
|
if len(node.ForcedTags) > 0 {
|
||||||
found = true
|
found = true
|
||||||
}
|
}
|
||||||
|
@ -981,7 +985,10 @@ func FilterNodesByACL(
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Printf("Checking if %s can access %s", node.Hostname, peer.Hostname)
|
||||||
|
|
||||||
if node.CanAccess(filter, nodes[index]) || peer.CanAccess(filter, node) {
|
if node.CanAccess(filter, nodes[index]) || peer.CanAccess(filter, node) {
|
||||||
|
log.Printf("CAN ACCESS %s can access %s", node.Hostname, peer.Hostname)
|
||||||
result = append(result, peer)
|
result = append(result, peer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2385,7 +2385,7 @@ func TestReduceFilterRules(t *testing.T) {
|
||||||
Hostinfo: &tailcfg.Hostinfo{
|
Hostinfo: &tailcfg.Hostinfo{
|
||||||
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")},
|
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")},
|
||||||
},
|
},
|
||||||
ForcedTags: types.StringList{"tag:access-servers"},
|
ForcedTags: []string{"tag:access-servers"},
|
||||||
},
|
},
|
||||||
peers: types.Nodes{
|
peers: types.Nodes{
|
||||||
&types.Node{
|
&types.Node{
|
||||||
|
@ -3182,7 +3182,7 @@ func Test_getFilteredByACLPeers(t *testing.T) {
|
||||||
Routes: types.Routes{
|
Routes: types.Routes{
|
||||||
types.Route{
|
types.Route{
|
||||||
NodeID: 2,
|
NodeID: 2,
|
||||||
Prefix: types.IPPrefix(netip.MustParsePrefix("10.33.0.0/16")),
|
Prefix: netip.MustParsePrefix("10.33.0.0/16"),
|
||||||
IsPrimary: true,
|
IsPrimary: true,
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
},
|
},
|
||||||
|
@ -3215,7 +3215,7 @@ func Test_getFilteredByACLPeers(t *testing.T) {
|
||||||
Routes: types.Routes{
|
Routes: types.Routes{
|
||||||
types.Route{
|
types.Route{
|
||||||
NodeID: 2,
|
NodeID: 2,
|
||||||
Prefix: types.IPPrefix(netip.MustParsePrefix("10.33.0.0/16")),
|
Prefix: netip.MustParsePrefix("10.33.0.0/16"),
|
||||||
IsPrimary: true,
|
IsPrimary: true,
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
},
|
},
|
||||||
|
@ -3225,13 +3225,6 @@ func Test_getFilteredByACLPeers(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(kradalby): Remove when we have gotten rid of IPPrefix type
|
|
||||||
prefixComparer := cmp.Comparer(func(x, y types.IPPrefix) bool {
|
|
||||||
return x == y
|
|
||||||
})
|
|
||||||
comparers := append([]cmp.Option{}, util.Comparers...)
|
|
||||||
comparers = append(comparers, prefixComparer)
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got := FilterNodesByACL(
|
got := FilterNodesByACL(
|
||||||
|
@ -3239,7 +3232,7 @@ func Test_getFilteredByACLPeers(t *testing.T) {
|
||||||
tt.args.nodes,
|
tt.args.nodes,
|
||||||
tt.args.rules,
|
tt.args.rules,
|
||||||
)
|
)
|
||||||
if diff := cmp.Diff(tt.want, got, comparers...); diff != "" {
|
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
|
||||||
t.Errorf("FilterNodesByACL() unexpected result (-want +got):\n%s", diff)
|
t.Errorf("FilterNodesByACL() unexpected result (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand/v2"
|
"math/rand/v2"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
@ -448,13 +449,13 @@ func (m *mapSession) handleEndpointUpdate() {
|
||||||
sendUpdate, routesChanged := hostInfoChanged(m.node.Hostinfo, m.req.Hostinfo)
|
sendUpdate, routesChanged := hostInfoChanged(m.node.Hostinfo, m.req.Hostinfo)
|
||||||
|
|
||||||
// The node might not set NetInfo if it has not changed and if
|
// The node might not set NetInfo if it has not changed and if
|
||||||
// the full HostInfo object is overrwritten, the information is lost.
|
// the full HostInfo object is overwritten, the information is lost.
|
||||||
// If there is no NetInfo, keep the previous one.
|
// If there is no NetInfo, keep the previous one.
|
||||||
// From 1.66 the client only sends it if changed:
|
// From 1.66 the client only sends it if changed:
|
||||||
// https://github.com/tailscale/tailscale/commit/e1011f138737286ecf5123ff887a7a5800d129a2
|
// https://github.com/tailscale/tailscale/commit/e1011f138737286ecf5123ff887a7a5800d129a2
|
||||||
// TODO(kradalby): evaulate if we need better comparing of hostinfo
|
// TODO(kradalby): evaulate if we need better comparing of hostinfo
|
||||||
// before we take the changes.
|
// before we take the changes.
|
||||||
if m.req.Hostinfo.NetInfo == nil {
|
if m.req.Hostinfo.NetInfo == nil && m.node.Hostinfo != nil {
|
||||||
m.req.Hostinfo.NetInfo = m.node.Hostinfo.NetInfo
|
m.req.Hostinfo.NetInfo = m.node.Hostinfo.NetInfo
|
||||||
}
|
}
|
||||||
m.node.Hostinfo = m.req.Hostinfo
|
m.node.Hostinfo = m.req.Hostinfo
|
||||||
|
@ -661,8 +662,15 @@ func hostInfoChanged(old, new *tailcfg.Hostinfo) (bool, bool) {
|
||||||
return false, false
|
return false, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if old == nil && new != nil {
|
||||||
|
return true, true
|
||||||
|
}
|
||||||
|
|
||||||
// Routes
|
// Routes
|
||||||
oldRoutes := old.RoutableIPs
|
oldRoutes := make([]netip.Prefix, 0)
|
||||||
|
if old != nil {
|
||||||
|
oldRoutes = old.RoutableIPs
|
||||||
|
}
|
||||||
newRoutes := new.RoutableIPs
|
newRoutes := new.RoutableIPs
|
||||||
|
|
||||||
tsaddr.SortPrefixes(oldRoutes)
|
tsaddr.SortPrefixes(oldRoutes)
|
||||||
|
|
|
@ -2,11 +2,7 @@ package types
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql/driver"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
|
@ -21,74 +17,6 @@ const (
|
||||||
|
|
||||||
var ErrCannotParsePrefix = errors.New("cannot parse prefix")
|
var ErrCannotParsePrefix = errors.New("cannot parse prefix")
|
||||||
|
|
||||||
type IPPrefix netip.Prefix
|
|
||||||
|
|
||||||
func (i *IPPrefix) Scan(destination interface{}) error {
|
|
||||||
switch value := destination.(type) {
|
|
||||||
case string:
|
|
||||||
prefix, err := netip.ParsePrefix(value)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
*i = IPPrefix(prefix)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("%w: unexpected data type %T", ErrCannotParsePrefix, destination)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Value return json value, implement driver.Valuer interface.
|
|
||||||
func (i IPPrefix) Value() (driver.Value, error) {
|
|
||||||
prefixStr := netip.Prefix(i).String()
|
|
||||||
|
|
||||||
return prefixStr, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type IPPrefixes []netip.Prefix
|
|
||||||
|
|
||||||
func (i *IPPrefixes) Scan(destination interface{}) error {
|
|
||||||
switch value := destination.(type) {
|
|
||||||
case []byte:
|
|
||||||
return json.Unmarshal(value, i)
|
|
||||||
|
|
||||||
case string:
|
|
||||||
return json.Unmarshal([]byte(value), i)
|
|
||||||
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("%w: unexpected data type %T", ErrNodeAddressesInvalid, destination)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Value return json value, implement driver.Valuer interface.
|
|
||||||
func (i IPPrefixes) Value() (driver.Value, error) {
|
|
||||||
bytes, err := json.Marshal(i)
|
|
||||||
|
|
||||||
return string(bytes), err
|
|
||||||
}
|
|
||||||
|
|
||||||
type StringList []string
|
|
||||||
|
|
||||||
func (i *StringList) Scan(destination interface{}) error {
|
|
||||||
switch value := destination.(type) {
|
|
||||||
case []byte:
|
|
||||||
return json.Unmarshal(value, i)
|
|
||||||
|
|
||||||
case string:
|
|
||||||
return json.Unmarshal([]byte(value), i)
|
|
||||||
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("%w: unexpected data type %T", ErrNodeAddressesInvalid, destination)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Value return json value, implement driver.Valuer interface.
|
|
||||||
func (i StringList) Value() (driver.Value, error) {
|
|
||||||
bytes, err := json.Marshal(i)
|
|
||||||
|
|
||||||
return string(bytes), err
|
|
||||||
}
|
|
||||||
|
|
||||||
type StateUpdateType int
|
type StateUpdateType int
|
||||||
|
|
||||||
func (su StateUpdateType) String() string {
|
func (su StateUpdateType) String() string {
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
package types
|
package types
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
@ -15,7 +13,6 @@ import (
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"go4.org/netipx"
|
"go4.org/netipx"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
"gorm.io/gorm"
|
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
)
|
)
|
||||||
|
@ -51,54 +48,16 @@ func (id NodeID) String() string {
|
||||||
type Node struct {
|
type Node struct {
|
||||||
ID NodeID `gorm:"primary_key"`
|
ID NodeID `gorm:"primary_key"`
|
||||||
|
|
||||||
// MachineKeyDatabaseField is the string representation of MachineKey
|
MachineKey key.MachinePublic `gorm:"serializer:text"`
|
||||||
// it is _only_ used for reading and writing the key to the
|
NodeKey key.NodePublic `gorm:"serializer:text"`
|
||||||
// database and should not be used.
|
DiscoKey key.DiscoPublic `gorm:"serializer:text"`
|
||||||
// Use MachineKey instead.
|
|
||||||
MachineKeyDatabaseField string `gorm:"column:machine_key;unique_index"`
|
|
||||||
MachineKey key.MachinePublic `gorm:"-"`
|
|
||||||
|
|
||||||
// NodeKeyDatabaseField is the string representation of NodeKey
|
Endpoints []netip.AddrPort `gorm:"serializer:json"`
|
||||||
// it is _only_ used for reading and writing the key to the
|
|
||||||
// database and should not be used.
|
|
||||||
// Use NodeKey instead.
|
|
||||||
NodeKeyDatabaseField string `gorm:"column:node_key"`
|
|
||||||
NodeKey key.NodePublic `gorm:"-"`
|
|
||||||
|
|
||||||
// DiscoKeyDatabaseField is the string representation of DiscoKey
|
Hostinfo *tailcfg.Hostinfo `gorm:"serializer:json"`
|
||||||
// it is _only_ used for reading and writing the key to the
|
|
||||||
// database and should not be used.
|
|
||||||
// Use DiscoKey instead.
|
|
||||||
DiscoKeyDatabaseField string `gorm:"column:disco_key"`
|
|
||||||
DiscoKey key.DiscoPublic `gorm:"-"`
|
|
||||||
|
|
||||||
// EndpointsDatabaseField is the string list representation of Endpoints
|
IPv4 *netip.Addr `gorm:"serializer:text"`
|
||||||
// it is _only_ used for reading and writing the key to the
|
IPv6 *netip.Addr `gorm:"serializer:text"`
|
||||||
// database and should not be used.
|
|
||||||
// Use Endpoints instead.
|
|
||||||
EndpointsDatabaseField StringList `gorm:"column:endpoints"`
|
|
||||||
Endpoints []netip.AddrPort `gorm:"-"`
|
|
||||||
|
|
||||||
// EndpointsDatabaseField is the string list representation of Endpoints
|
|
||||||
// it is _only_ used for reading and writing the key to the
|
|
||||||
// database and should not be used.
|
|
||||||
// Use Endpoints instead.
|
|
||||||
HostinfoDatabaseField string `gorm:"column:host_info"`
|
|
||||||
Hostinfo *tailcfg.Hostinfo `gorm:"-"`
|
|
||||||
|
|
||||||
// IPv4DatabaseField is the string representation of v4 address,
|
|
||||||
// it is _only_ used for reading and writing the key to the
|
|
||||||
// database and should not be used.
|
|
||||||
// Use V4 instead.
|
|
||||||
IPv4DatabaseField sql.NullString `gorm:"column:ipv4"`
|
|
||||||
IPv4 *netip.Addr `gorm:"-"`
|
|
||||||
|
|
||||||
// IPv6DatabaseField is the string representation of v4 address,
|
|
||||||
// it is _only_ used for reading and writing the key to the
|
|
||||||
// database and should not be used.
|
|
||||||
// Use V6 instead.
|
|
||||||
IPv6DatabaseField sql.NullString `gorm:"column:ipv6"`
|
|
||||||
IPv6 *netip.Addr `gorm:"-"`
|
|
||||||
|
|
||||||
// Hostname represents the name given by the Tailscale
|
// Hostname represents the name given by the Tailscale
|
||||||
// client during registration
|
// client during registration
|
||||||
|
@ -116,7 +75,7 @@ type Node struct {
|
||||||
|
|
||||||
RegisterMethod string
|
RegisterMethod string
|
||||||
|
|
||||||
ForcedTags StringList
|
ForcedTags []string `gorm:"serializer:json"`
|
||||||
|
|
||||||
// TODO(kradalby): This seems like irrelevant information?
|
// TODO(kradalby): This seems like irrelevant information?
|
||||||
AuthKeyID *uint64 `sql:"DEFAULT:NULL"`
|
AuthKeyID *uint64 `sql:"DEFAULT:NULL"`
|
||||||
|
@ -216,16 +175,20 @@ func (node *Node) CanAccess(filter []tailcfg.FilterRule, node2 *Node) bool {
|
||||||
src := node.IPs()
|
src := node.IPs()
|
||||||
allowedIPs := node2.IPs()
|
allowedIPs := node2.IPs()
|
||||||
|
|
||||||
|
// TODO(kradalby): Regenerate this everytime the filter change, instead of
|
||||||
|
// every time we use it.
|
||||||
|
matchers := make([]matcher.Match, len(filter))
|
||||||
|
for i, rule := range filter {
|
||||||
|
matchers[i] = matcher.MatchFromFilterRule(rule)
|
||||||
|
}
|
||||||
|
|
||||||
for _, route := range node2.Routes {
|
for _, route := range node2.Routes {
|
||||||
if route.Enabled {
|
if route.Enabled {
|
||||||
allowedIPs = append(allowedIPs, netip.Prefix(route.Prefix).Addr())
|
allowedIPs = append(allowedIPs, netip.Prefix(route.Prefix).Addr())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, rule := range filter {
|
for _, matcher := range matchers {
|
||||||
// TODO(kradalby): Cache or pregen this
|
|
||||||
matcher := matcher.MatchFromFilterRule(rule)
|
|
||||||
|
|
||||||
if !matcher.SrcsContainsIPs(src) {
|
if !matcher.SrcsContainsIPs(src) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -255,109 +218,6 @@ func (nodes Nodes) FilterByIP(ip netip.Addr) Nodes {
|
||||||
return found
|
return found
|
||||||
}
|
}
|
||||||
|
|
||||||
// BeforeSave is a hook that ensures that some values that
|
|
||||||
// cannot be directly marshalled into database values are stored
|
|
||||||
// correctly in the database.
|
|
||||||
// This currently means storing the keys as strings.
|
|
||||||
func (node *Node) BeforeSave(tx *gorm.DB) error {
|
|
||||||
node.MachineKeyDatabaseField = node.MachineKey.String()
|
|
||||||
node.NodeKeyDatabaseField = node.NodeKey.String()
|
|
||||||
node.DiscoKeyDatabaseField = node.DiscoKey.String()
|
|
||||||
|
|
||||||
var endpoints StringList
|
|
||||||
for _, addrPort := range node.Endpoints {
|
|
||||||
endpoints = append(endpoints, addrPort.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
node.EndpointsDatabaseField = endpoints
|
|
||||||
|
|
||||||
hi, err := json.Marshal(node.Hostinfo)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("marshalling Hostinfo to store in db: %w", err)
|
|
||||||
}
|
|
||||||
node.HostinfoDatabaseField = string(hi)
|
|
||||||
|
|
||||||
if node.IPv4 != nil {
|
|
||||||
node.IPv4DatabaseField.String, node.IPv4DatabaseField.Valid = node.IPv4.String(), true
|
|
||||||
} else {
|
|
||||||
node.IPv4DatabaseField.String, node.IPv4DatabaseField.Valid = "", false
|
|
||||||
}
|
|
||||||
|
|
||||||
if node.IPv6 != nil {
|
|
||||||
node.IPv6DatabaseField.String, node.IPv6DatabaseField.Valid = node.IPv6.String(), true
|
|
||||||
} else {
|
|
||||||
node.IPv6DatabaseField.String, node.IPv6DatabaseField.Valid = "", false
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AfterFind is a hook that ensures that Node objects fields that
|
|
||||||
// has a different type in the database is unwrapped and populated
|
|
||||||
// correctly.
|
|
||||||
// This currently unmarshals all the keys, stored as strings, into
|
|
||||||
// the proper types.
|
|
||||||
func (node *Node) AfterFind(tx *gorm.DB) error {
|
|
||||||
var machineKey key.MachinePublic
|
|
||||||
if err := machineKey.UnmarshalText([]byte(node.MachineKeyDatabaseField)); err != nil {
|
|
||||||
return fmt.Errorf("unmarshalling machine key from db: %w", err)
|
|
||||||
}
|
|
||||||
node.MachineKey = machineKey
|
|
||||||
|
|
||||||
var nodeKey key.NodePublic
|
|
||||||
if err := nodeKey.UnmarshalText([]byte(node.NodeKeyDatabaseField)); err != nil {
|
|
||||||
return fmt.Errorf("unmarshalling node key from db: %w", err)
|
|
||||||
}
|
|
||||||
node.NodeKey = nodeKey
|
|
||||||
|
|
||||||
// DiscoKey might be empty if a node has not sent it to headscale.
|
|
||||||
// This means that this might fail if the disco key is empty.
|
|
||||||
if node.DiscoKeyDatabaseField != "" {
|
|
||||||
var discoKey key.DiscoPublic
|
|
||||||
if err := discoKey.UnmarshalText([]byte(node.DiscoKeyDatabaseField)); err != nil {
|
|
||||||
return fmt.Errorf("unmarshalling disco key from db: %w", err)
|
|
||||||
}
|
|
||||||
node.DiscoKey = discoKey
|
|
||||||
}
|
|
||||||
|
|
||||||
endpoints := make([]netip.AddrPort, len(node.EndpointsDatabaseField))
|
|
||||||
for idx, ep := range node.EndpointsDatabaseField {
|
|
||||||
addrPort, err := netip.ParseAddrPort(ep)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("parsing endpoint from db: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
endpoints[idx] = addrPort
|
|
||||||
}
|
|
||||||
node.Endpoints = endpoints
|
|
||||||
|
|
||||||
var hi tailcfg.Hostinfo
|
|
||||||
if err := json.Unmarshal([]byte(node.HostinfoDatabaseField), &hi); err != nil {
|
|
||||||
return fmt.Errorf("unmarshalling hostinfo from database: %w", err)
|
|
||||||
}
|
|
||||||
node.Hostinfo = &hi
|
|
||||||
|
|
||||||
if node.IPv4DatabaseField.Valid {
|
|
||||||
ip, err := netip.ParseAddr(node.IPv4DatabaseField.String)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("parsing IPv4 from database: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
node.IPv4 = &ip
|
|
||||||
}
|
|
||||||
|
|
||||||
if node.IPv6DatabaseField.Valid {
|
|
||||||
ip, err := netip.ParseAddr(node.IPv6DatabaseField.String)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("parsing IPv6 from database: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
node.IPv6 = &ip
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (node *Node) Proto() *v1.Node {
|
func (node *Node) Proto() *v1.Node {
|
||||||
nodeProto := &v1.Node{
|
nodeProto := &v1.Node{
|
||||||
Id: uint64(node.ID),
|
Id: uint64(node.ID),
|
||||||
|
|
|
@ -17,7 +17,7 @@ type Route struct {
|
||||||
Node Node
|
Node Node
|
||||||
|
|
||||||
// TODO(kradalby): change this custom type to netip.Prefix
|
// TODO(kradalby): change this custom type to netip.Prefix
|
||||||
Prefix IPPrefix
|
Prefix netip.Prefix `gorm:"serializer:text"`
|
||||||
|
|
||||||
Advertised bool
|
Advertised bool
|
||||||
Enabled bool
|
Enabled bool
|
||||||
|
@ -31,7 +31,7 @@ func (r *Route) String() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Route) IsExitRoute() bool {
|
func (r *Route) IsExitRoute() bool {
|
||||||
return tsaddr.IsExitRoute(netip.Prefix(r.Prefix))
|
return tsaddr.IsExitRoute(r.Prefix)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Route) IsAnnouncable() bool {
|
func (r *Route) IsAnnouncable() bool {
|
||||||
|
@ -59,8 +59,8 @@ func (rs Routes) Primaries() Routes {
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rs Routes) PrefixMap() map[IPPrefix][]Route {
|
func (rs Routes) PrefixMap() map[netip.Prefix][]Route {
|
||||||
res := map[IPPrefix][]Route{}
|
res := map[netip.Prefix][]Route{}
|
||||||
|
|
||||||
for _, route := range rs {
|
for _, route := range rs {
|
||||||
if _, ok := res[route.Prefix]; ok {
|
if _, ok := res[route.Prefix]; ok {
|
||||||
|
@ -80,7 +80,7 @@ func (rs Routes) Proto() []*v1.Route {
|
||||||
protoRoute := v1.Route{
|
protoRoute := v1.Route{
|
||||||
Id: uint64(route.ID),
|
Id: uint64(route.ID),
|
||||||
Node: route.Node.Proto(),
|
Node: route.Node.Proto(),
|
||||||
Prefix: netip.Prefix(route.Prefix).String(),
|
Prefix: route.Prefix.String(),
|
||||||
Advertised: route.Advertised,
|
Advertised: route.Advertised,
|
||||||
Enabled: route.Enabled,
|
Enabled: route.Enabled,
|
||||||
IsPrimary: route.IsPrimary,
|
IsPrimary: route.IsPrimary,
|
||||||
|
|
|
@ -10,16 +10,11 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestPrefixMap(t *testing.T) {
|
func TestPrefixMap(t *testing.T) {
|
||||||
ipp := func(s string) IPPrefix { return IPPrefix(netip.MustParsePrefix(s)) }
|
ipp := func(s string) netip.Prefix { return netip.MustParsePrefix(s) }
|
||||||
|
|
||||||
// TODO(kradalby): Remove when we have gotten rid of IPPrefix type
|
|
||||||
prefixComparer := cmp.Comparer(func(x, y IPPrefix) bool {
|
|
||||||
return x == y
|
|
||||||
})
|
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
rs Routes
|
rs Routes
|
||||||
want map[IPPrefix][]Route
|
want map[netip.Prefix][]Route
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
rs: Routes{
|
rs: Routes{
|
||||||
|
@ -27,7 +22,7 @@ func TestPrefixMap(t *testing.T) {
|
||||||
Prefix: ipp("10.0.0.0/24"),
|
Prefix: ipp("10.0.0.0/24"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
want: map[IPPrefix][]Route{
|
want: map[netip.Prefix][]Route{
|
||||||
ipp("10.0.0.0/24"): Routes{
|
ipp("10.0.0.0/24"): Routes{
|
||||||
Route{
|
Route{
|
||||||
Prefix: ipp("10.0.0.0/24"),
|
Prefix: ipp("10.0.0.0/24"),
|
||||||
|
@ -44,7 +39,7 @@ func TestPrefixMap(t *testing.T) {
|
||||||
Prefix: ipp("10.0.1.0/24"),
|
Prefix: ipp("10.0.1.0/24"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
want: map[IPPrefix][]Route{
|
want: map[netip.Prefix][]Route{
|
||||||
ipp("10.0.0.0/24"): Routes{
|
ipp("10.0.0.0/24"): Routes{
|
||||||
Route{
|
Route{
|
||||||
Prefix: ipp("10.0.0.0/24"),
|
Prefix: ipp("10.0.0.0/24"),
|
||||||
|
@ -68,7 +63,7 @@ func TestPrefixMap(t *testing.T) {
|
||||||
Enabled: false,
|
Enabled: false,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
want: map[IPPrefix][]Route{
|
want: map[netip.Prefix][]Route{
|
||||||
ipp("10.0.0.0/24"): Routes{
|
ipp("10.0.0.0/24"): Routes{
|
||||||
Route{
|
Route{
|
||||||
Prefix: ipp("10.0.0.0/24"),
|
Prefix: ipp("10.0.0.0/24"),
|
||||||
|
@ -86,7 +81,7 @@ func TestPrefixMap(t *testing.T) {
|
||||||
for idx, tt := range tests {
|
for idx, tt := range tests {
|
||||||
t.Run(fmt.Sprintf("test-%d", idx), func(t *testing.T) {
|
t.Run(fmt.Sprintf("test-%d", idx), func(t *testing.T) {
|
||||||
got := tt.rs.PrefixMap()
|
got := tt.rs.PrefixMap()
|
||||||
if diff := cmp.Diff(tt.want, got, prefixComparer, util.MkeyComparer, util.NkeyComparer, util.DkeyComparer); diff != "" {
|
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
|
||||||
t.Errorf("PrefixMap() unexpected result (-want +got):\n%s", diff)
|
t.Errorf("PrefixMap() unexpected result (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -16,6 +16,8 @@ func DefaultConfigEnv() map[string]string {
|
||||||
"HEADSCALE_POLICY_PATH": "",
|
"HEADSCALE_POLICY_PATH": "",
|
||||||
"HEADSCALE_DATABASE_TYPE": "sqlite",
|
"HEADSCALE_DATABASE_TYPE": "sqlite",
|
||||||
"HEADSCALE_DATABASE_SQLITE_PATH": "/tmp/integration_test_db.sqlite3",
|
"HEADSCALE_DATABASE_SQLITE_PATH": "/tmp/integration_test_db.sqlite3",
|
||||||
|
"HEADSCALE_DATABASE_DEBUG": "1",
|
||||||
|
"HEADSCALE_DATABASE_GORM_SLOW_THRESHOLD": "1",
|
||||||
"HEADSCALE_EPHEMERAL_NODE_INACTIVITY_TIMEOUT": "30m",
|
"HEADSCALE_EPHEMERAL_NODE_INACTIVITY_TIMEOUT": "30m",
|
||||||
"HEADSCALE_PREFIXES_V4": "100.64.0.0/10",
|
"HEADSCALE_PREFIXES_V4": "100.64.0.0/10",
|
||||||
"HEADSCALE_PREFIXES_V6": "fd7a:115c:a1e0::/48",
|
"HEADSCALE_PREFIXES_V6": "fd7a:115c:a1e0::/48",
|
||||||
|
|
Loading…
Reference in a new issue