linter fixes

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2024-10-24 09:38:33 -05:00
parent 5fbf3f8327
commit f5feff7c22
No known key found for this signature in database
15 changed files with 527 additions and 397 deletions

View file

@ -27,6 +27,7 @@ linters:
- nolintlint - nolintlint
- musttag # causes issues with imported libs - musttag # causes issues with imported libs
- depguard - depguard
- exportloopref
# We should strive to enable these: # We should strive to enable these:
- wrapcheck - wrapcheck

View file

@ -16,6 +16,7 @@ import (
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm" "gorm.io/gorm"
"zgo.at/zcache/v2" "zgo.at/zcache/v2"
) )
@ -44,7 +45,7 @@ func TestMigrations(t *testing.T) {
routes, err := Read(h.DB, func(rx *gorm.DB) (types.Routes, error) { routes, err := Read(h.DB, func(rx *gorm.DB) (types.Routes, error) {
return GetRoutes(rx) return GetRoutes(rx)
}) })
assert.NoError(t, err) require.NoError(t, err)
assert.Len(t, routes, 10) assert.Len(t, routes, 10)
want := types.Routes{ want := types.Routes{
@ -70,7 +71,7 @@ func TestMigrations(t *testing.T) {
routes, err := Read(h.DB, func(rx *gorm.DB) (types.Routes, error) { routes, err := Read(h.DB, func(rx *gorm.DB) (types.Routes, error) {
return GetRoutes(rx) return GetRoutes(rx)
}) })
assert.NoError(t, err) require.NoError(t, err)
assert.Len(t, routes, 4) assert.Len(t, routes, 4)
want := types.Routes{ want := types.Routes{
@ -132,7 +133,7 @@ func TestMigrations(t *testing.T) {
return append(kratest, testkra...), nil return append(kratest, testkra...), nil
}) })
assert.NoError(t, err) require.NoError(t, err)
assert.Len(t, keys, 5) assert.Len(t, keys, 5)
want := []types.PreAuthKey{ want := []types.PreAuthKey{
@ -177,7 +178,7 @@ func TestMigrations(t *testing.T) {
nodes, err := Read(h.DB, func(rx *gorm.DB) (types.Nodes, error) { nodes, err := Read(h.DB, func(rx *gorm.DB) (types.Nodes, error) {
return ListNodes(rx) return ListNodes(rx)
}) })
assert.NoError(t, err) require.NoError(t, err)
for _, node := range nodes { for _, node := range nodes {
assert.Falsef(t, node.MachineKey.IsZero(), "expected non zero machinekey") assert.Falsef(t, node.MachineKey.IsZero(), "expected non zero machinekey")

View file

@ -12,6 +12,7 @@ import (
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tailscale.com/net/tsaddr" "tailscale.com/net/tsaddr"
"tailscale.com/types/ptr" "tailscale.com/types/ptr"
) )
@ -457,7 +458,12 @@ func TestBackfillIPAddresses(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
db := tt.dbFunc() db := tt.dbFunc()
alloc, err := NewIPAllocator(db, tt.prefix4, tt.prefix6, types.IPAllocationStrategySequential) alloc, err := NewIPAllocator(
db,
tt.prefix4,
tt.prefix6,
types.IPAllocationStrategySequential,
)
if err != nil { if err != nil {
t.Fatalf("failed to set up ip alloc: %s", err) t.Fatalf("failed to set up ip alloc: %s", err)
} }
@ -482,24 +488,29 @@ func TestBackfillIPAddresses(t *testing.T) {
} }
func TestIPAllocatorNextNoReservedIPs(t *testing.T) { func TestIPAllocatorNextNoReservedIPs(t *testing.T) {
alloc, err := NewIPAllocator(db, ptr.To(tsaddr.CGNATRange()), ptr.To(tsaddr.TailscaleULARange()), types.IPAllocationStrategySequential) alloc, err := NewIPAllocator(
db,
ptr.To(tsaddr.CGNATRange()),
ptr.To(tsaddr.TailscaleULARange()),
types.IPAllocationStrategySequential,
)
if err != nil { if err != nil {
t.Fatalf("failed to set up ip alloc: %s", err) t.Fatalf("failed to set up ip alloc: %s", err)
} }
// Validate that we do not give out 100.100.100.100 // Validate that we do not give out 100.100.100.100
nextQuad100, err := alloc.next(na("100.100.100.99"), ptr.To(tsaddr.CGNATRange())) nextQuad100, err := alloc.next(na("100.100.100.99"), ptr.To(tsaddr.CGNATRange()))
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, na("100.100.100.101"), *nextQuad100) assert.Equal(t, na("100.100.100.101"), *nextQuad100)
// Validate that we do not give out fd7a:115c:a1e0::53 // Validate that we do not give out fd7a:115c:a1e0::53
nextQuad100v6, err := alloc.next(na("fd7a:115c:a1e0::52"), ptr.To(tsaddr.TailscaleULARange())) nextQuad100v6, err := alloc.next(na("fd7a:115c:a1e0::52"), ptr.To(tsaddr.TailscaleULARange()))
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, na("fd7a:115c:a1e0::54"), *nextQuad100v6) assert.Equal(t, na("fd7a:115c:a1e0::54"), *nextQuad100v6)
// Validate that we do not give out fd7a:115c:a1e0::53 // Validate that we do not give out fd7a:115c:a1e0::53
nextChrome, err := alloc.next(na("100.115.91.255"), ptr.To(tsaddr.CGNATRange())) nextChrome, err := alloc.next(na("100.115.91.255"), ptr.To(tsaddr.CGNATRange()))
t.Logf("chrome: %s", nextChrome.String()) t.Logf("chrome: %s", nextChrome.String())
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, na("100.115.94.0"), *nextChrome) assert.Equal(t, na("100.115.94.0"), *nextChrome)
} }

View file

@ -17,6 +17,7 @@ import (
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/puzpuzpuz/xsync/v3" "github.com/puzpuzpuz/xsync/v3"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/check.v1" "gopkg.in/check.v1"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/net/tsaddr" "tailscale.com/net/tsaddr"
@ -558,17 +559,17 @@ func TestAutoApproveRoutes(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
adb, err := newTestDB() adb, err := newTestDB()
assert.NoError(t, err) require.NoError(t, err)
pol, err := policy.LoadACLPolicyFromBytes([]byte(tt.acl)) pol, err := policy.LoadACLPolicyFromBytes([]byte(tt.acl))
assert.NoError(t, err) require.NoError(t, err)
assert.NotNil(t, pol) assert.NotNil(t, pol)
user, err := adb.CreateUser("test") user, err := adb.CreateUser("test")
assert.NoError(t, err) require.NoError(t, err)
pak, err := adb.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := adb.CreatePreAuthKey(user.Name, false, false, nil, nil)
assert.NoError(t, err) require.NoError(t, err)
nodeKey := key.NewNode() nodeKey := key.NewNode()
machineKey := key.NewMachine() machineKey := key.NewMachine()
@ -590,21 +591,21 @@ func TestAutoApproveRoutes(t *testing.T) {
} }
trx := adb.DB.Save(&node) trx := adb.DB.Save(&node)
assert.NoError(t, trx.Error) require.NoError(t, trx.Error)
sendUpdate, err := adb.SaveNodeRoutes(&node) sendUpdate, err := adb.SaveNodeRoutes(&node)
assert.NoError(t, err) require.NoError(t, err)
assert.False(t, sendUpdate) assert.False(t, sendUpdate)
node0ByID, err := adb.GetNodeByID(0) node0ByID, err := adb.GetNodeByID(0)
assert.NoError(t, err) require.NoError(t, err)
// TODO(kradalby): Check state update // TODO(kradalby): Check state update
err = adb.EnableAutoApprovedRoutes(pol, node0ByID) err = adb.EnableAutoApprovedRoutes(pol, node0ByID)
assert.NoError(t, err) require.NoError(t, err)
enabledRoutes, err := adb.GetEnabledRoutes(node0ByID) enabledRoutes, err := adb.GetEnabledRoutes(node0ByID)
assert.NoError(t, err) require.NoError(t, err)
assert.Len(t, enabledRoutes, len(tt.want)) assert.Len(t, enabledRoutes, len(tt.want))
tsaddr.SortPrefixes(enabledRoutes) tsaddr.SortPrefixes(enabledRoutes)
@ -697,13 +698,13 @@ func TestListEphemeralNodes(t *testing.T) {
} }
user, err := db.CreateUser("test") user, err := db.CreateUser("test")
assert.NoError(t, err) require.NoError(t, err)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
assert.NoError(t, err) require.NoError(t, err)
pakEph, err := db.CreatePreAuthKey(user.Name, false, true, nil, nil) pakEph, err := db.CreatePreAuthKey(user.Name, false, true, nil, nil)
assert.NoError(t, err) require.NoError(t, err)
node := types.Node{ node := types.Node{
ID: 0, ID: 0,
@ -726,16 +727,16 @@ func TestListEphemeralNodes(t *testing.T) {
} }
err = db.DB.Save(&node).Error err = db.DB.Save(&node).Error
assert.NoError(t, err) require.NoError(t, err)
err = db.DB.Save(&nodeEph).Error err = db.DB.Save(&nodeEph).Error
assert.NoError(t, err) require.NoError(t, err)
nodes, err := db.ListNodes() nodes, err := db.ListNodes()
assert.NoError(t, err) require.NoError(t, err)
ephemeralNodes, err := db.ListEphemeralNodes() ephemeralNodes, err := db.ListEphemeralNodes()
assert.NoError(t, err) require.NoError(t, err)
assert.Len(t, nodes, 2) assert.Len(t, nodes, 2)
assert.Len(t, ephemeralNodes, 1) assert.Len(t, ephemeralNodes, 1)
@ -753,10 +754,10 @@ func TestRenameNode(t *testing.T) {
} }
user, err := db.CreateUser("test") user, err := db.CreateUser("test")
assert.NoError(t, err) require.NoError(t, err)
user2, err := db.CreateUser("test2") user2, err := db.CreateUser("test2")
assert.NoError(t, err) require.NoError(t, err)
node := types.Node{ node := types.Node{
ID: 0, ID: 0,
@ -777,10 +778,10 @@ func TestRenameNode(t *testing.T) {
} }
err = db.DB.Save(&node).Error err = db.DB.Save(&node).Error
assert.NoError(t, err) require.NoError(t, err)
err = db.DB.Save(&node2).Error err = db.DB.Save(&node2).Error
assert.NoError(t, err) require.NoError(t, err)
err = db.DB.Transaction(func(tx *gorm.DB) error { err = db.DB.Transaction(func(tx *gorm.DB) error {
_, err := RegisterNode(tx, node, nil, nil) _, err := RegisterNode(tx, node, nil, nil)
@ -790,10 +791,10 @@ func TestRenameNode(t *testing.T) {
_, err = RegisterNode(tx, node2, nil, nil) _, err = RegisterNode(tx, node2, nil, nil)
return err return err
}) })
assert.NoError(t, err) require.NoError(t, err)
nodes, err := db.ListNodes() nodes, err := db.ListNodes()
assert.NoError(t, err) require.NoError(t, err)
assert.Len(t, nodes, 2) assert.Len(t, nodes, 2)
@ -815,26 +816,26 @@ func TestRenameNode(t *testing.T) {
err = db.Write(func(tx *gorm.DB) error { err = db.Write(func(tx *gorm.DB) error {
return RenameNode(tx, nodes[0].ID, "newname") return RenameNode(tx, nodes[0].ID, "newname")
}) })
assert.NoError(t, err) require.NoError(t, err)
nodes, err = db.ListNodes() nodes, err = db.ListNodes()
assert.NoError(t, err) require.NoError(t, err)
assert.Len(t, nodes, 2) assert.Len(t, nodes, 2)
assert.Equal(t, nodes[0].Hostname, "test") assert.Equal(t, "test", nodes[0].Hostname)
assert.Equal(t, nodes[0].GivenName, "newname") assert.Equal(t, "newname", nodes[0].GivenName)
// Nodes can reuse name that is no longer used // Nodes can reuse name that is no longer used
err = db.Write(func(tx *gorm.DB) error { err = db.Write(func(tx *gorm.DB) error {
return RenameNode(tx, nodes[1].ID, "test") return RenameNode(tx, nodes[1].ID, "test")
}) })
assert.NoError(t, err) require.NoError(t, err)
nodes, err = db.ListNodes() nodes, err = db.ListNodes()
assert.NoError(t, err) require.NoError(t, err)
assert.Len(t, nodes, 2) assert.Len(t, nodes, 2)
assert.Equal(t, nodes[0].Hostname, "test") assert.Equal(t, "test", nodes[0].Hostname)
assert.Equal(t, nodes[0].GivenName, "newname") assert.Equal(t, "newname", nodes[0].GivenName)
assert.Equal(t, nodes[1].GivenName, "test") assert.Equal(t, "test", nodes[1].GivenName)
// Nodes cannot be renamed to used names // Nodes cannot be renamed to used names
err = db.Write(func(tx *gorm.DB) error { err = db.Write(func(tx *gorm.DB) error {

View file

@ -488,7 +488,7 @@ func (a *AuthProviderOIDC) registerNode(
} }
// TODO(kradalby): // TODO(kradalby):
// Rewrite in elem-go // Rewrite in elem-go.
func renderOIDCCallbackTemplate( func renderOIDCCallbackTemplate(
user *types.User, user *types.User,
) (*bytes.Buffer, error) { ) (*bytes.Buffer, error) {

View file

@ -599,7 +599,7 @@ func (pol *ACLPolicy) ExpandAlias(
// TODO(kradalby): It is quite hard to understand what this function is doing, // TODO(kradalby): It is quite hard to understand what this function is doing,
// it seems like it trying to ensure that we dont include nodes that are tagged // it seems like it trying to ensure that we dont include nodes that are tagged
// when we look up the nodes owned by a user. // when we look up the nodes owned by a user.
// This should be refactored to be more clear as part of the Tags work in #1369 // This should be refactored to be more clear as part of the Tags work in #1369.
func excludeCorrectlyTaggedNodes( func excludeCorrectlyTaggedNodes(
aclPolicy *ACLPolicy, aclPolicy *ACLPolicy,
nodes types.Nodes, nodes types.Nodes,

View file

@ -11,7 +11,7 @@ import (
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/spf13/viper" "github.com/spf13/viper"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
"go4.org/netipx" "go4.org/netipx"
"gopkg.in/check.v1" "gopkg.in/check.v1"
"tailscale.com/net/tsaddr" "tailscale.com/net/tsaddr"
@ -1824,12 +1824,20 @@ func TestTheInternet(t *testing.T) {
for i := range internetPrefs { for i := range internetPrefs {
if internetPrefs[i].String() != hsExitNodeDest[i].IP { if internetPrefs[i].String() != hsExitNodeDest[i].IP {
t.Errorf("prefix from internet set %q != hsExit list %q", internetPrefs[i].String(), hsExitNodeDest[i].IP) t.Errorf(
"prefix from internet set %q != hsExit list %q",
internetPrefs[i].String(),
hsExitNodeDest[i].IP,
)
} }
} }
if len(internetPrefs) != len(hsExitNodeDest) { if len(internetPrefs) != len(hsExitNodeDest) {
t.Fatalf("expected same length of prefixes, internet: %d, hsExit: %d", len(internetPrefs), len(hsExitNodeDest)) t.Fatalf(
"expected same length of prefixes, internet: %d, hsExit: %d",
len(internetPrefs),
len(hsExitNodeDest),
)
} }
} }
@ -2036,7 +2044,12 @@ func TestReduceFilterRules(t *testing.T) {
}, },
want: []tailcfg.FilterRule{ want: []tailcfg.FilterRule{
{ {
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, SrcIPs: []string{
"100.64.0.1/32",
"100.64.0.2/32",
"fd7a:115c:a1e0::1/128",
"fd7a:115c:a1e0::2/128",
},
DstPorts: []tailcfg.NetPortRange{ DstPorts: []tailcfg.NetPortRange{
{ {
IP: "100.64.0.100/32", IP: "100.64.0.100/32",
@ -2049,7 +2062,12 @@ func TestReduceFilterRules(t *testing.T) {
}, },
}, },
{ {
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, SrcIPs: []string{
"100.64.0.1/32",
"100.64.0.2/32",
"fd7a:115c:a1e0::1/128",
"fd7a:115c:a1e0::2/128",
},
DstPorts: hsExitNodeDest, DstPorts: hsExitNodeDest,
}, },
}, },
@ -2132,7 +2150,12 @@ func TestReduceFilterRules(t *testing.T) {
}, },
want: []tailcfg.FilterRule{ want: []tailcfg.FilterRule{
{ {
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, SrcIPs: []string{
"100.64.0.1/32",
"100.64.0.2/32",
"fd7a:115c:a1e0::1/128",
"fd7a:115c:a1e0::2/128",
},
DstPorts: []tailcfg.NetPortRange{ DstPorts: []tailcfg.NetPortRange{
{ {
IP: "100.64.0.100/32", IP: "100.64.0.100/32",
@ -2145,7 +2168,12 @@ func TestReduceFilterRules(t *testing.T) {
}, },
}, },
{ {
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, SrcIPs: []string{
"100.64.0.1/32",
"100.64.0.2/32",
"fd7a:115c:a1e0::1/128",
"fd7a:115c:a1e0::2/128",
},
DstPorts: []tailcfg.NetPortRange{ DstPorts: []tailcfg.NetPortRange{
{IP: "0.0.0.0/5", Ports: tailcfg.PortRangeAny}, {IP: "0.0.0.0/5", Ports: tailcfg.PortRangeAny},
{IP: "8.0.0.0/7", Ports: tailcfg.PortRangeAny}, {IP: "8.0.0.0/7", Ports: tailcfg.PortRangeAny},
@ -2217,7 +2245,10 @@ func TestReduceFilterRules(t *testing.T) {
IPv6: iap("fd7a:115c:a1e0::100"), IPv6: iap("fd7a:115c:a1e0::100"),
User: types.User{Name: "user100"}, User: types.User{Name: "user100"},
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")}, RoutableIPs: []netip.Prefix{
netip.MustParsePrefix("8.0.0.0/16"),
netip.MustParsePrefix("16.0.0.0/16"),
},
}, },
}, },
peers: types.Nodes{ peers: types.Nodes{
@ -2234,7 +2265,12 @@ func TestReduceFilterRules(t *testing.T) {
}, },
want: []tailcfg.FilterRule{ want: []tailcfg.FilterRule{
{ {
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, SrcIPs: []string{
"100.64.0.1/32",
"100.64.0.2/32",
"fd7a:115c:a1e0::1/128",
"fd7a:115c:a1e0::2/128",
},
DstPorts: []tailcfg.NetPortRange{ DstPorts: []tailcfg.NetPortRange{
{ {
IP: "100.64.0.100/32", IP: "100.64.0.100/32",
@ -2247,7 +2283,12 @@ func TestReduceFilterRules(t *testing.T) {
}, },
}, },
{ {
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, SrcIPs: []string{
"100.64.0.1/32",
"100.64.0.2/32",
"fd7a:115c:a1e0::1/128",
"fd7a:115c:a1e0::2/128",
},
DstPorts: []tailcfg.NetPortRange{ DstPorts: []tailcfg.NetPortRange{
{ {
IP: "8.0.0.0/8", IP: "8.0.0.0/8",
@ -2294,7 +2335,10 @@ func TestReduceFilterRules(t *testing.T) {
IPv6: iap("fd7a:115c:a1e0::100"), IPv6: iap("fd7a:115c:a1e0::100"),
User: types.User{Name: "user100"}, User: types.User{Name: "user100"},
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")}, RoutableIPs: []netip.Prefix{
netip.MustParsePrefix("8.0.0.0/8"),
netip.MustParsePrefix("16.0.0.0/8"),
},
}, },
}, },
peers: types.Nodes{ peers: types.Nodes{
@ -2311,7 +2355,12 @@ func TestReduceFilterRules(t *testing.T) {
}, },
want: []tailcfg.FilterRule{ want: []tailcfg.FilterRule{
{ {
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, SrcIPs: []string{
"100.64.0.1/32",
"100.64.0.2/32",
"fd7a:115c:a1e0::1/128",
"fd7a:115c:a1e0::2/128",
},
DstPorts: []tailcfg.NetPortRange{ DstPorts: []tailcfg.NetPortRange{
{ {
IP: "100.64.0.100/32", IP: "100.64.0.100/32",
@ -2324,7 +2373,12 @@ func TestReduceFilterRules(t *testing.T) {
}, },
}, },
{ {
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, SrcIPs: []string{
"100.64.0.1/32",
"100.64.0.2/32",
"fd7a:115c:a1e0::1/128",
"fd7a:115c:a1e0::2/128",
},
DstPorts: []tailcfg.NetPortRange{ DstPorts: []tailcfg.NetPortRange{
{ {
IP: "8.0.0.0/16", IP: "8.0.0.0/16",
@ -3299,7 +3353,11 @@ func TestSSHRules(t *testing.T) {
SSHUsers: map[string]string{ SSHUsers: map[string]string{
"autogroup:nonroot": "=", "autogroup:nonroot": "=",
}, },
Action: &tailcfg.SSHAction{Accept: true, AllowAgentForwarding: true, AllowLocalPortForwarding: true}, Action: &tailcfg.SSHAction{
Accept: true,
AllowAgentForwarding: true,
AllowLocalPortForwarding: true,
},
}, },
{ {
SSHUsers: map[string]string{ SSHUsers: map[string]string{
@ -3310,7 +3368,11 @@ func TestSSHRules(t *testing.T) {
Any: true, Any: true,
}, },
}, },
Action: &tailcfg.SSHAction{Accept: true, AllowAgentForwarding: true, AllowLocalPortForwarding: true}, Action: &tailcfg.SSHAction{
Accept: true,
AllowAgentForwarding: true,
AllowLocalPortForwarding: true,
},
}, },
{ {
Principals: []*tailcfg.SSHPrincipal{ Principals: []*tailcfg.SSHPrincipal{
@ -3321,7 +3383,11 @@ func TestSSHRules(t *testing.T) {
SSHUsers: map[string]string{ SSHUsers: map[string]string{
"autogroup:nonroot": "=", "autogroup:nonroot": "=",
}, },
Action: &tailcfg.SSHAction{Accept: true, AllowAgentForwarding: true, AllowLocalPortForwarding: true}, Action: &tailcfg.SSHAction{
Accept: true,
AllowAgentForwarding: true,
AllowLocalPortForwarding: true,
},
}, },
{ {
SSHUsers: map[string]string{ SSHUsers: map[string]string{
@ -3332,7 +3398,11 @@ func TestSSHRules(t *testing.T) {
Any: true, Any: true,
}, },
}, },
Action: &tailcfg.SSHAction{Accept: true, AllowAgentForwarding: true, AllowLocalPortForwarding: true}, Action: &tailcfg.SSHAction{
Accept: true,
AllowAgentForwarding: true,
AllowLocalPortForwarding: true,
},
}, },
}}, }},
}, },
@ -3392,7 +3462,7 @@ func TestSSHRules(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := tt.pol.CompileSSHPolicy(&tt.node, tt.peers) got, err := tt.pol.CompileSSHPolicy(&tt.node, tt.peers)
assert.NoError(t, err) require.NoError(t, err)
if diff := cmp.Diff(tt.want, got); diff != "" { if diff := cmp.Diff(tt.want, got); diff != "" {
t.Errorf("TestSSHRules() unexpected result (-want +got):\n%s", diff) t.Errorf("TestSSHRules() unexpected result (-want +got):\n%s", diff)
@ -3499,7 +3569,7 @@ func TestValidExpandTagOwnersInSources(t *testing.T) {
} }
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}) got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{})
assert.NoError(t, err) require.NoError(t, err)
want := []tailcfg.FilterRule{ want := []tailcfg.FilterRule{
{ {
@ -3550,7 +3620,7 @@ func TestInvalidTagValidUser(t *testing.T) {
} }
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}) got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{})
assert.NoError(t, err) require.NoError(t, err)
want := []tailcfg.FilterRule{ want := []tailcfg.FilterRule{
{ {
@ -3609,7 +3679,7 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) {
// c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32") // c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}) got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{})
assert.NoError(t, err) require.NoError(t, err)
want := []tailcfg.FilterRule{ want := []tailcfg.FilterRule{
{ {
@ -3679,7 +3749,7 @@ func TestValidTagInvalidUser(t *testing.T) {
} }
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{nodes2}) got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{nodes2})
assert.NoError(t, err) require.NoError(t, err)
want := []tailcfg.FilterRule{ want := []tailcfg.FilterRule{
{ {

View file

@ -13,7 +13,7 @@ func Windows(url string) *elem.Element {
elem.Text("headscale - Windows"), elem.Text("headscale - Windows"),
), ),
elem.Body(attrs.Props{ elem.Body(attrs.Props{
attrs.Style : bodyStyle.ToInline(), attrs.Style: bodyStyle.ToInline(),
}, },
headerOne("headscale: Windows configuration"), headerOne("headscale: Windows configuration"),
elem.P(nil, elem.P(nil,
@ -21,7 +21,8 @@ func Windows(url string) *elem.Element {
elem.A(attrs.Props{ elem.A(attrs.Props{
attrs.Href: "https://tailscale.com/download/windows", attrs.Href: "https://tailscale.com/download/windows",
attrs.Rel: "noreferrer noopener", attrs.Rel: "noreferrer noopener",
attrs.Target: "_blank"}, attrs.Target: "_blank",
},
elem.Text("Tailscale for Windows ")), elem.Text("Tailscale for Windows ")),
elem.Text("and install it."), elem.Text("and install it."),
), ),

View file

@ -8,6 +8,7 @@ import (
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/spf13/viper" "github.com/spf13/viper"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/dnstype" "tailscale.com/types/dnstype"
) )
@ -35,8 +36,17 @@ func TestReadConfig(t *testing.T) {
MagicDNS: true, MagicDNS: true,
BaseDomain: "example.com", BaseDomain: "example.com",
Nameservers: Nameservers{ Nameservers: Nameservers{
Global: []string{"1.1.1.1", "1.0.0.1", "2606:4700:4700::1111", "2606:4700:4700::1001", "https://dns.nextdns.io/abc123"}, Global: []string{
Split: map[string][]string{"darp.headscale.net": {"1.1.1.1", "8.8.8.8"}, "foo.bar.com": {"1.1.1.1"}}, "1.1.1.1",
"1.0.0.1",
"2606:4700:4700::1111",
"2606:4700:4700::1001",
"https://dns.nextdns.io/abc123",
},
Split: map[string][]string{
"darp.headscale.net": {"1.1.1.1", "8.8.8.8"},
"foo.bar.com": {"1.1.1.1"},
},
}, },
ExtraRecords: []tailcfg.DNSRecord{ ExtraRecords: []tailcfg.DNSRecord{
{Name: "grafana.myvpn.example.com", Type: "A", Value: "100.64.0.3"}, {Name: "grafana.myvpn.example.com", Type: "A", Value: "100.64.0.3"},
@ -91,8 +101,17 @@ func TestReadConfig(t *testing.T) {
MagicDNS: false, MagicDNS: false,
BaseDomain: "example.com", BaseDomain: "example.com",
Nameservers: Nameservers{ Nameservers: Nameservers{
Global: []string{"1.1.1.1", "1.0.0.1", "2606:4700:4700::1111", "2606:4700:4700::1001", "https://dns.nextdns.io/abc123"}, Global: []string{
Split: map[string][]string{"darp.headscale.net": {"1.1.1.1", "8.8.8.8"}, "foo.bar.com": {"1.1.1.1"}}, "1.1.1.1",
"1.0.0.1",
"2606:4700:4700::1111",
"2606:4700:4700::1001",
"https://dns.nextdns.io/abc123",
},
Split: map[string][]string{
"darp.headscale.net": {"1.1.1.1", "8.8.8.8"},
"foo.bar.com": {"1.1.1.1"},
},
}, },
ExtraRecords: []tailcfg.DNSRecord{ ExtraRecords: []tailcfg.DNSRecord{
{Name: "grafana.myvpn.example.com", Type: "A", Value: "100.64.0.3"}, {Name: "grafana.myvpn.example.com", Type: "A", Value: "100.64.0.3"},
@ -186,7 +205,7 @@ func TestReadConfig(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
viper.Reset() viper.Reset()
err := LoadConfig(tt.configPath, true) err := LoadConfig(tt.configPath, true)
assert.NoError(t, err) require.NoError(t, err)
conf, err := tt.setup(t) conf, err := tt.setup(t)
@ -196,7 +215,7 @@ func TestReadConfig(t *testing.T) {
return return
} }
assert.NoError(t, err) require.NoError(t, err)
if diff := cmp.Diff(tt.want, conf); diff != "" { if diff := cmp.Diff(tt.want, conf); diff != "" {
t.Errorf("ReadConfig() mismatch (-want +got):\n%s", diff) t.Errorf("ReadConfig() mismatch (-want +got):\n%s", diff)
@ -276,10 +295,10 @@ func TestReadConfigFromEnv(t *testing.T) {
viper.Reset() viper.Reset()
err := LoadConfig("testdata/minimal.yaml", true) err := LoadConfig("testdata/minimal.yaml", true)
assert.NoError(t, err) require.NoError(t, err)
conf, err := tt.setup(t) conf, err := tt.setup(t)
assert.NoError(t, err) require.NoError(t, err)
if diff := cmp.Diff(tt.want, conf); diff != "" { if diff := cmp.Diff(tt.want, conf); diff != "" {
t.Errorf("ReadConfig() mismatch (-want +got):\n%s", diff) t.Errorf("ReadConfig() mismatch (-want +got):\n%s", diff)
@ -310,13 +329,25 @@ noise:
// Check configuration validation errors (1) // Check configuration validation errors (1)
err = LoadConfig(tmpDir, false) err = LoadConfig(tmpDir, false)
assert.NoError(t, err) require.NoError(t, err)
err = validateServerConfig() err = validateServerConfig()
assert.Error(t, err) require.Error(t, err)
assert.Contains(t, err.Error(), "Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both") assert.Contains(
assert.Contains(t, err.Error(), "Fatal config error: the only supported values for tls_letsencrypt_challenge_type are") t,
assert.Contains(t, err.Error(), "Fatal config error: server_url must start with https:// or http://") err.Error(),
"Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both",
)
assert.Contains(
t,
err.Error(),
"Fatal config error: the only supported values for tls_letsencrypt_challenge_type are",
)
assert.Contains(
t,
err.Error(),
"Fatal config error: server_url must start with https:// or http://",
)
// Check configuration validation errors (2) // Check configuration validation errors (2)
configYaml = []byte(`--- configYaml = []byte(`---
@ -331,5 +362,5 @@ tls_letsencrypt_challenge_type: TLS-ALPN-01
t.Fatalf("Couldn't write file %s", configFilePath) t.Fatalf("Couldn't write file %s", configFilePath)
} }
err = LoadConfig(tmpDir, false) err = LoadConfig(tmpDir, false)
assert.NoError(t, err) require.NoError(t, err)
} }

View file

@ -4,12 +4,13 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestGenerateRandomStringDNSSafe(t *testing.T) { func TestGenerateRandomStringDNSSafe(t *testing.T) {
for i := 0; i < 100000; i++ { for i := 0; i < 100000; i++ {
str, err := GenerateRandomStringDNSSafe(8) str, err := GenerateRandomStringDNSSafe(8)
assert.Nil(t, err) require.NoError(t, err)
assert.Len(t, str, 8) assert.Len(t, str, 8)
} }
} }

View file

@ -12,6 +12,7 @@ import (
"github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/tsic" "github.com/juanfont/headscale/integration/tsic"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
var veryLargeDestination = []string{ var veryLargeDestination = []string{
@ -54,7 +55,7 @@ func aclScenario(
) *Scenario { ) *Scenario {
t.Helper() t.Helper()
scenario, err := NewScenario(dockertestMaxWait()) scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) require.NoError(t, err)
spec := map[string]int{ spec := map[string]int{
"user1": clientsPerUser, "user1": clientsPerUser,
@ -77,10 +78,10 @@ func aclScenario(
hsic.WithACLPolicy(policy), hsic.WithACLPolicy(policy),
hsic.WithTestName("acl"), hsic.WithTestName("acl"),
) )
assertNoErr(t, err) require.NoError(t, err)
_, err = scenario.ListTailscaleClientsFQDNs() _, err = scenario.ListTailscaleClientsFQDNs()
assertNoErrListFQDN(t, err) require.NoError(t, err)
return scenario return scenario
} }
@ -267,7 +268,7 @@ func TestACLHostsInNetMapTable(t *testing.T) {
for name, testCase := range tests { for name, testCase := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait()) scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) require.NoError(t, err)
spec := testCase.users spec := testCase.users
@ -275,22 +276,22 @@ func TestACLHostsInNetMapTable(t *testing.T) {
[]tsic.Option{}, []tsic.Option{},
hsic.WithACLPolicy(&testCase.policy), hsic.WithACLPolicy(&testCase.policy),
) )
assertNoErr(t, err) require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
allClients, err := scenario.ListTailscaleClients() allClients, err := scenario.ListTailscaleClients()
assertNoErr(t, err) require.NoError(t, err)
err = scenario.WaitForTailscaleSyncWithPeerCount(testCase.want["user1"]) err = scenario.WaitForTailscaleSyncWithPeerCount(testCase.want["user1"])
assertNoErrSync(t, err) require.NoError(t, err)
for _, client := range allClients { for _, client := range allClients {
status, err := client.Status() status, err := client.Status()
assertNoErr(t, err) require.NoError(t, err)
user := status.User[status.Self.UserID].LoginName user := status.User[status.Self.UserID].LoginName
assert.Equal(t, (testCase.want[user]), len(status.Peer)) assert.Len(t, status.Peer, (testCase.want[user]))
} }
}) })
} }
@ -319,23 +320,23 @@ func TestACLAllowUser80Dst(t *testing.T) {
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
user1Clients, err := scenario.ListTailscaleClients("user1") user1Clients, err := scenario.ListTailscaleClients("user1")
assertNoErr(t, err) require.NoError(t, err)
user2Clients, err := scenario.ListTailscaleClients("user2") user2Clients, err := scenario.ListTailscaleClients("user2")
assertNoErr(t, err) require.NoError(t, err)
// Test that user1 can visit all user2 // Test that user1 can visit all user2
for _, client := range user1Clients { for _, client := range user1Clients {
for _, peer := range user2Clients { for _, peer := range user2Clients {
fqdn, err := peer.FQDN() fqdn, err := peer.FQDN()
assertNoErr(t, err) require.NoError(t, err)
url := fmt.Sprintf("http://%s/etc/hostname", fqdn) url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
t.Logf("url from %s to %s", client.Hostname(), url) t.Logf("url from %s to %s", client.Hostname(), url)
result, err := client.Curl(url) result, err := client.Curl(url)
assert.Len(t, result, 13) assert.Len(t, result, 13)
assertNoErr(t, err) require.NoError(t, err)
} }
} }
@ -343,14 +344,14 @@ func TestACLAllowUser80Dst(t *testing.T) {
for _, client := range user2Clients { for _, client := range user2Clients {
for _, peer := range user1Clients { for _, peer := range user1Clients {
fqdn, err := peer.FQDN() fqdn, err := peer.FQDN()
assertNoErr(t, err) require.NoError(t, err)
url := fmt.Sprintf("http://%s/etc/hostname", fqdn) url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
t.Logf("url from %s to %s", client.Hostname(), url) t.Logf("url from %s to %s", client.Hostname(), url)
result, err := client.Curl(url) result, err := client.Curl(url)
assert.Empty(t, result) assert.Empty(t, result)
assert.Error(t, err) require.Error(t, err)
} }
} }
} }
@ -376,10 +377,10 @@ func TestACLDenyAllPort80(t *testing.T) {
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
allClients, err := scenario.ListTailscaleClients() allClients, err := scenario.ListTailscaleClients()
assertNoErr(t, err) require.NoError(t, err)
allHostnames, err := scenario.ListTailscaleClientsFQDNs() allHostnames, err := scenario.ListTailscaleClientsFQDNs()
assertNoErr(t, err) require.NoError(t, err)
for _, client := range allClients { for _, client := range allClients {
for _, hostname := range allHostnames { for _, hostname := range allHostnames {
@ -394,7 +395,7 @@ func TestACLDenyAllPort80(t *testing.T) {
result, err := client.Curl(url) result, err := client.Curl(url)
assert.Empty(t, result) assert.Empty(t, result)
assert.Error(t, err) require.Error(t, err)
} }
} }
} }
@ -420,23 +421,23 @@ func TestACLAllowUserDst(t *testing.T) {
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
user1Clients, err := scenario.ListTailscaleClients("user1") user1Clients, err := scenario.ListTailscaleClients("user1")
assertNoErr(t, err) require.NoError(t, err)
user2Clients, err := scenario.ListTailscaleClients("user2") user2Clients, err := scenario.ListTailscaleClients("user2")
assertNoErr(t, err) require.NoError(t, err)
// Test that user1 can visit all user2 // Test that user1 can visit all user2
for _, client := range user1Clients { for _, client := range user1Clients {
for _, peer := range user2Clients { for _, peer := range user2Clients {
fqdn, err := peer.FQDN() fqdn, err := peer.FQDN()
assertNoErr(t, err) require.NoError(t, err)
url := fmt.Sprintf("http://%s/etc/hostname", fqdn) url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
t.Logf("url from %s to %s", client.Hostname(), url) t.Logf("url from %s to %s", client.Hostname(), url)
result, err := client.Curl(url) result, err := client.Curl(url)
assert.Len(t, result, 13) assert.Len(t, result, 13)
assertNoErr(t, err) require.NoError(t, err)
} }
} }
@ -444,14 +445,14 @@ func TestACLAllowUserDst(t *testing.T) {
for _, client := range user2Clients { for _, client := range user2Clients {
for _, peer := range user1Clients { for _, peer := range user1Clients {
fqdn, err := peer.FQDN() fqdn, err := peer.FQDN()
assertNoErr(t, err) require.NoError(t, err)
url := fmt.Sprintf("http://%s/etc/hostname", fqdn) url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
t.Logf("url from %s to %s", client.Hostname(), url) t.Logf("url from %s to %s", client.Hostname(), url)
result, err := client.Curl(url) result, err := client.Curl(url)
assert.Empty(t, result) assert.Empty(t, result)
assert.Error(t, err) require.Error(t, err)
} }
} }
} }
@ -476,23 +477,23 @@ func TestACLAllowStarDst(t *testing.T) {
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
user1Clients, err := scenario.ListTailscaleClients("user1") user1Clients, err := scenario.ListTailscaleClients("user1")
assertNoErr(t, err) require.NoError(t, err)
user2Clients, err := scenario.ListTailscaleClients("user2") user2Clients, err := scenario.ListTailscaleClients("user2")
assertNoErr(t, err) require.NoError(t, err)
// Test that user1 can visit all user2 // Test that user1 can visit all user2
for _, client := range user1Clients { for _, client := range user1Clients {
for _, peer := range user2Clients { for _, peer := range user2Clients {
fqdn, err := peer.FQDN() fqdn, err := peer.FQDN()
assertNoErr(t, err) require.NoError(t, err)
url := fmt.Sprintf("http://%s/etc/hostname", fqdn) url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
t.Logf("url from %s to %s", client.Hostname(), url) t.Logf("url from %s to %s", client.Hostname(), url)
result, err := client.Curl(url) result, err := client.Curl(url)
assert.Len(t, result, 13) assert.Len(t, result, 13)
assertNoErr(t, err) require.NoError(t, err)
} }
} }
@ -500,14 +501,14 @@ func TestACLAllowStarDst(t *testing.T) {
for _, client := range user2Clients { for _, client := range user2Clients {
for _, peer := range user1Clients { for _, peer := range user1Clients {
fqdn, err := peer.FQDN() fqdn, err := peer.FQDN()
assertNoErr(t, err) require.NoError(t, err)
url := fmt.Sprintf("http://%s/etc/hostname", fqdn) url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
t.Logf("url from %s to %s", client.Hostname(), url) t.Logf("url from %s to %s", client.Hostname(), url)
result, err := client.Curl(url) result, err := client.Curl(url)
assert.Empty(t, result) assert.Empty(t, result)
assert.Error(t, err) require.Error(t, err)
} }
} }
} }
@ -537,23 +538,23 @@ func TestACLNamedHostsCanReachBySubnet(t *testing.T) {
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
user1Clients, err := scenario.ListTailscaleClients("user1") user1Clients, err := scenario.ListTailscaleClients("user1")
assertNoErr(t, err) require.NoError(t, err)
user2Clients, err := scenario.ListTailscaleClients("user2") user2Clients, err := scenario.ListTailscaleClients("user2")
assertNoErr(t, err) require.NoError(t, err)
// Test that user1 can visit all user2 // Test that user1 can visit all user2
for _, client := range user1Clients { for _, client := range user1Clients {
for _, peer := range user2Clients { for _, peer := range user2Clients {
fqdn, err := peer.FQDN() fqdn, err := peer.FQDN()
assertNoErr(t, err) require.NoError(t, err)
url := fmt.Sprintf("http://%s/etc/hostname", fqdn) url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
t.Logf("url from %s to %s", client.Hostname(), url) t.Logf("url from %s to %s", client.Hostname(), url)
result, err := client.Curl(url) result, err := client.Curl(url)
assert.Len(t, result, 13) assert.Len(t, result, 13)
assertNoErr(t, err) require.NoError(t, err)
} }
} }
@ -561,14 +562,14 @@ func TestACLNamedHostsCanReachBySubnet(t *testing.T) {
for _, client := range user2Clients { for _, client := range user2Clients {
for _, peer := range user1Clients { for _, peer := range user1Clients {
fqdn, err := peer.FQDN() fqdn, err := peer.FQDN()
assertNoErr(t, err) require.NoError(t, err)
url := fmt.Sprintf("http://%s/etc/hostname", fqdn) url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
t.Logf("url from %s to %s", client.Hostname(), url) t.Logf("url from %s to %s", client.Hostname(), url)
result, err := client.Curl(url) result, err := client.Curl(url)
assert.Len(t, result, 13) assert.Len(t, result, 13)
assertNoErr(t, err) require.NoError(t, err)
} }
} }
} }
@ -679,10 +680,10 @@ func TestACLNamedHostsCanReach(t *testing.T) {
test1ip4 := netip.MustParseAddr("100.64.0.1") test1ip4 := netip.MustParseAddr("100.64.0.1")
test1ip6 := netip.MustParseAddr("fd7a:115c:a1e0::1") test1ip6 := netip.MustParseAddr("fd7a:115c:a1e0::1")
test1, err := scenario.FindTailscaleClientByIP(test1ip6) test1, err := scenario.FindTailscaleClientByIP(test1ip6)
assertNoErr(t, err) require.NoError(t, err)
test1fqdn, err := test1.FQDN() test1fqdn, err := test1.FQDN()
assertNoErr(t, err) require.NoError(t, err)
test1ip4URL := fmt.Sprintf("http://%s/etc/hostname", test1ip4.String()) test1ip4URL := fmt.Sprintf("http://%s/etc/hostname", test1ip4.String())
test1ip6URL := fmt.Sprintf("http://[%s]/etc/hostname", test1ip6.String()) test1ip6URL := fmt.Sprintf("http://[%s]/etc/hostname", test1ip6.String())
test1fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test1fqdn) test1fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test1fqdn)
@ -690,10 +691,10 @@ func TestACLNamedHostsCanReach(t *testing.T) {
test2ip4 := netip.MustParseAddr("100.64.0.2") test2ip4 := netip.MustParseAddr("100.64.0.2")
test2ip6 := netip.MustParseAddr("fd7a:115c:a1e0::2") test2ip6 := netip.MustParseAddr("fd7a:115c:a1e0::2")
test2, err := scenario.FindTailscaleClientByIP(test2ip6) test2, err := scenario.FindTailscaleClientByIP(test2ip6)
assertNoErr(t, err) require.NoError(t, err)
test2fqdn, err := test2.FQDN() test2fqdn, err := test2.FQDN()
assertNoErr(t, err) require.NoError(t, err)
test2ip4URL := fmt.Sprintf("http://%s/etc/hostname", test2ip4.String()) test2ip4URL := fmt.Sprintf("http://%s/etc/hostname", test2ip4.String())
test2ip6URL := fmt.Sprintf("http://[%s]/etc/hostname", test2ip6.String()) test2ip6URL := fmt.Sprintf("http://[%s]/etc/hostname", test2ip6.String())
test2fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test2fqdn) test2fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test2fqdn)
@ -701,10 +702,10 @@ func TestACLNamedHostsCanReach(t *testing.T) {
test3ip4 := netip.MustParseAddr("100.64.0.3") test3ip4 := netip.MustParseAddr("100.64.0.3")
test3ip6 := netip.MustParseAddr("fd7a:115c:a1e0::3") test3ip6 := netip.MustParseAddr("fd7a:115c:a1e0::3")
test3, err := scenario.FindTailscaleClientByIP(test3ip6) test3, err := scenario.FindTailscaleClientByIP(test3ip6)
assertNoErr(t, err) require.NoError(t, err)
test3fqdn, err := test3.FQDN() test3fqdn, err := test3.FQDN()
assertNoErr(t, err) require.NoError(t, err)
test3ip4URL := fmt.Sprintf("http://%s/etc/hostname", test3ip4.String()) test3ip4URL := fmt.Sprintf("http://%s/etc/hostname", test3ip4.String())
test3ip6URL := fmt.Sprintf("http://[%s]/etc/hostname", test3ip6.String()) test3ip6URL := fmt.Sprintf("http://[%s]/etc/hostname", test3ip6.String())
test3fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test3fqdn) test3fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test3fqdn)
@ -719,7 +720,7 @@ func TestACLNamedHostsCanReach(t *testing.T) {
test3ip4URL, test3ip4URL,
result, result,
) )
assertNoErr(t, err) require.NoError(t, err)
result, err = test1.Curl(test3ip6URL) result, err = test1.Curl(test3ip6URL)
assert.Lenf( assert.Lenf(
@ -730,7 +731,7 @@ func TestACLNamedHostsCanReach(t *testing.T) {
test3ip6URL, test3ip6URL,
result, result,
) )
assertNoErr(t, err) require.NoError(t, err)
result, err = test1.Curl(test3fqdnURL) result, err = test1.Curl(test3fqdnURL)
assert.Lenf( assert.Lenf(
@ -741,7 +742,7 @@ func TestACLNamedHostsCanReach(t *testing.T) {
test3fqdnURL, test3fqdnURL,
result, result,
) )
assertNoErr(t, err) require.NoError(t, err)
// test2 can query test3 // test2 can query test3
result, err = test2.Curl(test3ip4URL) result, err = test2.Curl(test3ip4URL)
@ -753,7 +754,7 @@ func TestACLNamedHostsCanReach(t *testing.T) {
test3ip4URL, test3ip4URL,
result, result,
) )
assertNoErr(t, err) require.NoError(t, err)
result, err = test2.Curl(test3ip6URL) result, err = test2.Curl(test3ip6URL)
assert.Lenf( assert.Lenf(
@ -764,7 +765,7 @@ func TestACLNamedHostsCanReach(t *testing.T) {
test3ip6URL, test3ip6URL,
result, result,
) )
assertNoErr(t, err) require.NoError(t, err)
result, err = test2.Curl(test3fqdnURL) result, err = test2.Curl(test3fqdnURL)
assert.Lenf( assert.Lenf(
@ -775,33 +776,33 @@ func TestACLNamedHostsCanReach(t *testing.T) {
test3fqdnURL, test3fqdnURL,
result, result,
) )
assertNoErr(t, err) require.NoError(t, err)
// test3 cannot query test1 // test3 cannot query test1
result, err = test3.Curl(test1ip4URL) result, err = test3.Curl(test1ip4URL)
assert.Empty(t, result) assert.Empty(t, result)
assert.Error(t, err) require.Error(t, err)
result, err = test3.Curl(test1ip6URL) result, err = test3.Curl(test1ip6URL)
assert.Empty(t, result) assert.Empty(t, result)
assert.Error(t, err) require.Error(t, err)
result, err = test3.Curl(test1fqdnURL) result, err = test3.Curl(test1fqdnURL)
assert.Empty(t, result) assert.Empty(t, result)
assert.Error(t, err) require.Error(t, err)
// test3 cannot query test2 // test3 cannot query test2
result, err = test3.Curl(test2ip4URL) result, err = test3.Curl(test2ip4URL)
assert.Empty(t, result) assert.Empty(t, result)
assert.Error(t, err) require.Error(t, err)
result, err = test3.Curl(test2ip6URL) result, err = test3.Curl(test2ip6URL)
assert.Empty(t, result) assert.Empty(t, result)
assert.Error(t, err) require.Error(t, err)
result, err = test3.Curl(test2fqdnURL) result, err = test3.Curl(test2fqdnURL)
assert.Empty(t, result) assert.Empty(t, result)
assert.Error(t, err) require.Error(t, err)
// test1 can query test2 // test1 can query test2
result, err = test1.Curl(test2ip4URL) result, err = test1.Curl(test2ip4URL)
@ -814,7 +815,7 @@ func TestACLNamedHostsCanReach(t *testing.T) {
result, result,
) )
assertNoErr(t, err) require.NoError(t, err)
result, err = test1.Curl(test2ip6URL) result, err = test1.Curl(test2ip6URL)
assert.Lenf( assert.Lenf(
t, t,
@ -824,7 +825,7 @@ func TestACLNamedHostsCanReach(t *testing.T) {
test2ip6URL, test2ip6URL,
result, result,
) )
assertNoErr(t, err) require.NoError(t, err)
result, err = test1.Curl(test2fqdnURL) result, err = test1.Curl(test2fqdnURL)
assert.Lenf( assert.Lenf(
@ -835,20 +836,20 @@ func TestACLNamedHostsCanReach(t *testing.T) {
test2fqdnURL, test2fqdnURL,
result, result,
) )
assertNoErr(t, err) require.NoError(t, err)
// test2 cannot query test1 // test2 cannot query test1
result, err = test2.Curl(test1ip4URL) result, err = test2.Curl(test1ip4URL)
assert.Empty(t, result) assert.Empty(t, result)
assert.Error(t, err) require.Error(t, err)
result, err = test2.Curl(test1ip6URL) result, err = test2.Curl(test1ip6URL)
assert.Empty(t, result) assert.Empty(t, result)
assert.Error(t, err) require.Error(t, err)
result, err = test2.Curl(test1fqdnURL) result, err = test2.Curl(test1fqdnURL)
assert.Empty(t, result) assert.Empty(t, result)
assert.Error(t, err) require.Error(t, err)
}) })
} }
} }
@ -946,10 +947,10 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) {
test1ip6 := netip.MustParseAddr("fd7a:115c:a1e0::1") test1ip6 := netip.MustParseAddr("fd7a:115c:a1e0::1")
test1, err := scenario.FindTailscaleClientByIP(test1ip) test1, err := scenario.FindTailscaleClientByIP(test1ip)
assert.NotNil(t, test1) assert.NotNil(t, test1)
assertNoErr(t, err) require.NoError(t, err)
test1fqdn, err := test1.FQDN() test1fqdn, err := test1.FQDN()
assertNoErr(t, err) require.NoError(t, err)
test1ipURL := fmt.Sprintf("http://%s/etc/hostname", test1ip.String()) test1ipURL := fmt.Sprintf("http://%s/etc/hostname", test1ip.String())
test1ip6URL := fmt.Sprintf("http://[%s]/etc/hostname", test1ip6.String()) test1ip6URL := fmt.Sprintf("http://[%s]/etc/hostname", test1ip6.String())
test1fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test1fqdn) test1fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test1fqdn)
@ -958,10 +959,10 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) {
test2ip6 := netip.MustParseAddr("fd7a:115c:a1e0::2") test2ip6 := netip.MustParseAddr("fd7a:115c:a1e0::2")
test2, err := scenario.FindTailscaleClientByIP(test2ip) test2, err := scenario.FindTailscaleClientByIP(test2ip)
assert.NotNil(t, test2) assert.NotNil(t, test2)
assertNoErr(t, err) require.NoError(t, err)
test2fqdn, err := test2.FQDN() test2fqdn, err := test2.FQDN()
assertNoErr(t, err) require.NoError(t, err)
test2ipURL := fmt.Sprintf("http://%s/etc/hostname", test2ip.String()) test2ipURL := fmt.Sprintf("http://%s/etc/hostname", test2ip.String())
test2ip6URL := fmt.Sprintf("http://[%s]/etc/hostname", test2ip6.String()) test2ip6URL := fmt.Sprintf("http://[%s]/etc/hostname", test2ip6.String())
test2fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test2fqdn) test2fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test2fqdn)
@ -976,7 +977,7 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) {
test2ipURL, test2ipURL,
result, result,
) )
assertNoErr(t, err) require.NoError(t, err)
result, err = test1.Curl(test2ip6URL) result, err = test1.Curl(test2ip6URL)
assert.Lenf( assert.Lenf(
@ -987,7 +988,7 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) {
test2ip6URL, test2ip6URL,
result, result,
) )
assertNoErr(t, err) require.NoError(t, err)
result, err = test1.Curl(test2fqdnURL) result, err = test1.Curl(test2fqdnURL)
assert.Lenf( assert.Lenf(
@ -998,19 +999,19 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) {
test2fqdnURL, test2fqdnURL,
result, result,
) )
assertNoErr(t, err) require.NoError(t, err)
result, err = test2.Curl(test1ipURL) result, err = test2.Curl(test1ipURL)
assert.Empty(t, result) assert.Empty(t, result)
assert.Error(t, err) require.Error(t, err)
result, err = test2.Curl(test1ip6URL) result, err = test2.Curl(test1ip6URL)
assert.Empty(t, result) assert.Empty(t, result)
assert.Error(t, err) require.Error(t, err)
result, err = test2.Curl(test1fqdnURL) result, err = test2.Curl(test1fqdnURL)
assert.Empty(t, result) assert.Empty(t, result)
assert.Error(t, err) require.Error(t, err)
}) })
} }
} }
@ -1020,7 +1021,7 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) {
t.Parallel() t.Parallel()
scenario, err := NewScenario(dockertestMaxWait()) scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{ spec := map[string]int{
@ -1046,19 +1047,19 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) {
"HEADSCALE_POLICY_MODE": "database", "HEADSCALE_POLICY_MODE": "database",
}), }),
) )
assertNoErr(t, err) require.NoError(t, err)
_, err = scenario.ListTailscaleClientsFQDNs() _, err = scenario.ListTailscaleClientsFQDNs()
assertNoErrListFQDN(t, err) require.NoError(t, err)
err = scenario.WaitForTailscaleSync() err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err) require.NoError(t, err)
user1Clients, err := scenario.ListTailscaleClients("user1") user1Clients, err := scenario.ListTailscaleClients("user1")
assertNoErr(t, err) require.NoError(t, err)
user2Clients, err := scenario.ListTailscaleClients("user2") user2Clients, err := scenario.ListTailscaleClients("user2")
assertNoErr(t, err) require.NoError(t, err)
all := append(user1Clients, user2Clients...) all := append(user1Clients, user2Clients...)
@ -1070,19 +1071,19 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) {
} }
fqdn, err := peer.FQDN() fqdn, err := peer.FQDN()
assertNoErr(t, err) require.NoError(t, err)
url := fmt.Sprintf("http://%s/etc/hostname", fqdn) url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
t.Logf("url from %s to %s", client.Hostname(), url) t.Logf("url from %s to %s", client.Hostname(), url)
result, err := client.Curl(url) result, err := client.Curl(url)
assert.Len(t, result, 13) assert.Len(t, result, 13)
assertNoErr(t, err) require.NoError(t, err)
} }
} }
headscale, err := scenario.Headscale() headscale, err := scenario.Headscale()
assertNoErr(t, err) require.NoError(t, err)
p := policy.ACLPolicy{ p := policy.ACLPolicy{
ACLs: []policy.ACL{ ACLs: []policy.ACL{
@ -1100,7 +1101,7 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) {
policyFilePath := "/etc/headscale/policy.json" policyFilePath := "/etc/headscale/policy.json"
err = headscale.WriteFile(policyFilePath, pBytes) err = headscale.WriteFile(policyFilePath, pBytes)
assertNoErr(t, err) require.NoError(t, err)
// No policy is present at this time. // No policy is present at this time.
// Add a new policy from a file. // Add a new policy from a file.
@ -1113,7 +1114,7 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) {
policyFilePath, policyFilePath,
}, },
) )
assertNoErr(t, err) require.NoError(t, err)
// Get the current policy and check // Get the current policy and check
// if it is the same as the one we set. // if it is the same as the one we set.
@ -1129,7 +1130,7 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) {
}, },
&output, &output,
) )
assertNoErr(t, err) require.NoError(t, err)
assert.Len(t, output.ACLs, 1) assert.Len(t, output.ACLs, 1)
@ -1141,14 +1142,14 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) {
for _, client := range user1Clients { for _, client := range user1Clients {
for _, peer := range user2Clients { for _, peer := range user2Clients {
fqdn, err := peer.FQDN() fqdn, err := peer.FQDN()
assertNoErr(t, err) require.NoError(t, err)
url := fmt.Sprintf("http://%s/etc/hostname", fqdn) url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
t.Logf("url from %s to %s", client.Hostname(), url) t.Logf("url from %s to %s", client.Hostname(), url)
result, err := client.Curl(url) result, err := client.Curl(url)
assert.Len(t, result, 13) assert.Len(t, result, 13)
assertNoErr(t, err) require.NoError(t, err)
} }
} }
@ -1156,14 +1157,14 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) {
for _, client := range user2Clients { for _, client := range user2Clients {
for _, peer := range user1Clients { for _, peer := range user1Clients {
fqdn, err := peer.FQDN() fqdn, err := peer.FQDN()
assertNoErr(t, err) require.NoError(t, err)
url := fmt.Sprintf("http://%s/etc/hostname", fqdn) url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
t.Logf("url from %s to %s", client.Hostname(), url) t.Logf("url from %s to %s", client.Hostname(), url)
result, err := client.Curl(url) result, err := client.Curl(url)
assert.Empty(t, result) assert.Empty(t, result)
assert.Error(t, err) require.Error(t, err)
} }
} }
} }

View file

@ -13,6 +13,7 @@ import (
"github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/tsic" "github.com/juanfont/headscale/integration/tsic"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func executeAndUnmarshal[T any](headscale ControlServer, command []string, result T) error { func executeAndUnmarshal[T any](headscale ControlServer, command []string, result T) error {
@ -34,7 +35,7 @@ func TestUserCommand(t *testing.T) {
t.Parallel() t.Parallel()
scenario, err := NewScenario(dockertestMaxWait()) scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{ spec := map[string]int{
@ -43,10 +44,10 @@ func TestUserCommand(t *testing.T) {
} }
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins")) err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins"))
assertNoErr(t, err) require.NoError(t, err)
headscale, err := scenario.Headscale() headscale, err := scenario.Headscale()
assertNoErr(t, err) require.NoError(t, err)
var listUsers []v1.User var listUsers []v1.User
err = executeAndUnmarshal(headscale, err = executeAndUnmarshal(headscale,
@ -59,7 +60,7 @@ func TestUserCommand(t *testing.T) {
}, },
&listUsers, &listUsers,
) )
assertNoErr(t, err) require.NoError(t, err)
result := []string{listUsers[0].GetName(), listUsers[1].GetName()} result := []string{listUsers[0].GetName(), listUsers[1].GetName()}
sort.Strings(result) sort.Strings(result)
@ -81,7 +82,7 @@ func TestUserCommand(t *testing.T) {
"newname", "newname",
}, },
) )
assertNoErr(t, err) require.NoError(t, err)
var listAfterRenameUsers []v1.User var listAfterRenameUsers []v1.User
err = executeAndUnmarshal(headscale, err = executeAndUnmarshal(headscale,
@ -94,7 +95,7 @@ func TestUserCommand(t *testing.T) {
}, },
&listAfterRenameUsers, &listAfterRenameUsers,
) )
assertNoErr(t, err) require.NoError(t, err)
result = []string{listAfterRenameUsers[0].GetName(), listAfterRenameUsers[1].GetName()} result = []string{listAfterRenameUsers[0].GetName(), listAfterRenameUsers[1].GetName()}
sort.Strings(result) sort.Strings(result)
@ -114,7 +115,7 @@ func TestPreAuthKeyCommand(t *testing.T) {
count := 3 count := 3
scenario, err := NewScenario(dockertestMaxWait()) scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{ spec := map[string]int{
@ -122,13 +123,13 @@ func TestPreAuthKeyCommand(t *testing.T) {
} }
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clipak")) err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clipak"))
assertNoErr(t, err) require.NoError(t, err)
headscale, err := scenario.Headscale() headscale, err := scenario.Headscale()
assertNoErr(t, err) require.NoError(t, err)
keys := make([]*v1.PreAuthKey, count) keys := make([]*v1.PreAuthKey, count)
assertNoErr(t, err) require.NoError(t, err)
for index := 0; index < count; index++ { for index := 0; index < count; index++ {
var preAuthKey v1.PreAuthKey var preAuthKey v1.PreAuthKey
@ -150,7 +151,7 @@ func TestPreAuthKeyCommand(t *testing.T) {
}, },
&preAuthKey, &preAuthKey,
) )
assertNoErr(t, err) require.NoError(t, err)
keys[index] = &preAuthKey keys[index] = &preAuthKey
} }
@ -171,7 +172,7 @@ func TestPreAuthKeyCommand(t *testing.T) {
}, },
&listedPreAuthKeys, &listedPreAuthKeys,
) )
assertNoErr(t, err) require.NoError(t, err)
// There is one key created by "scenario.CreateHeadscaleEnv" // There is one key created by "scenario.CreateHeadscaleEnv"
assert.Len(t, listedPreAuthKeys, 4) assert.Len(t, listedPreAuthKeys, 4)
@ -212,7 +213,7 @@ func TestPreAuthKeyCommand(t *testing.T) {
continue continue
} }
assert.Equal(t, listedPreAuthKeys[index].GetAclTags(), []string{"tag:test1", "tag:test2"}) assert.Equal(t, []string{"tag:test1", "tag:test2"}, listedPreAuthKeys[index].GetAclTags())
} }
// Test key expiry // Test key expiry
@ -226,7 +227,7 @@ func TestPreAuthKeyCommand(t *testing.T) {
listedPreAuthKeys[1].GetKey(), listedPreAuthKeys[1].GetKey(),
}, },
) )
assertNoErr(t, err) require.NoError(t, err)
var listedPreAuthKeysAfterExpire []v1.PreAuthKey var listedPreAuthKeysAfterExpire []v1.PreAuthKey
err = executeAndUnmarshal( err = executeAndUnmarshal(
@ -242,7 +243,7 @@ func TestPreAuthKeyCommand(t *testing.T) {
}, },
&listedPreAuthKeysAfterExpire, &listedPreAuthKeysAfterExpire,
) )
assertNoErr(t, err) require.NoError(t, err)
assert.True(t, listedPreAuthKeysAfterExpire[1].GetExpiration().AsTime().Before(time.Now())) assert.True(t, listedPreAuthKeysAfterExpire[1].GetExpiration().AsTime().Before(time.Now()))
assert.True(t, listedPreAuthKeysAfterExpire[2].GetExpiration().AsTime().After(time.Now())) assert.True(t, listedPreAuthKeysAfterExpire[2].GetExpiration().AsTime().After(time.Now()))
@ -256,7 +257,7 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) {
user := "pre-auth-key-without-exp-user" user := "pre-auth-key-without-exp-user"
scenario, err := NewScenario(dockertestMaxWait()) scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{ spec := map[string]int{
@ -264,10 +265,10 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) {
} }
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clipaknaexp")) err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clipaknaexp"))
assertNoErr(t, err) require.NoError(t, err)
headscale, err := scenario.Headscale() headscale, err := scenario.Headscale()
assertNoErr(t, err) require.NoError(t, err)
var preAuthKey v1.PreAuthKey var preAuthKey v1.PreAuthKey
err = executeAndUnmarshal( err = executeAndUnmarshal(
@ -284,7 +285,7 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) {
}, },
&preAuthKey, &preAuthKey,
) )
assertNoErr(t, err) require.NoError(t, err)
var listedPreAuthKeys []v1.PreAuthKey var listedPreAuthKeys []v1.PreAuthKey
err = executeAndUnmarshal( err = executeAndUnmarshal(
@ -300,7 +301,7 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) {
}, },
&listedPreAuthKeys, &listedPreAuthKeys,
) )
assertNoErr(t, err) require.NoError(t, err)
// There is one key created by "scenario.CreateHeadscaleEnv" // There is one key created by "scenario.CreateHeadscaleEnv"
assert.Len(t, listedPreAuthKeys, 2) assert.Len(t, listedPreAuthKeys, 2)
@ -319,7 +320,7 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
user := "pre-auth-key-reus-ephm-user" user := "pre-auth-key-reus-ephm-user"
scenario, err := NewScenario(dockertestMaxWait()) scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{ spec := map[string]int{
@ -327,10 +328,10 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
} }
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clipakresueeph")) err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clipakresueeph"))
assertNoErr(t, err) require.NoError(t, err)
headscale, err := scenario.Headscale() headscale, err := scenario.Headscale()
assertNoErr(t, err) require.NoError(t, err)
var preAuthReusableKey v1.PreAuthKey var preAuthReusableKey v1.PreAuthKey
err = executeAndUnmarshal( err = executeAndUnmarshal(
@ -347,7 +348,7 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
}, },
&preAuthReusableKey, &preAuthReusableKey,
) )
assertNoErr(t, err) require.NoError(t, err)
var preAuthEphemeralKey v1.PreAuthKey var preAuthEphemeralKey v1.PreAuthKey
err = executeAndUnmarshal( err = executeAndUnmarshal(
@ -364,7 +365,7 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
}, },
&preAuthEphemeralKey, &preAuthEphemeralKey,
) )
assertNoErr(t, err) require.NoError(t, err)
assert.True(t, preAuthEphemeralKey.GetEphemeral()) assert.True(t, preAuthEphemeralKey.GetEphemeral())
assert.False(t, preAuthEphemeralKey.GetReusable()) assert.False(t, preAuthEphemeralKey.GetReusable())
@ -383,7 +384,7 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
}, },
&listedPreAuthKeys, &listedPreAuthKeys,
) )
assertNoErr(t, err) require.NoError(t, err)
// There is one key created by "scenario.CreateHeadscaleEnv" // There is one key created by "scenario.CreateHeadscaleEnv"
assert.Len(t, listedPreAuthKeys, 3) assert.Len(t, listedPreAuthKeys, 3)
@ -397,7 +398,7 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
user2 := "user2" user2 := "user2"
scenario, err := NewScenario(dockertestMaxWait()) scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{ spec := map[string]int{
@ -413,10 +414,10 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
hsic.WithTLS(), hsic.WithTLS(),
hsic.WithHostnameAsServerURL(), hsic.WithHostnameAsServerURL(),
) )
assertNoErr(t, err) require.NoError(t, err)
headscale, err := scenario.Headscale() headscale, err := scenario.Headscale()
assertNoErr(t, err) require.NoError(t, err)
var user2Key v1.PreAuthKey var user2Key v1.PreAuthKey
@ -438,10 +439,10 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
}, },
&user2Key, &user2Key,
) )
assertNoErr(t, err) require.NoError(t, err)
allClients, err := scenario.ListTailscaleClients() allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err) require.NoError(t, err)
assert.Len(t, allClients, 1) assert.Len(t, allClients, 1)
@ -449,22 +450,22 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
// Log out from user1 // Log out from user1
err = client.Logout() err = client.Logout()
assertNoErr(t, err) require.NoError(t, err)
err = scenario.WaitForTailscaleLogout() err = scenario.WaitForTailscaleLogout()
assertNoErr(t, err) require.NoError(t, err)
status, err := client.Status() status, err := client.Status()
assertNoErr(t, err) require.NoError(t, err)
if status.BackendState == "Starting" || status.BackendState == "Running" { if status.BackendState == "Starting" || status.BackendState == "Running" {
t.Fatalf("expected node to be logged out, backend state: %s", status.BackendState) t.Fatalf("expected node to be logged out, backend state: %s", status.BackendState)
} }
err = client.Login(headscale.GetEndpoint(), user2Key.GetKey()) err = client.Login(headscale.GetEndpoint(), user2Key.GetKey())
assertNoErr(t, err) require.NoError(t, err)
status, err = client.Status() status, err = client.Status()
assertNoErr(t, err) require.NoError(t, err)
if status.BackendState != "Running" { if status.BackendState != "Running" {
t.Fatalf("expected node to be logged in, backend state: %s", status.BackendState) t.Fatalf("expected node to be logged in, backend state: %s", status.BackendState)
} }
@ -485,7 +486,7 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
}, },
&listNodes, &listNodes,
) )
assert.Nil(t, err) require.NoError(t, err)
assert.Len(t, listNodes, 1) assert.Len(t, listNodes, 1)
assert.Equal(t, "user2", listNodes[0].GetUser().GetName()) assert.Equal(t, "user2", listNodes[0].GetUser().GetName())
@ -498,7 +499,7 @@ func TestApiKeyCommand(t *testing.T) {
count := 5 count := 5
scenario, err := NewScenario(dockertestMaxWait()) scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{ spec := map[string]int{
@ -507,10 +508,10 @@ func TestApiKeyCommand(t *testing.T) {
} }
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins")) err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins"))
assertNoErr(t, err) require.NoError(t, err)
headscale, err := scenario.Headscale() headscale, err := scenario.Headscale()
assertNoErr(t, err) require.NoError(t, err)
keys := make([]string, count) keys := make([]string, count)
@ -526,7 +527,7 @@ func TestApiKeyCommand(t *testing.T) {
"json", "json",
}, },
) )
assert.Nil(t, err) require.NoError(t, err)
assert.NotEmpty(t, apiResult) assert.NotEmpty(t, apiResult)
keys[idx] = apiResult keys[idx] = apiResult
@ -545,7 +546,7 @@ func TestApiKeyCommand(t *testing.T) {
}, },
&listedAPIKeys, &listedAPIKeys,
) )
assert.Nil(t, err) require.NoError(t, err)
assert.Len(t, listedAPIKeys, 5) assert.Len(t, listedAPIKeys, 5)
@ -601,7 +602,7 @@ func TestApiKeyCommand(t *testing.T) {
listedAPIKeys[idx].GetPrefix(), listedAPIKeys[idx].GetPrefix(),
}, },
) )
assert.Nil(t, err) require.NoError(t, err)
expiredPrefixes[listedAPIKeys[idx].GetPrefix()] = true expiredPrefixes[listedAPIKeys[idx].GetPrefix()] = true
} }
@ -617,7 +618,7 @@ func TestApiKeyCommand(t *testing.T) {
}, },
&listedAfterExpireAPIKeys, &listedAfterExpireAPIKeys,
) )
assert.Nil(t, err) require.NoError(t, err)
for index := range listedAfterExpireAPIKeys { for index := range listedAfterExpireAPIKeys {
if _, ok := expiredPrefixes[listedAfterExpireAPIKeys[index].GetPrefix()]; ok { if _, ok := expiredPrefixes[listedAfterExpireAPIKeys[index].GetPrefix()]; ok {
@ -643,7 +644,7 @@ func TestApiKeyCommand(t *testing.T) {
"--prefix", "--prefix",
listedAPIKeys[0].GetPrefix(), listedAPIKeys[0].GetPrefix(),
}) })
assert.Nil(t, err) require.NoError(t, err)
var listedAPIKeysAfterDelete []v1.ApiKey var listedAPIKeysAfterDelete []v1.ApiKey
err = executeAndUnmarshal(headscale, err = executeAndUnmarshal(headscale,
@ -656,7 +657,7 @@ func TestApiKeyCommand(t *testing.T) {
}, },
&listedAPIKeysAfterDelete, &listedAPIKeysAfterDelete,
) )
assert.Nil(t, err) require.NoError(t, err)
assert.Len(t, listedAPIKeysAfterDelete, 4) assert.Len(t, listedAPIKeysAfterDelete, 4)
} }
@ -666,7 +667,7 @@ func TestNodeTagCommand(t *testing.T) {
t.Parallel() t.Parallel()
scenario, err := NewScenario(dockertestMaxWait()) scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{ spec := map[string]int{
@ -674,17 +675,17 @@ func TestNodeTagCommand(t *testing.T) {
} }
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins")) err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins"))
assertNoErr(t, err) require.NoError(t, err)
headscale, err := scenario.Headscale() headscale, err := scenario.Headscale()
assertNoErr(t, err) require.NoError(t, err)
machineKeys := []string{ machineKeys := []string{
"mkey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe", "mkey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
"mkey:6abd00bb5fdda622db51387088c68e97e71ce58e7056aa54f592b6a8219d524c", "mkey:6abd00bb5fdda622db51387088c68e97e71ce58e7056aa54f592b6a8219d524c",
} }
nodes := make([]*v1.Node, len(machineKeys)) nodes := make([]*v1.Node, len(machineKeys))
assert.Nil(t, err) require.NoError(t, err)
for index, machineKey := range machineKeys { for index, machineKey := range machineKeys {
_, err := headscale.Execute( _, err := headscale.Execute(
@ -702,7 +703,7 @@ func TestNodeTagCommand(t *testing.T) {
"json", "json",
}, },
) )
assert.Nil(t, err) require.NoError(t, err)
var node v1.Node var node v1.Node
err = executeAndUnmarshal( err = executeAndUnmarshal(
@ -720,7 +721,7 @@ func TestNodeTagCommand(t *testing.T) {
}, },
&node, &node,
) )
assert.Nil(t, err) require.NoError(t, err)
nodes[index] = &node nodes[index] = &node
} }
@ -739,7 +740,7 @@ func TestNodeTagCommand(t *testing.T) {
}, },
&node, &node,
) )
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, []string{"tag:test"}, node.GetForcedTags()) assert.Equal(t, []string{"tag:test"}, node.GetForcedTags())
@ -753,7 +754,7 @@ func TestNodeTagCommand(t *testing.T) {
"--output", "json", "--output", "json",
}, },
) )
assert.ErrorContains(t, err, "tag must start with the string 'tag:'") require.ErrorContains(t, err, "tag must start with the string 'tag:'")
// Test list all nodes after added seconds // Test list all nodes after added seconds
resultMachines := make([]*v1.Node, len(machineKeys)) resultMachines := make([]*v1.Node, len(machineKeys))
@ -767,7 +768,7 @@ func TestNodeTagCommand(t *testing.T) {
}, },
&resultMachines, &resultMachines,
) )
assert.Nil(t, err) require.NoError(t, err)
found := false found := false
for _, node := range resultMachines { for _, node := range resultMachines {
if node.GetForcedTags() != nil { if node.GetForcedTags() != nil {
@ -778,9 +779,8 @@ func TestNodeTagCommand(t *testing.T) {
} }
} }
} }
assert.Equal( assert.True(
t, t,
true,
found, found,
"should find a node with the tag 'tag:test' in the list of nodes", "should find a node with the tag 'tag:test' in the list of nodes",
) )
@ -791,18 +791,22 @@ func TestNodeAdvertiseTagNoACLCommand(t *testing.T) {
t.Parallel() t.Parallel()
scenario, err := NewScenario(dockertestMaxWait()) scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{ spec := map[string]int{
"user1": 1, "user1": 1,
} }
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithTags([]string{"tag:test"})}, hsic.WithTestName("cliadvtags")) err = scenario.CreateHeadscaleEnv(
assertNoErr(t, err) spec,
[]tsic.Option{tsic.WithTags([]string{"tag:test"})},
hsic.WithTestName("cliadvtags"),
)
require.NoError(t, err)
headscale, err := scenario.Headscale() headscale, err := scenario.Headscale()
assertNoErr(t, err) require.NoError(t, err)
// Test list all nodes after added seconds // Test list all nodes after added seconds
resultMachines := make([]*v1.Node, spec["user1"]) resultMachines := make([]*v1.Node, spec["user1"])
@ -817,7 +821,7 @@ func TestNodeAdvertiseTagNoACLCommand(t *testing.T) {
}, },
&resultMachines, &resultMachines,
) )
assert.Nil(t, err) require.NoError(t, err)
found := false found := false
for _, node := range resultMachines { for _, node := range resultMachines {
if node.GetInvalidTags() != nil { if node.GetInvalidTags() != nil {
@ -828,9 +832,8 @@ func TestNodeAdvertiseTagNoACLCommand(t *testing.T) {
} }
} }
} }
assert.Equal( assert.True(
t, t,
true,
found, found,
"should not find a node with the tag 'tag:test' in the list of nodes", "should not find a node with the tag 'tag:test' in the list of nodes",
) )
@ -841,31 +844,36 @@ func TestNodeAdvertiseTagWithACLCommand(t *testing.T) {
t.Parallel() t.Parallel()
scenario, err := NewScenario(dockertestMaxWait()) scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{ spec := map[string]int{
"user1": 1, "user1": 1,
} }
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithTags([]string{"tag:exists"})}, hsic.WithTestName("cliadvtags"), hsic.WithACLPolicy( err = scenario.CreateHeadscaleEnv(
&policy.ACLPolicy{ spec,
ACLs: []policy.ACL{ []tsic.Option{tsic.WithTags([]string{"tag:exists"})},
{ hsic.WithTestName("cliadvtags"),
Action: "accept", hsic.WithACLPolicy(
Sources: []string{"*"}, &policy.ACLPolicy{
Destinations: []string{"*:*"}, ACLs: []policy.ACL{
{
Action: "accept",
Sources: []string{"*"},
Destinations: []string{"*:*"},
},
},
TagOwners: map[string][]string{
"tag:exists": {"user1"},
}, },
}, },
TagOwners: map[string][]string{ ),
"tag:exists": {"user1"}, )
}, require.NoError(t, err)
},
))
assertNoErr(t, err)
headscale, err := scenario.Headscale() headscale, err := scenario.Headscale()
assertNoErr(t, err) require.NoError(t, err)
// Test list all nodes after added seconds // Test list all nodes after added seconds
resultMachines := make([]*v1.Node, spec["user1"]) resultMachines := make([]*v1.Node, spec["user1"])
@ -880,7 +888,7 @@ func TestNodeAdvertiseTagWithACLCommand(t *testing.T) {
}, },
&resultMachines, &resultMachines,
) )
assert.Nil(t, err) require.NoError(t, err)
found := false found := false
for _, node := range resultMachines { for _, node := range resultMachines {
if node.GetValidTags() != nil { if node.GetValidTags() != nil {
@ -891,9 +899,8 @@ func TestNodeAdvertiseTagWithACLCommand(t *testing.T) {
} }
} }
} }
assert.Equal( assert.True(
t, t,
true,
found, found,
"should not find a node with the tag 'tag:exists' in the list of nodes", "should not find a node with the tag 'tag:exists' in the list of nodes",
) )
@ -904,7 +911,7 @@ func TestNodeCommand(t *testing.T) {
t.Parallel() t.Parallel()
scenario, err := NewScenario(dockertestMaxWait()) scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{ spec := map[string]int{
@ -913,10 +920,10 @@ func TestNodeCommand(t *testing.T) {
} }
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins")) err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins"))
assertNoErr(t, err) require.NoError(t, err)
headscale, err := scenario.Headscale() headscale, err := scenario.Headscale()
assertNoErr(t, err) require.NoError(t, err)
// Pregenerated machine keys // Pregenerated machine keys
machineKeys := []string{ machineKeys := []string{
@ -927,7 +934,7 @@ func TestNodeCommand(t *testing.T) {
"mkey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084", "mkey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084",
} }
nodes := make([]*v1.Node, len(machineKeys)) nodes := make([]*v1.Node, len(machineKeys))
assert.Nil(t, err) require.NoError(t, err)
for index, machineKey := range machineKeys { for index, machineKey := range machineKeys {
_, err := headscale.Execute( _, err := headscale.Execute(
@ -945,7 +952,7 @@ func TestNodeCommand(t *testing.T) {
"json", "json",
}, },
) )
assert.Nil(t, err) require.NoError(t, err)
var node v1.Node var node v1.Node
err = executeAndUnmarshal( err = executeAndUnmarshal(
@ -963,7 +970,7 @@ func TestNodeCommand(t *testing.T) {
}, },
&node, &node,
) )
assert.Nil(t, err) require.NoError(t, err)
nodes[index] = &node nodes[index] = &node
} }
@ -983,7 +990,7 @@ func TestNodeCommand(t *testing.T) {
}, },
&listAll, &listAll,
) )
assert.Nil(t, err) require.NoError(t, err)
assert.Len(t, listAll, 5) assert.Len(t, listAll, 5)
@ -1004,7 +1011,7 @@ func TestNodeCommand(t *testing.T) {
"mkey:dc721977ac7415aafa87f7d4574cbe07c6b171834a6d37375782bdc1fb6b3584", "mkey:dc721977ac7415aafa87f7d4574cbe07c6b171834a6d37375782bdc1fb6b3584",
} }
otherUserMachines := make([]*v1.Node, len(otherUserMachineKeys)) otherUserMachines := make([]*v1.Node, len(otherUserMachineKeys))
assert.Nil(t, err) require.NoError(t, err)
for index, machineKey := range otherUserMachineKeys { for index, machineKey := range otherUserMachineKeys {
_, err := headscale.Execute( _, err := headscale.Execute(
@ -1022,7 +1029,7 @@ func TestNodeCommand(t *testing.T) {
"json", "json",
}, },
) )
assert.Nil(t, err) require.NoError(t, err)
var node v1.Node var node v1.Node
err = executeAndUnmarshal( err = executeAndUnmarshal(
@ -1040,7 +1047,7 @@ func TestNodeCommand(t *testing.T) {
}, },
&node, &node,
) )
assert.Nil(t, err) require.NoError(t, err)
otherUserMachines[index] = &node otherUserMachines[index] = &node
} }
@ -1060,7 +1067,7 @@ func TestNodeCommand(t *testing.T) {
}, },
&listAllWithotherUser, &listAllWithotherUser,
) )
assert.Nil(t, err) require.NoError(t, err)
// All nodes, nodes + otherUser // All nodes, nodes + otherUser
assert.Len(t, listAllWithotherUser, 7) assert.Len(t, listAllWithotherUser, 7)
@ -1086,7 +1093,7 @@ func TestNodeCommand(t *testing.T) {
}, },
&listOnlyotherUserMachineUser, &listOnlyotherUserMachineUser,
) )
assert.Nil(t, err) require.NoError(t, err)
assert.Len(t, listOnlyotherUserMachineUser, 2) assert.Len(t, listOnlyotherUserMachineUser, 2)
@ -1118,7 +1125,7 @@ func TestNodeCommand(t *testing.T) {
"--force", "--force",
}, },
) )
assert.Nil(t, err) require.NoError(t, err)
// Test: list main user after node is deleted // Test: list main user after node is deleted
var listOnlyMachineUserAfterDelete []v1.Node var listOnlyMachineUserAfterDelete []v1.Node
@ -1135,7 +1142,7 @@ func TestNodeCommand(t *testing.T) {
}, },
&listOnlyMachineUserAfterDelete, &listOnlyMachineUserAfterDelete,
) )
assert.Nil(t, err) require.NoError(t, err)
assert.Len(t, listOnlyMachineUserAfterDelete, 4) assert.Len(t, listOnlyMachineUserAfterDelete, 4)
} }
@ -1145,7 +1152,7 @@ func TestNodeExpireCommand(t *testing.T) {
t.Parallel() t.Parallel()
scenario, err := NewScenario(dockertestMaxWait()) scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{ spec := map[string]int{
@ -1153,10 +1160,10 @@ func TestNodeExpireCommand(t *testing.T) {
} }
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins")) err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins"))
assertNoErr(t, err) require.NoError(t, err)
headscale, err := scenario.Headscale() headscale, err := scenario.Headscale()
assertNoErr(t, err) require.NoError(t, err)
// Pregenerated machine keys // Pregenerated machine keys
machineKeys := []string{ machineKeys := []string{
@ -1184,7 +1191,7 @@ func TestNodeExpireCommand(t *testing.T) {
"json", "json",
}, },
) )
assert.Nil(t, err) require.NoError(t, err)
var node v1.Node var node v1.Node
err = executeAndUnmarshal( err = executeAndUnmarshal(
@ -1202,7 +1209,7 @@ func TestNodeExpireCommand(t *testing.T) {
}, },
&node, &node,
) )
assert.Nil(t, err) require.NoError(t, err)
nodes[index] = &node nodes[index] = &node
} }
@ -1221,7 +1228,7 @@ func TestNodeExpireCommand(t *testing.T) {
}, },
&listAll, &listAll,
) )
assert.Nil(t, err) require.NoError(t, err)
assert.Len(t, listAll, 5) assert.Len(t, listAll, 5)
@ -1241,7 +1248,7 @@ func TestNodeExpireCommand(t *testing.T) {
fmt.Sprintf("%d", listAll[idx].GetId()), fmt.Sprintf("%d", listAll[idx].GetId()),
}, },
) )
assert.Nil(t, err) require.NoError(t, err)
} }
var listAllAfterExpiry []v1.Node var listAllAfterExpiry []v1.Node
@ -1256,7 +1263,7 @@ func TestNodeExpireCommand(t *testing.T) {
}, },
&listAllAfterExpiry, &listAllAfterExpiry,
) )
assert.Nil(t, err) require.NoError(t, err)
assert.Len(t, listAllAfterExpiry, 5) assert.Len(t, listAllAfterExpiry, 5)
@ -1272,7 +1279,7 @@ func TestNodeRenameCommand(t *testing.T) {
t.Parallel() t.Parallel()
scenario, err := NewScenario(dockertestMaxWait()) scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{ spec := map[string]int{
@ -1280,10 +1287,10 @@ func TestNodeRenameCommand(t *testing.T) {
} }
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins")) err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins"))
assertNoErr(t, err) require.NoError(t, err)
headscale, err := scenario.Headscale() headscale, err := scenario.Headscale()
assertNoErr(t, err) require.NoError(t, err)
// Pregenerated machine keys // Pregenerated machine keys
machineKeys := []string{ machineKeys := []string{
@ -1294,7 +1301,7 @@ func TestNodeRenameCommand(t *testing.T) {
"mkey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe", "mkey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
} }
nodes := make([]*v1.Node, len(machineKeys)) nodes := make([]*v1.Node, len(machineKeys))
assert.Nil(t, err) require.NoError(t, err)
for index, machineKey := range machineKeys { for index, machineKey := range machineKeys {
_, err := headscale.Execute( _, err := headscale.Execute(
@ -1312,7 +1319,7 @@ func TestNodeRenameCommand(t *testing.T) {
"json", "json",
}, },
) )
assertNoErr(t, err) require.NoError(t, err)
var node v1.Node var node v1.Node
err = executeAndUnmarshal( err = executeAndUnmarshal(
@ -1330,7 +1337,7 @@ func TestNodeRenameCommand(t *testing.T) {
}, },
&node, &node,
) )
assertNoErr(t, err) require.NoError(t, err)
nodes[index] = &node nodes[index] = &node
} }
@ -1349,7 +1356,7 @@ func TestNodeRenameCommand(t *testing.T) {
}, },
&listAll, &listAll,
) )
assert.Nil(t, err) require.NoError(t, err)
assert.Len(t, listAll, 5) assert.Len(t, listAll, 5)
@ -1370,7 +1377,7 @@ func TestNodeRenameCommand(t *testing.T) {
fmt.Sprintf("newnode-%d", idx+1), fmt.Sprintf("newnode-%d", idx+1),
}, },
) )
assert.Nil(t, err) require.NoError(t, err)
assert.Contains(t, res, "Node renamed") assert.Contains(t, res, "Node renamed")
} }
@ -1387,7 +1394,7 @@ func TestNodeRenameCommand(t *testing.T) {
}, },
&listAllAfterRename, &listAllAfterRename,
) )
assert.Nil(t, err) require.NoError(t, err)
assert.Len(t, listAllAfterRename, 5) assert.Len(t, listAllAfterRename, 5)
@ -1408,7 +1415,7 @@ func TestNodeRenameCommand(t *testing.T) {
strings.Repeat("t", 64), strings.Repeat("t", 64),
}, },
) )
assert.ErrorContains(t, err, "not be over 63 chars") require.ErrorContains(t, err, "not be over 63 chars")
var listAllAfterRenameAttempt []v1.Node var listAllAfterRenameAttempt []v1.Node
err = executeAndUnmarshal( err = executeAndUnmarshal(
@ -1422,7 +1429,7 @@ func TestNodeRenameCommand(t *testing.T) {
}, },
&listAllAfterRenameAttempt, &listAllAfterRenameAttempt,
) )
assert.Nil(t, err) require.NoError(t, err)
assert.Len(t, listAllAfterRenameAttempt, 5) assert.Len(t, listAllAfterRenameAttempt, 5)
@ -1438,7 +1445,7 @@ func TestNodeMoveCommand(t *testing.T) {
t.Parallel() t.Parallel()
scenario, err := NewScenario(dockertestMaxWait()) scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{ spec := map[string]int{
@ -1447,10 +1454,10 @@ func TestNodeMoveCommand(t *testing.T) {
} }
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins")) err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins"))
assertNoErr(t, err) require.NoError(t, err)
headscale, err := scenario.Headscale() headscale, err := scenario.Headscale()
assertNoErr(t, err) require.NoError(t, err)
// Randomly generated node key // Randomly generated node key
machineKey := "mkey:688411b767663479632d44140f08a9fde87383adc7cdeb518f62ce28a17ef0aa" machineKey := "mkey:688411b767663479632d44140f08a9fde87383adc7cdeb518f62ce28a17ef0aa"
@ -1470,7 +1477,7 @@ func TestNodeMoveCommand(t *testing.T) {
"json", "json",
}, },
) )
assert.Nil(t, err) require.NoError(t, err)
var node v1.Node var node v1.Node
err = executeAndUnmarshal( err = executeAndUnmarshal(
@ -1488,11 +1495,11 @@ func TestNodeMoveCommand(t *testing.T) {
}, },
&node, &node,
) )
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, uint64(1), node.GetId()) assert.Equal(t, uint64(1), node.GetId())
assert.Equal(t, "nomad-node", node.GetName()) assert.Equal(t, "nomad-node", node.GetName())
assert.Equal(t, node.GetUser().GetName(), "old-user") assert.Equal(t, "old-user", node.GetUser().GetName())
nodeID := fmt.Sprintf("%d", node.GetId()) nodeID := fmt.Sprintf("%d", node.GetId())
@ -1511,9 +1518,9 @@ func TestNodeMoveCommand(t *testing.T) {
}, },
&node, &node,
) )
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, node.GetUser().GetName(), "new-user") assert.Equal(t, "new-user", node.GetUser().GetName())
var allNodes []v1.Node var allNodes []v1.Node
err = executeAndUnmarshal( err = executeAndUnmarshal(
@ -1527,13 +1534,13 @@ func TestNodeMoveCommand(t *testing.T) {
}, },
&allNodes, &allNodes,
) )
assert.Nil(t, err) require.NoError(t, err)
assert.Len(t, allNodes, 1) assert.Len(t, allNodes, 1)
assert.Equal(t, allNodes[0].GetId(), node.GetId()) assert.Equal(t, allNodes[0].GetId(), node.GetId())
assert.Equal(t, allNodes[0].GetUser(), node.GetUser()) assert.Equal(t, allNodes[0].GetUser(), node.GetUser())
assert.Equal(t, allNodes[0].GetUser().GetName(), "new-user") assert.Equal(t, "new-user", allNodes[0].GetUser().GetName())
_, err = headscale.Execute( _, err = headscale.Execute(
[]string{ []string{
@ -1548,12 +1555,12 @@ func TestNodeMoveCommand(t *testing.T) {
"json", "json",
}, },
) )
assert.ErrorContains( require.ErrorContains(
t, t,
err, err,
"user not found", "user not found",
) )
assert.Equal(t, node.GetUser().GetName(), "new-user") assert.Equal(t, "new-user", node.GetUser().GetName())
err = executeAndUnmarshal( err = executeAndUnmarshal(
headscale, headscale,
@ -1570,9 +1577,9 @@ func TestNodeMoveCommand(t *testing.T) {
}, },
&node, &node,
) )
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, node.GetUser().GetName(), "old-user") assert.Equal(t, "old-user", node.GetUser().GetName())
err = executeAndUnmarshal( err = executeAndUnmarshal(
headscale, headscale,
@ -1589,9 +1596,9 @@ func TestNodeMoveCommand(t *testing.T) {
}, },
&node, &node,
) )
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, node.GetUser().GetName(), "old-user") assert.Equal(t, "old-user", node.GetUser().GetName())
} }
func TestPolicyCommand(t *testing.T) { func TestPolicyCommand(t *testing.T) {
@ -1599,7 +1606,7 @@ func TestPolicyCommand(t *testing.T) {
t.Parallel() t.Parallel()
scenario, err := NewScenario(dockertestMaxWait()) scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{ spec := map[string]int{
@ -1614,10 +1621,10 @@ func TestPolicyCommand(t *testing.T) {
"HEADSCALE_POLICY_MODE": "database", "HEADSCALE_POLICY_MODE": "database",
}), }),
) )
assertNoErr(t, err) require.NoError(t, err)
headscale, err := scenario.Headscale() headscale, err := scenario.Headscale()
assertNoErr(t, err) require.NoError(t, err)
p := policy.ACLPolicy{ p := policy.ACLPolicy{
ACLs: []policy.ACL{ ACLs: []policy.ACL{
@ -1637,7 +1644,7 @@ func TestPolicyCommand(t *testing.T) {
policyFilePath := "/etc/headscale/policy.json" policyFilePath := "/etc/headscale/policy.json"
err = headscale.WriteFile(policyFilePath, pBytes) err = headscale.WriteFile(policyFilePath, pBytes)
assertNoErr(t, err) require.NoError(t, err)
// No policy is present at this time. // No policy is present at this time.
// Add a new policy from a file. // Add a new policy from a file.
@ -1651,7 +1658,7 @@ func TestPolicyCommand(t *testing.T) {
}, },
) )
assertNoErr(t, err) require.NoError(t, err)
// Get the current policy and check // Get the current policy and check
// if it is the same as the one we set. // if it is the same as the one we set.
@ -1667,11 +1674,11 @@ func TestPolicyCommand(t *testing.T) {
}, },
&output, &output,
) )
assertNoErr(t, err) require.NoError(t, err)
assert.Len(t, output.TagOwners, 1) assert.Len(t, output.TagOwners, 1)
assert.Len(t, output.ACLs, 1) assert.Len(t, output.ACLs, 1)
assert.Equal(t, output.TagOwners["tag:exists"], []string{"policy-user"}) assert.Equal(t, []string{"policy-user"}, output.TagOwners["tag:exists"])
} }
func TestPolicyBrokenConfigCommand(t *testing.T) { func TestPolicyBrokenConfigCommand(t *testing.T) {
@ -1679,7 +1686,7 @@ func TestPolicyBrokenConfigCommand(t *testing.T) {
t.Parallel() t.Parallel()
scenario, err := NewScenario(dockertestMaxWait()) scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{ spec := map[string]int{
@ -1694,10 +1701,10 @@ func TestPolicyBrokenConfigCommand(t *testing.T) {
"HEADSCALE_POLICY_MODE": "database", "HEADSCALE_POLICY_MODE": "database",
}), }),
) )
assertNoErr(t, err) require.NoError(t, err)
headscale, err := scenario.Headscale() headscale, err := scenario.Headscale()
assertNoErr(t, err) require.NoError(t, err)
p := policy.ACLPolicy{ p := policy.ACLPolicy{
ACLs: []policy.ACL{ ACLs: []policy.ACL{
@ -1719,7 +1726,7 @@ func TestPolicyBrokenConfigCommand(t *testing.T) {
policyFilePath := "/etc/headscale/policy.json" policyFilePath := "/etc/headscale/policy.json"
err = headscale.WriteFile(policyFilePath, pBytes) err = headscale.WriteFile(policyFilePath, pBytes)
assertNoErr(t, err) require.NoError(t, err)
// No policy is present at this time. // No policy is present at this time.
// Add a new policy from a file. // Add a new policy from a file.
@ -1732,7 +1739,7 @@ func TestPolicyBrokenConfigCommand(t *testing.T) {
policyFilePath, policyFilePath,
}, },
) )
assert.ErrorContains(t, err, "verifying policy rules: invalid action") require.ErrorContains(t, err, "verifying policy rules: invalid action")
// The new policy was invalid, the old one should still be in place, which // The new policy was invalid, the old one should still be in place, which
// is none. // is none.
@ -1745,5 +1752,5 @@ func TestPolicyBrokenConfigCommand(t *testing.T) {
"json", "json",
}, },
) )
assert.ErrorContains(t, err, "acl policy not found") require.ErrorContains(t, err, "acl policy not found")
} }

View file

@ -18,6 +18,7 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/samber/lo" "github.com/samber/lo"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"tailscale.com/client/tailscale/apitype" "tailscale.com/client/tailscale/apitype"
"tailscale.com/types/key" "tailscale.com/types/key"
@ -244,7 +245,11 @@ func TestEphemeral(t *testing.T) {
} }
func TestEphemeralInAlternateTimezone(t *testing.T) { func TestEphemeralInAlternateTimezone(t *testing.T) {
testEphemeralWithOptions(t, hsic.WithTestName("ephemeral-tz"), hsic.WithTimezone("America/Los_Angeles")) testEphemeralWithOptions(
t,
hsic.WithTestName("ephemeral-tz"),
hsic.WithTimezone("America/Los_Angeles"),
)
} }
func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) { func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) {
@ -1164,10 +1169,10 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) {
}, },
&nodeList, &nodeList,
) )
assert.Nil(t, err) require.NoError(t, err)
assert.Len(t, nodeList, 2) assert.Len(t, nodeList, 2)
assert.True(t, nodeList[0].Online) assert.True(t, nodeList[0].GetOnline())
assert.True(t, nodeList[1].Online) assert.True(t, nodeList[1].GetOnline())
// Delete the first node, which is online // Delete the first node, which is online
_, err = headscale.Execute( _, err = headscale.Execute(
@ -1177,13 +1182,13 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) {
"delete", "delete",
"--identifier", "--identifier",
// Delete the last added machine // Delete the last added machine
fmt.Sprintf("%d", nodeList[0].Id), fmt.Sprintf("%d", nodeList[0].GetId()),
"--output", "--output",
"json", "json",
"--force", "--force",
}, },
) )
assert.Nil(t, err) require.NoError(t, err)
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
@ -1200,9 +1205,8 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) {
}, },
&nodeListAfter, &nodeListAfter,
) )
assert.Nil(t, err) require.NoError(t, err)
assert.Len(t, nodeListAfter, 1) assert.Len(t, nodeListAfter, 1)
assert.True(t, nodeListAfter[0].Online) assert.True(t, nodeListAfter[0].GetOnline())
assert.Equal(t, nodeList[1].Id, nodeListAfter[0].Id) assert.Equal(t, nodeList[1].GetId(), nodeListAfter[0].GetId())
} }

View file

@ -92,9 +92,9 @@ func TestEnablingRoutes(t *testing.T) {
assert.Len(t, routes, 3) assert.Len(t, routes, 3)
for _, route := range routes { for _, route := range routes {
assert.Equal(t, true, route.GetAdvertised()) assert.True(t, route.GetAdvertised())
assert.Equal(t, false, route.GetEnabled()) assert.False(t, route.GetEnabled())
assert.Equal(t, false, route.GetIsPrimary()) assert.False(t, route.GetIsPrimary())
} }
// Verify that no routes has been sent to the client, // Verify that no routes has been sent to the client,
@ -139,9 +139,9 @@ func TestEnablingRoutes(t *testing.T) {
assert.Len(t, enablingRoutes, 3) assert.Len(t, enablingRoutes, 3)
for _, route := range enablingRoutes { for _, route := range enablingRoutes {
assert.Equal(t, true, route.GetAdvertised()) assert.True(t, route.GetAdvertised())
assert.Equal(t, true, route.GetEnabled()) assert.True(t, route.GetEnabled())
assert.Equal(t, true, route.GetIsPrimary()) assert.True(t, route.GetIsPrimary())
} }
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
@ -212,18 +212,18 @@ func TestEnablingRoutes(t *testing.T) {
assertNoErr(t, err) assertNoErr(t, err)
for _, route := range disablingRoutes { for _, route := range disablingRoutes {
assert.Equal(t, true, route.GetAdvertised()) assert.True(t, route.GetAdvertised())
if route.GetId() == routeToBeDisabled.GetId() { if route.GetId() == routeToBeDisabled.GetId() {
assert.Equal(t, false, route.GetEnabled()) assert.False(t, route.GetEnabled())
// since this is the only route of this cidr, // since this is the only route of this cidr,
// it will not failover, and remain Primary // it will not failover, and remain Primary
// until something can replace it. // until something can replace it.
assert.Equal(t, true, route.GetIsPrimary()) assert.True(t, route.GetIsPrimary())
} else { } else {
assert.Equal(t, true, route.GetEnabled()) assert.True(t, route.GetEnabled())
assert.Equal(t, true, route.GetIsPrimary()) assert.True(t, route.GetIsPrimary())
} }
} }
@ -342,9 +342,9 @@ func TestHASubnetRouterFailover(t *testing.T) {
t.Logf("initial routes %#v", routes) t.Logf("initial routes %#v", routes)
for _, route := range routes { for _, route := range routes {
assert.Equal(t, true, route.GetAdvertised()) assert.True(t, route.GetAdvertised())
assert.Equal(t, false, route.GetEnabled()) assert.False(t, route.GetEnabled())
assert.Equal(t, false, route.GetIsPrimary()) assert.False(t, route.GetIsPrimary())
} }
// Verify that no routes has been sent to the client, // Verify that no routes has been sent to the client,
@ -391,14 +391,14 @@ func TestHASubnetRouterFailover(t *testing.T) {
assert.Len(t, enablingRoutes, 2) assert.Len(t, enablingRoutes, 2)
// Node 1 is primary // Node 1 is primary
assert.Equal(t, true, enablingRoutes[0].GetAdvertised()) assert.True(t, enablingRoutes[0].GetAdvertised())
assert.Equal(t, true, enablingRoutes[0].GetEnabled()) assert.True(t, enablingRoutes[0].GetEnabled())
assert.Equal(t, true, enablingRoutes[0].GetIsPrimary(), "both subnet routers are up, expected r1 to be primary") assert.True(t, enablingRoutes[0].GetIsPrimary(), "both subnet routers are up, expected r1 to be primary")
// Node 2 is not primary // Node 2 is not primary
assert.Equal(t, true, enablingRoutes[1].GetAdvertised()) assert.True(t, enablingRoutes[1].GetAdvertised())
assert.Equal(t, true, enablingRoutes[1].GetEnabled()) assert.True(t, enablingRoutes[1].GetEnabled())
assert.Equal(t, false, enablingRoutes[1].GetIsPrimary(), "both subnet routers are up, expected r2 to be non-primary") assert.False(t, enablingRoutes[1].GetIsPrimary(), "both subnet routers are up, expected r2 to be non-primary")
// Verify that the client has routes from the primary machine // Verify that the client has routes from the primary machine
srs1, err := subRouter1.Status() srs1, err := subRouter1.Status()
@ -446,14 +446,14 @@ func TestHASubnetRouterFailover(t *testing.T) {
assert.Len(t, routesAfterMove, 2) assert.Len(t, routesAfterMove, 2)
// Node 1 is not primary // Node 1 is not primary
assert.Equal(t, true, routesAfterMove[0].GetAdvertised()) assert.True(t, routesAfterMove[0].GetAdvertised())
assert.Equal(t, true, routesAfterMove[0].GetEnabled()) assert.True(t, routesAfterMove[0].GetEnabled())
assert.Equal(t, false, routesAfterMove[0].GetIsPrimary(), "r1 is down, expected r2 to be primary") assert.False(t, routesAfterMove[0].GetIsPrimary(), "r1 is down, expected r2 to be primary")
// Node 2 is primary // Node 2 is primary
assert.Equal(t, true, routesAfterMove[1].GetAdvertised()) assert.True(t, routesAfterMove[1].GetAdvertised())
assert.Equal(t, true, routesAfterMove[1].GetEnabled()) assert.True(t, routesAfterMove[1].GetEnabled())
assert.Equal(t, true, routesAfterMove[1].GetIsPrimary(), "r1 is down, expected r2 to be primary") assert.True(t, routesAfterMove[1].GetIsPrimary(), "r1 is down, expected r2 to be primary")
srs2, err = subRouter2.Status() srs2, err = subRouter2.Status()
@ -501,16 +501,16 @@ func TestHASubnetRouterFailover(t *testing.T) {
assert.Len(t, routesAfterBothDown, 2) assert.Len(t, routesAfterBothDown, 2)
// Node 1 is not primary // Node 1 is not primary
assert.Equal(t, true, routesAfterBothDown[0].GetAdvertised()) assert.True(t, routesAfterBothDown[0].GetAdvertised())
assert.Equal(t, true, routesAfterBothDown[0].GetEnabled()) assert.True(t, routesAfterBothDown[0].GetEnabled())
assert.Equal(t, false, routesAfterBothDown[0].GetIsPrimary(), "r1 and r2 is down, expected r2 to _still_ be primary") assert.False(t, routesAfterBothDown[0].GetIsPrimary(), "r1 and r2 is down, expected r2 to _still_ be primary")
// Node 2 is primary // Node 2 is primary
// if the node goes down, but no other suitable route is // if the node goes down, but no other suitable route is
// available, keep the last known good route. // available, keep the last known good route.
assert.Equal(t, true, routesAfterBothDown[1].GetAdvertised()) assert.True(t, routesAfterBothDown[1].GetAdvertised())
assert.Equal(t, true, routesAfterBothDown[1].GetEnabled()) assert.True(t, routesAfterBothDown[1].GetEnabled())
assert.Equal(t, true, routesAfterBothDown[1].GetIsPrimary(), "r1 and r2 is down, expected r2 to _still_ be primary") assert.True(t, routesAfterBothDown[1].GetIsPrimary(), "r1 and r2 is down, expected r2 to _still_ be primary")
// TODO(kradalby): Check client status // TODO(kradalby): Check client status
// Both are expected to be down // Both are expected to be down
@ -560,14 +560,14 @@ func TestHASubnetRouterFailover(t *testing.T) {
assert.Len(t, routesAfter1Up, 2) assert.Len(t, routesAfter1Up, 2)
// Node 1 is primary // Node 1 is primary
assert.Equal(t, true, routesAfter1Up[0].GetAdvertised()) assert.True(t, routesAfter1Up[0].GetAdvertised())
assert.Equal(t, true, routesAfter1Up[0].GetEnabled()) assert.True(t, routesAfter1Up[0].GetEnabled())
assert.Equal(t, true, routesAfter1Up[0].GetIsPrimary(), "r1 is back up, expected r1 to become be primary") assert.True(t, routesAfter1Up[0].GetIsPrimary(), "r1 is back up, expected r1 to become be primary")
// Node 2 is not primary // Node 2 is not primary
assert.Equal(t, true, routesAfter1Up[1].GetAdvertised()) assert.True(t, routesAfter1Up[1].GetAdvertised())
assert.Equal(t, true, routesAfter1Up[1].GetEnabled()) assert.True(t, routesAfter1Up[1].GetEnabled())
assert.Equal(t, false, routesAfter1Up[1].GetIsPrimary(), "r1 is back up, expected r1 to become be primary") assert.False(t, routesAfter1Up[1].GetIsPrimary(), "r1 is back up, expected r1 to become be primary")
// Verify that the route is announced from subnet router 1 // Verify that the route is announced from subnet router 1
clientStatus, err = client.Status() clientStatus, err = client.Status()
@ -614,14 +614,14 @@ func TestHASubnetRouterFailover(t *testing.T) {
assert.Len(t, routesAfter2Up, 2) assert.Len(t, routesAfter2Up, 2)
// Node 1 is not primary // Node 1 is not primary
assert.Equal(t, true, routesAfter2Up[0].GetAdvertised()) assert.True(t, routesAfter2Up[0].GetAdvertised())
assert.Equal(t, true, routesAfter2Up[0].GetEnabled()) assert.True(t, routesAfter2Up[0].GetEnabled())
assert.Equal(t, true, routesAfter2Up[0].GetIsPrimary(), "r1 and r2 is back up, expected r1 to _still_ be primary") assert.True(t, routesAfter2Up[0].GetIsPrimary(), "r1 and r2 is back up, expected r1 to _still_ be primary")
// Node 2 is primary // Node 2 is primary
assert.Equal(t, true, routesAfter2Up[1].GetAdvertised()) assert.True(t, routesAfter2Up[1].GetAdvertised())
assert.Equal(t, true, routesAfter2Up[1].GetEnabled()) assert.True(t, routesAfter2Up[1].GetEnabled())
assert.Equal(t, false, routesAfter2Up[1].GetIsPrimary(), "r1 and r2 is back up, expected r1 to _still_ be primary") assert.False(t, routesAfter2Up[1].GetIsPrimary(), "r1 and r2 is back up, expected r1 to _still_ be primary")
// Verify that the route is announced from subnet router 1 // Verify that the route is announced from subnet router 1
clientStatus, err = client.Status() clientStatus, err = client.Status()
@ -677,14 +677,14 @@ func TestHASubnetRouterFailover(t *testing.T) {
t.Logf("routes after disabling r1 %#v", routesAfterDisabling1) t.Logf("routes after disabling r1 %#v", routesAfterDisabling1)
// Node 1 is not primary // Node 1 is not primary
assert.Equal(t, true, routesAfterDisabling1[0].GetAdvertised()) assert.True(t, routesAfterDisabling1[0].GetAdvertised())
assert.Equal(t, false, routesAfterDisabling1[0].GetEnabled()) assert.False(t, routesAfterDisabling1[0].GetEnabled())
assert.Equal(t, false, routesAfterDisabling1[0].GetIsPrimary()) assert.False(t, routesAfterDisabling1[0].GetIsPrimary())
// Node 2 is primary // Node 2 is primary
assert.Equal(t, true, routesAfterDisabling1[1].GetAdvertised()) assert.True(t, routesAfterDisabling1[1].GetAdvertised())
assert.Equal(t, true, routesAfterDisabling1[1].GetEnabled()) assert.True(t, routesAfterDisabling1[1].GetEnabled())
assert.Equal(t, true, routesAfterDisabling1[1].GetIsPrimary()) assert.True(t, routesAfterDisabling1[1].GetIsPrimary())
// Verify that the route is announced from subnet router 1 // Verify that the route is announced from subnet router 1
clientStatus, err = client.Status() clientStatus, err = client.Status()
@ -735,14 +735,14 @@ func TestHASubnetRouterFailover(t *testing.T) {
assert.Len(t, routesAfterEnabling1, 2) assert.Len(t, routesAfterEnabling1, 2)
// Node 1 is not primary // Node 1 is not primary
assert.Equal(t, true, routesAfterEnabling1[0].GetAdvertised()) assert.True(t, routesAfterEnabling1[0].GetAdvertised())
assert.Equal(t, true, routesAfterEnabling1[0].GetEnabled()) assert.True(t, routesAfterEnabling1[0].GetEnabled())
assert.Equal(t, false, routesAfterEnabling1[0].GetIsPrimary()) assert.False(t, routesAfterEnabling1[0].GetIsPrimary())
// Node 2 is primary // Node 2 is primary
assert.Equal(t, true, routesAfterEnabling1[1].GetAdvertised()) assert.True(t, routesAfterEnabling1[1].GetAdvertised())
assert.Equal(t, true, routesAfterEnabling1[1].GetEnabled()) assert.True(t, routesAfterEnabling1[1].GetEnabled())
assert.Equal(t, true, routesAfterEnabling1[1].GetIsPrimary()) assert.True(t, routesAfterEnabling1[1].GetIsPrimary())
// Verify that the route is announced from subnet router 1 // Verify that the route is announced from subnet router 1
clientStatus, err = client.Status() clientStatus, err = client.Status()
@ -795,9 +795,9 @@ func TestHASubnetRouterFailover(t *testing.T) {
t.Logf("routes after deleting r2 %#v", routesAfterDeleting2) t.Logf("routes after deleting r2 %#v", routesAfterDeleting2)
// Node 1 is primary // Node 1 is primary
assert.Equal(t, true, routesAfterDeleting2[0].GetAdvertised()) assert.True(t, routesAfterDeleting2[0].GetAdvertised())
assert.Equal(t, true, routesAfterDeleting2[0].GetEnabled()) assert.True(t, routesAfterDeleting2[0].GetEnabled())
assert.Equal(t, true, routesAfterDeleting2[0].GetIsPrimary()) assert.True(t, routesAfterDeleting2[0].GetIsPrimary())
// Verify that the route is announced from subnet router 1 // Verify that the route is announced from subnet router 1
clientStatus, err = client.Status() clientStatus, err = client.Status()
@ -893,9 +893,9 @@ func TestEnableDisableAutoApprovedRoute(t *testing.T) {
assert.Len(t, routes, 1) assert.Len(t, routes, 1)
// All routes should be auto approved and enabled // All routes should be auto approved and enabled
assert.Equal(t, true, routes[0].GetAdvertised()) assert.True(t, routes[0].GetAdvertised())
assert.Equal(t, true, routes[0].GetEnabled()) assert.True(t, routes[0].GetEnabled())
assert.Equal(t, true, routes[0].GetIsPrimary()) assert.True(t, routes[0].GetIsPrimary())
// Stop advertising route // Stop advertising route
command = []string{ command = []string{
@ -924,9 +924,9 @@ func TestEnableDisableAutoApprovedRoute(t *testing.T) {
assert.Len(t, notAdvertisedRoutes, 1) assert.Len(t, notAdvertisedRoutes, 1)
// Route is no longer advertised // Route is no longer advertised
assert.Equal(t, false, notAdvertisedRoutes[0].GetAdvertised()) assert.False(t, notAdvertisedRoutes[0].GetAdvertised())
assert.Equal(t, false, notAdvertisedRoutes[0].GetEnabled()) assert.False(t, notAdvertisedRoutes[0].GetEnabled())
assert.Equal(t, true, notAdvertisedRoutes[0].GetIsPrimary()) assert.True(t, notAdvertisedRoutes[0].GetIsPrimary())
// Advertise route again // Advertise route again
command = []string{ command = []string{
@ -955,9 +955,9 @@ func TestEnableDisableAutoApprovedRoute(t *testing.T) {
assert.Len(t, reAdvertisedRoutes, 1) assert.Len(t, reAdvertisedRoutes, 1)
// All routes should be auto approved and enabled // All routes should be auto approved and enabled
assert.Equal(t, true, reAdvertisedRoutes[0].GetAdvertised()) assert.True(t, reAdvertisedRoutes[0].GetAdvertised())
assert.Equal(t, true, reAdvertisedRoutes[0].GetEnabled()) assert.True(t, reAdvertisedRoutes[0].GetEnabled())
assert.Equal(t, true, reAdvertisedRoutes[0].GetIsPrimary()) assert.True(t, reAdvertisedRoutes[0].GetIsPrimary())
} }
func TestAutoApprovedSubRoute2068(t *testing.T) { func TestAutoApprovedSubRoute2068(t *testing.T) {
@ -1163,9 +1163,9 @@ func TestSubnetRouteACL(t *testing.T) {
assert.Len(t, routes, 1) assert.Len(t, routes, 1)
for _, route := range routes { for _, route := range routes {
assert.Equal(t, true, route.GetAdvertised()) assert.True(t, route.GetAdvertised())
assert.Equal(t, false, route.GetEnabled()) assert.False(t, route.GetEnabled())
assert.Equal(t, false, route.GetIsPrimary()) assert.False(t, route.GetIsPrimary())
} }
// Verify that no routes has been sent to the client, // Verify that no routes has been sent to the client,
@ -1212,9 +1212,9 @@ func TestSubnetRouteACL(t *testing.T) {
assert.Len(t, enablingRoutes, 1) assert.Len(t, enablingRoutes, 1)
// Node 1 has active route // Node 1 has active route
assert.Equal(t, true, enablingRoutes[0].GetAdvertised()) assert.True(t, enablingRoutes[0].GetAdvertised())
assert.Equal(t, true, enablingRoutes[0].GetEnabled()) assert.True(t, enablingRoutes[0].GetEnabled())
assert.Equal(t, true, enablingRoutes[0].GetIsPrimary()) assert.True(t, enablingRoutes[0].GetIsPrimary())
// Verify that the client has routes from the primary machine // Verify that the client has routes from the primary machine
srs1, _ := subRouter1.Status() srs1, _ := subRouter1.Status()

View file

@ -20,6 +20,7 @@ import (
"github.com/puzpuzpuz/xsync/v3" "github.com/puzpuzpuz/xsync/v3"
"github.com/samber/lo" "github.com/samber/lo"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"tailscale.com/envknob" "tailscale.com/envknob"
) )
@ -203,11 +204,11 @@ func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) {
if t != nil { if t != nil {
stdout, err := os.ReadFile(stdoutPath) stdout, err := os.ReadFile(stdoutPath)
assert.NoError(t, err) require.NoError(t, err)
assert.NotContains(t, string(stdout), "panic") assert.NotContains(t, string(stdout), "panic")
stderr, err := os.ReadFile(stderrPath) stderr, err := os.ReadFile(stderrPath)
assert.NoError(t, err) require.NoError(t, err)
assert.NotContains(t, string(stderr), "panic") assert.NotContains(t, string(stderr), "panic")
} }