mirror of
https://github.com/juanfont/headscale.git
synced 2024-11-26 08:53:05 +00:00
move to use tailscfg types over strings/custom types (#1612)
* rename database only fields Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * use correct endpoint type over string list Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * remove HostInfo wrapper Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * wrap errors in database hooks Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> --------- Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
ed4e19996b
commit
b918aa03fc
13 changed files with 147 additions and 154 deletions
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"gopkg.in/check.v1"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
|
@ -593,7 +594,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
|
|||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: types.HostInfo{
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RequestTags: []string{"tag:exit"},
|
||||
RoutableIPs: []netip.Prefix{defaultRouteV4, defaultRouteV6, route1, route2},
|
||||
},
|
||||
|
|
|
@ -274,7 +274,7 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) error {
|
|||
}
|
||||
|
||||
advertisedRoutes := map[netip.Prefix]bool{}
|
||||
for _, prefix := range node.HostInfo.RoutableIPs {
|
||||
for _, prefix := range node.Hostinfo.RoutableIPs {
|
||||
advertisedRoutes[prefix] = false
|
||||
}
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ func (s *Suite) TestGetRoutes(c *check.C) {
|
|||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: types.HostInfo(hostInfo),
|
||||
Hostinfo: &hostInfo,
|
||||
}
|
||||
db.db.Save(&node)
|
||||
|
||||
|
@ -81,7 +81,7 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
|
|||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: types.HostInfo(hostInfo),
|
||||
Hostinfo: &hostInfo,
|
||||
}
|
||||
db.db.Save(&node)
|
||||
|
||||
|
@ -152,7 +152,7 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
|
|||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: types.HostInfo(hostInfo1),
|
||||
Hostinfo: &hostInfo1,
|
||||
}
|
||||
db.db.Save(&node1)
|
||||
|
||||
|
@ -174,7 +174,7 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
|
|||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: types.HostInfo(hostInfo2),
|
||||
Hostinfo: &hostInfo2,
|
||||
}
|
||||
db.db.Save(&node2)
|
||||
|
||||
|
@ -232,7 +232,7 @@ func (s *Suite) TestSubnetFailover(c *check.C) {
|
|||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: types.HostInfo(hostInfo1),
|
||||
Hostinfo: &hostInfo1,
|
||||
LastSeen: &now,
|
||||
}
|
||||
db.db.Save(&node1)
|
||||
|
@ -266,7 +266,7 @@ func (s *Suite) TestSubnetFailover(c *check.C) {
|
|||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: types.HostInfo(hostInfo2),
|
||||
Hostinfo: &hostInfo2,
|
||||
LastSeen: &now,
|
||||
}
|
||||
db.db.Save(&node2)
|
||||
|
@ -313,9 +313,9 @@ func (s *Suite) TestSubnetFailover(c *check.C) {
|
|||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(routes), check.Equals, 1)
|
||||
|
||||
node2.HostInfo = types.HostInfo(tailcfg.Hostinfo{
|
||||
node2.Hostinfo = &tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{prefix, prefix2},
|
||||
})
|
||||
}
|
||||
err = db.db.Save(&node2).Error
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
|
@ -368,7 +368,7 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
|
|||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: types.HostInfo(hostInfo1),
|
||||
Hostinfo: &hostInfo1,
|
||||
LastSeen: &now,
|
||||
}
|
||||
db.db.Save(&node1)
|
||||
|
|
|
@ -550,7 +550,7 @@ func (api headscaleV1APIServer) DebugCreateNode(
|
|||
Expiry: &time.Time{},
|
||||
LastSeen: &time.Time{},
|
||||
|
||||
HostInfo: types.HostInfo(hostinfo),
|
||||
Hostinfo: &hostinfo,
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
|
|
|
@ -195,7 +195,7 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
|
|||
if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) {
|
||||
attrs := url.Values{
|
||||
"device_name": []string{node.Hostname},
|
||||
"device_model": []string{node.HostInfo.OS},
|
||||
"device_model": []string{node.Hostinfo.OS},
|
||||
}
|
||||
|
||||
if len(node.IPAddresses) > 0 {
|
||||
|
|
|
@ -186,8 +186,7 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
AuthKey: &types.PreAuthKey{},
|
||||
LastSeen: &lastSeen,
|
||||
Expiry: &expire,
|
||||
HostInfo: types.HostInfo{},
|
||||
Endpoints: []string{},
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
Routes: []types.Route{
|
||||
{
|
||||
Prefix: types.IPPrefix(netip.MustParsePrefix("0.0.0.0/0")),
|
||||
|
@ -267,8 +266,7 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
ForcedTags: []string{},
|
||||
LastSeen: &lastSeen,
|
||||
Expiry: &expire,
|
||||
HostInfo: types.HostInfo{},
|
||||
Endpoints: []string{},
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
Routes: []types.Route{},
|
||||
CreatedAt: created,
|
||||
}
|
||||
|
@ -324,8 +322,7 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
ForcedTags: []string{},
|
||||
LastSeen: &lastSeen,
|
||||
Expiry: &expire,
|
||||
HostInfo: types.HostInfo{},
|
||||
Endpoints: []string{},
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
Routes: []types.Route{},
|
||||
CreatedAt: created,
|
||||
}
|
||||
|
|
|
@ -72,8 +72,8 @@ func tailNode(
|
|||
}
|
||||
|
||||
var derp string
|
||||
if node.HostInfo.NetInfo != nil {
|
||||
derp = fmt.Sprintf("127.3.3.40:%d", node.HostInfo.NetInfo.PreferredDERP)
|
||||
if node.Hostinfo.NetInfo != nil {
|
||||
derp = fmt.Sprintf("127.3.3.40:%d", node.Hostinfo.NetInfo.PreferredDERP)
|
||||
} else {
|
||||
derp = "127.3.3.40:0" // Zero means disconnected or unknown.
|
||||
}
|
||||
|
@ -90,18 +90,11 @@ func tailNode(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
hostInfo := node.GetHostInfo()
|
||||
|
||||
online := node.IsOnline()
|
||||
|
||||
tags, _ := pol.TagsOfNode(node)
|
||||
tags = lo.Uniq(append(tags, node.ForcedTags...))
|
||||
|
||||
endpoints, err := node.EndpointsToAddrPort()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tNode := tailcfg.Node{
|
||||
ID: tailcfg.NodeID(node.ID), // this is the actual ID
|
||||
StableID: tailcfg.StableNodeID(
|
||||
|
@ -118,9 +111,9 @@ func tailNode(
|
|||
DiscoKey: node.DiscoKey,
|
||||
Addresses: addrs,
|
||||
AllowedIPs: allowedIPs,
|
||||
Endpoints: endpoints,
|
||||
Endpoints: node.Endpoints,
|
||||
DERP: derp,
|
||||
Hostinfo: hostInfo.View(),
|
||||
Hostinfo: node.Hostinfo.View(),
|
||||
Created: node.CreatedAt,
|
||||
|
||||
Tags: tags,
|
||||
|
|
|
@ -53,8 +53,10 @@ func TestTailNode(t *testing.T) {
|
|||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "empty-node",
|
||||
node: &types.Node{},
|
||||
name: "empty-node",
|
||||
node: &types.Node{
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
},
|
||||
pol: &policy.ACLPolicy{},
|
||||
dnsConfig: &tailcfg.DNSConfig{},
|
||||
baseDomain: "",
|
||||
|
@ -102,8 +104,7 @@ func TestTailNode(t *testing.T) {
|
|||
AuthKey: &types.PreAuthKey{},
|
||||
LastSeen: &lastSeen,
|
||||
Expiry: &expire,
|
||||
HostInfo: types.HostInfo{},
|
||||
Endpoints: []string{},
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
Routes: []types.Route{
|
||||
{
|
||||
Prefix: types.IPPrefix(netip.MustParsePrefix("0.0.0.0/0")),
|
||||
|
|
|
@ -596,10 +596,13 @@ func excludeCorrectlyTaggedNodes(
|
|||
}
|
||||
// for each node if tag is in tags list, don't append it.
|
||||
for _, node := range nodes {
|
||||
hi := node.GetHostInfo()
|
||||
|
||||
found := false
|
||||
for _, t := range hi.RequestTags {
|
||||
|
||||
if node.Hostinfo == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, t := range node.Hostinfo.RequestTags {
|
||||
if util.StringOrPrefixListContains(tags, t) {
|
||||
found = true
|
||||
|
||||
|
@ -787,8 +790,11 @@ func (pol *ACLPolicy) expandIPsFromTag(
|
|||
for _, user := range owners {
|
||||
nodes := filterNodesByUser(nodes, user)
|
||||
for _, node := range nodes {
|
||||
hi := node.GetHostInfo()
|
||||
if util.StringOrPrefixListContains(hi.RequestTags, alias) {
|
||||
if node.Hostinfo == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if util.StringOrPrefixListContains(node.Hostinfo.RequestTags, alias) {
|
||||
node.IPAddresses.AppendToIPSet(&build)
|
||||
}
|
||||
}
|
||||
|
@ -882,7 +888,7 @@ func (pol *ACLPolicy) TagsOfNode(
|
|||
|
||||
validTagMap := make(map[string]bool)
|
||||
invalidTagMap := make(map[string]bool)
|
||||
for _, tag := range node.HostInfo.RequestTags {
|
||||
for _, tag := range node.Hostinfo.RequestTags {
|
||||
owners, err := expandOwnersFromTag(pol, tag)
|
||||
if errors.Is(err, ErrInvalidTag) {
|
||||
invalidTagMap[tag] = true
|
||||
|
|
|
@ -418,6 +418,7 @@ acls:
|
|||
User: types.User{
|
||||
Name: "testuser",
|
||||
},
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
},
|
||||
})
|
||||
|
||||
|
@ -1264,7 +1265,7 @@ func Test_expandAlias(t *testing.T) {
|
|||
netip.MustParseAddr("100.64.0.1"),
|
||||
},
|
||||
User: types.User{Name: "joe"},
|
||||
HostInfo: types.HostInfo{
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
OS: "centos",
|
||||
Hostname: "foo",
|
||||
RequestTags: []string{"tag:hr-webserver"},
|
||||
|
@ -1275,7 +1276,7 @@ func Test_expandAlias(t *testing.T) {
|
|||
netip.MustParseAddr("100.64.0.2"),
|
||||
},
|
||||
User: types.User{Name: "joe"},
|
||||
HostInfo: types.HostInfo{
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
OS: "centos",
|
||||
Hostname: "foo",
|
||||
RequestTags: []string{"tag:hr-webserver"},
|
||||
|
@ -1405,7 +1406,7 @@ func Test_expandAlias(t *testing.T) {
|
|||
netip.MustParseAddr("100.64.0.2"),
|
||||
},
|
||||
User: types.User{Name: "joe"},
|
||||
HostInfo: types.HostInfo{
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
OS: "centos",
|
||||
Hostname: "foo",
|
||||
RequestTags: []string{"tag:hr-webserver"},
|
||||
|
@ -1443,7 +1444,7 @@ func Test_expandAlias(t *testing.T) {
|
|||
netip.MustParseAddr("100.64.0.1"),
|
||||
},
|
||||
User: types.User{Name: "joe"},
|
||||
HostInfo: types.HostInfo{
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
OS: "centos",
|
||||
Hostname: "foo",
|
||||
RequestTags: []string{"tag:accountant-webserver"},
|
||||
|
@ -1454,7 +1455,7 @@ func Test_expandAlias(t *testing.T) {
|
|||
netip.MustParseAddr("100.64.0.2"),
|
||||
},
|
||||
User: types.User{Name: "joe"},
|
||||
HostInfo: types.HostInfo{
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
OS: "centos",
|
||||
Hostname: "foo",
|
||||
RequestTags: []string{"tag:accountant-webserver"},
|
||||
|
@ -1464,13 +1465,15 @@ func Test_expandAlias(t *testing.T) {
|
|||
IPAddresses: types.NodeAddresses{
|
||||
netip.MustParseAddr("100.64.0.3"),
|
||||
},
|
||||
User: types.User{Name: "marc"},
|
||||
User: types.User{Name: "marc"},
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
},
|
||||
&types.Node{
|
||||
IPAddresses: types.NodeAddresses{
|
||||
netip.MustParseAddr("100.64.0.4"),
|
||||
},
|
||||
User: types.User{Name: "joe"},
|
||||
User: types.User{Name: "joe"},
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -1520,7 +1523,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
|
|||
netip.MustParseAddr("100.64.0.1"),
|
||||
},
|
||||
User: types.User{Name: "joe"},
|
||||
HostInfo: types.HostInfo{
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
OS: "centos",
|
||||
Hostname: "foo",
|
||||
RequestTags: []string{"tag:accountant-webserver"},
|
||||
|
@ -1531,7 +1534,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
|
|||
netip.MustParseAddr("100.64.0.2"),
|
||||
},
|
||||
User: types.User{Name: "joe"},
|
||||
HostInfo: types.HostInfo{
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
OS: "centos",
|
||||
Hostname: "foo",
|
||||
RequestTags: []string{"tag:accountant-webserver"},
|
||||
|
@ -1541,7 +1544,8 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
|
|||
IPAddresses: types.NodeAddresses{
|
||||
netip.MustParseAddr("100.64.0.4"),
|
||||
},
|
||||
User: types.User{Name: "joe"},
|
||||
User: types.User{Name: "joe"},
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
},
|
||||
},
|
||||
user: "joe",
|
||||
|
@ -1550,6 +1554,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
|
|||
&types.Node{
|
||||
IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.4")},
|
||||
User: types.User{Name: "joe"},
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -1570,7 +1575,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
|
|||
netip.MustParseAddr("100.64.0.1"),
|
||||
},
|
||||
User: types.User{Name: "joe"},
|
||||
HostInfo: types.HostInfo{
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
OS: "centos",
|
||||
Hostname: "foo",
|
||||
RequestTags: []string{"tag:accountant-webserver"},
|
||||
|
@ -1581,7 +1586,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
|
|||
netip.MustParseAddr("100.64.0.2"),
|
||||
},
|
||||
User: types.User{Name: "joe"},
|
||||
HostInfo: types.HostInfo{
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
OS: "centos",
|
||||
Hostname: "foo",
|
||||
RequestTags: []string{"tag:accountant-webserver"},
|
||||
|
@ -1591,7 +1596,8 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
|
|||
IPAddresses: types.NodeAddresses{
|
||||
netip.MustParseAddr("100.64.0.4"),
|
||||
},
|
||||
User: types.User{Name: "joe"},
|
||||
User: types.User{Name: "joe"},
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
},
|
||||
},
|
||||
user: "joe",
|
||||
|
@ -1600,6 +1606,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
|
|||
&types.Node{
|
||||
IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.4")},
|
||||
User: types.User{Name: "joe"},
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -1615,7 +1622,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
|
|||
netip.MustParseAddr("100.64.0.1"),
|
||||
},
|
||||
User: types.User{Name: "joe"},
|
||||
HostInfo: types.HostInfo{
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
OS: "centos",
|
||||
Hostname: "foo",
|
||||
RequestTags: []string{"tag:accountant-webserver"},
|
||||
|
@ -1627,12 +1634,14 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
|
|||
},
|
||||
User: types.User{Name: "joe"},
|
||||
ForcedTags: []string{"tag:accountant-webserver"},
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
},
|
||||
&types.Node{
|
||||
IPAddresses: types.NodeAddresses{
|
||||
netip.MustParseAddr("100.64.0.4"),
|
||||
},
|
||||
User: types.User{Name: "joe"},
|
||||
User: types.User{Name: "joe"},
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
},
|
||||
},
|
||||
user: "joe",
|
||||
|
@ -1641,6 +1650,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
|
|||
&types.Node{
|
||||
IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.4")},
|
||||
User: types.User{Name: "joe"},
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -1656,7 +1666,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
|
|||
netip.MustParseAddr("100.64.0.1"),
|
||||
},
|
||||
User: types.User{Name: "joe"},
|
||||
HostInfo: types.HostInfo{
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
OS: "centos",
|
||||
Hostname: "hr-web1",
|
||||
RequestTags: []string{"tag:hr-webserver"},
|
||||
|
@ -1667,7 +1677,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
|
|||
netip.MustParseAddr("100.64.0.2"),
|
||||
},
|
||||
User: types.User{Name: "joe"},
|
||||
HostInfo: types.HostInfo{
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
OS: "centos",
|
||||
Hostname: "hr-web2",
|
||||
RequestTags: []string{"tag:hr-webserver"},
|
||||
|
@ -1677,7 +1687,8 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
|
|||
IPAddresses: types.NodeAddresses{
|
||||
netip.MustParseAddr("100.64.0.4"),
|
||||
},
|
||||
User: types.User{Name: "joe"},
|
||||
User: types.User{Name: "joe"},
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
},
|
||||
},
|
||||
user: "joe",
|
||||
|
@ -1688,7 +1699,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
|
|||
netip.MustParseAddr("100.64.0.1"),
|
||||
},
|
||||
User: types.User{Name: "joe"},
|
||||
HostInfo: types.HostInfo{
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
OS: "centos",
|
||||
Hostname: "hr-web1",
|
||||
RequestTags: []string{"tag:hr-webserver"},
|
||||
|
@ -1699,7 +1710,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
|
|||
netip.MustParseAddr("100.64.0.2"),
|
||||
},
|
||||
User: types.User{Name: "joe"},
|
||||
HostInfo: types.HostInfo{
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
OS: "centos",
|
||||
Hostname: "hr-web2",
|
||||
RequestTags: []string{"tag:hr-webserver"},
|
||||
|
@ -1709,7 +1720,8 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
|
|||
IPAddresses: types.NodeAddresses{
|
||||
netip.MustParseAddr("100.64.0.4"),
|
||||
},
|
||||
User: types.User{Name: "joe"},
|
||||
User: types.User{Name: "joe"},
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -1952,7 +1964,7 @@ func Test_getTags(t *testing.T) {
|
|||
User: types.User{
|
||||
Name: "joe",
|
||||
},
|
||||
HostInfo: types.HostInfo{
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RequestTags: []string{"tag:valid"},
|
||||
},
|
||||
},
|
||||
|
@ -1972,7 +1984,7 @@ func Test_getTags(t *testing.T) {
|
|||
User: types.User{
|
||||
Name: "joe",
|
||||
},
|
||||
HostInfo: types.HostInfo{
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RequestTags: []string{"tag:valid", "tag:invalid"},
|
||||
},
|
||||
},
|
||||
|
@ -1992,7 +2004,7 @@ func Test_getTags(t *testing.T) {
|
|||
User: types.User{
|
||||
Name: "joe",
|
||||
},
|
||||
HostInfo: types.HostInfo{
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RequestTags: []string{
|
||||
"tag:invalid",
|
||||
"tag:valid",
|
||||
|
@ -2016,7 +2028,7 @@ func Test_getTags(t *testing.T) {
|
|||
User: types.User{
|
||||
Name: "joe",
|
||||
},
|
||||
HostInfo: types.HostInfo{
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RequestTags: []string{"tag:invalid", "very-invalid"},
|
||||
},
|
||||
},
|
||||
|
@ -2032,7 +2044,7 @@ func Test_getTags(t *testing.T) {
|
|||
User: types.User{
|
||||
Name: "joe",
|
||||
},
|
||||
HostInfo: types.HostInfo{
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RequestTags: []string{"tag:invalid", "very-invalid"},
|
||||
},
|
||||
},
|
||||
|
@ -3010,7 +3022,7 @@ func TestValidExpandTagOwnersInSources(t *testing.T) {
|
|||
Name: "user1",
|
||||
},
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
HostInfo: types.HostInfo(hostInfo),
|
||||
Hostinfo: &hostInfo,
|
||||
}
|
||||
|
||||
pol := &ACLPolicy{
|
||||
|
@ -3062,7 +3074,7 @@ func TestInvalidTagValidUser(t *testing.T) {
|
|||
Name: "user1",
|
||||
},
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
HostInfo: types.HostInfo(hostInfo),
|
||||
Hostinfo: &hostInfo,
|
||||
}
|
||||
|
||||
pol := &ACLPolicy{
|
||||
|
@ -3113,7 +3125,7 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) {
|
|||
Name: "user1",
|
||||
},
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
HostInfo: types.HostInfo(hostInfo),
|
||||
Hostinfo: &hostInfo,
|
||||
}
|
||||
|
||||
pol := &ACLPolicy{
|
||||
|
@ -3174,7 +3186,7 @@ func TestValidTagInvalidUser(t *testing.T) {
|
|||
Name: "user1",
|
||||
},
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
HostInfo: types.HostInfo(hostInfo),
|
||||
Hostinfo: &hostInfo,
|
||||
}
|
||||
|
||||
hostInfo2 := tailcfg.Hostinfo{
|
||||
|
@ -3191,7 +3203,7 @@ func TestValidTagInvalidUser(t *testing.T) {
|
|||
Name: "user1",
|
||||
},
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
HostInfo: types.HostInfo(hostInfo2),
|
||||
Hostinfo: &hostInfo2,
|
||||
}
|
||||
|
||||
pol := &ACLPolicy{
|
||||
|
|
|
@ -83,15 +83,14 @@ func (h *Headscale) handlePoll(
|
|||
Bool("stream", mapRequest.Stream).
|
||||
Str("node_key", node.NodeKey.ShortString()).
|
||||
Str("node", node.Hostname).
|
||||
Strs("endpoints", node.Endpoints).
|
||||
Msg("Received endpoint update")
|
||||
|
||||
now := time.Now().UTC()
|
||||
node.LastSeen = &now
|
||||
node.Hostname = mapRequest.Hostinfo.Hostname
|
||||
node.HostInfo = types.HostInfo(*mapRequest.Hostinfo)
|
||||
node.Hostinfo = mapRequest.Hostinfo
|
||||
node.DiscoKey = mapRequest.DiscoKey
|
||||
node.SetEndpointsFromAddrPorts(mapRequest.Endpoints)
|
||||
node.Endpoints = mapRequest.Endpoints
|
||||
|
||||
if err := h.db.NodeSave(node); err != nil {
|
||||
logErr(err, "Failed to persist/update node in the database")
|
||||
|
@ -142,9 +141,9 @@ func (h *Headscale) handlePoll(
|
|||
now := time.Now().UTC()
|
||||
node.LastSeen = &now
|
||||
node.Hostname = mapRequest.Hostinfo.Hostname
|
||||
node.HostInfo = types.HostInfo(*mapRequest.Hostinfo)
|
||||
node.Hostinfo = mapRequest.Hostinfo
|
||||
node.DiscoKey = mapRequest.DiscoKey
|
||||
node.SetEndpointsFromAddrPorts(mapRequest.Endpoints)
|
||||
node.Endpoints = mapRequest.Endpoints
|
||||
|
||||
// When a node connects to control, list the peers it has at
|
||||
// that given point, further updates are kept in memory in
|
||||
|
|
|
@ -12,33 +12,6 @@ import (
|
|||
|
||||
var ErrCannotParsePrefix = errors.New("cannot parse prefix")
|
||||
|
||||
// This is a "wrapper" type around tailscales
|
||||
// Hostinfo to allow us to add database "serialization"
|
||||
// methods. This allows us to use a typed values throughout
|
||||
// the code and not have to marshal/unmarshal and error
|
||||
// check all over the code.
|
||||
type HostInfo tailcfg.Hostinfo
|
||||
|
||||
func (hi *HostInfo) Scan(destination interface{}) error {
|
||||
switch value := destination.(type) {
|
||||
case []byte:
|
||||
return json.Unmarshal(value, hi)
|
||||
|
||||
case string:
|
||||
return json.Unmarshal([]byte(value), hi)
|
||||
|
||||
default:
|
||||
return fmt.Errorf("%w: unexpected data type %T", ErrNodeAddressesInvalid, destination)
|
||||
}
|
||||
}
|
||||
|
||||
// Value return json value, implement driver.Valuer interface.
|
||||
func (hi HostInfo) Value() (driver.Value, error) {
|
||||
bytes, err := json.Marshal(hi)
|
||||
|
||||
return string(bytes), err
|
||||
}
|
||||
|
||||
type IPPrefix netip.Prefix
|
||||
|
||||
func (i *IPPrefix) Scan(destination interface{}) error {
|
||||
|
|
|
@ -2,6 +2,7 @@ package types
|
|||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
@ -27,27 +28,40 @@ var (
|
|||
type Node struct {
|
||||
ID uint64 `gorm:"primary_key"`
|
||||
|
||||
// MachineKeyValue is the string representation of MachineKey
|
||||
// MachineKeyDatabaseField is the string representation of MachineKey
|
||||
// it is _only_ used for reading and writing the key to the
|
||||
// database and should not be used.
|
||||
// Use MachineKey instead.
|
||||
MachineKeyValue string `gorm:"column:machine_key;unique_index"`
|
||||
MachineKeyDatabaseField string `gorm:"column:machine_key;unique_index"`
|
||||
MachineKey key.MachinePublic `gorm:"-"`
|
||||
|
||||
// NodeKeyValue is the string representation of NodeKey
|
||||
// NodeKeyDatabaseField is the string representation of NodeKey
|
||||
// it is _only_ used for reading and writing the key to the
|
||||
// database and should not be used.
|
||||
// Use NodeKey instead.
|
||||
NodeKeyValue string `gorm:"column:node_key"`
|
||||
NodeKeyDatabaseField string `gorm:"column:node_key"`
|
||||
NodeKey key.NodePublic `gorm:"-"`
|
||||
|
||||
// DiscoKeyValue is the string representation of DiscoKey
|
||||
// DiscoKeyDatabaseField is the string representation of DiscoKey
|
||||
// it is _only_ used for reading and writing the key to the
|
||||
// database and should not be used.
|
||||
// Use DiscoKey instead.
|
||||
DiscoKeyValue string `gorm:"column:disco_key"`
|
||||
DiscoKeyDatabaseField string `gorm:"column:disco_key"`
|
||||
DiscoKey key.DiscoPublic `gorm:"-"`
|
||||
|
||||
MachineKey key.MachinePublic `gorm:"-"`
|
||||
NodeKey key.NodePublic `gorm:"-"`
|
||||
DiscoKey key.DiscoPublic `gorm:"-"`
|
||||
// EndpointsDatabaseField is the string list representation of Endpoints
|
||||
// it is _only_ used for reading and writing the key to the
|
||||
// database and should not be used.
|
||||
// Use Endpoints instead.
|
||||
EndpointsDatabaseField StringList `gorm:"column:endpoints"`
|
||||
Endpoints []netip.AddrPort `gorm:"-"`
|
||||
|
||||
// EndpointsDatabaseField is the string list representation of Endpoints
|
||||
// it is _only_ used for reading and writing the key to the
|
||||
// database and should not be used.
|
||||
// Use Endpoints instead.
|
||||
HostinfoDatabaseField string `gorm:"column:hostinfo"`
|
||||
Hostinfo *tailcfg.Hostinfo `gorm:"-"`
|
||||
|
||||
IPAddresses NodeAddresses
|
||||
|
||||
|
@ -76,9 +90,6 @@ type Node struct {
|
|||
LastSeen *time.Time
|
||||
Expiry *time.Time
|
||||
|
||||
HostInfo HostInfo
|
||||
Endpoints StringList
|
||||
|
||||
Routes []Route
|
||||
|
||||
CreatedAt time.Time
|
||||
|
@ -195,31 +206,6 @@ func (node Node) IsExpired() bool {
|
|||
return time.Now().UTC().After(*node.Expiry)
|
||||
}
|
||||
|
||||
// TODO(kradalby): Try to replace the types in the DB to be correct.
|
||||
func (node *Node) EndpointsToAddrPort() ([]netip.AddrPort, error) {
|
||||
var ret []netip.AddrPort
|
||||
for _, ep := range node.Endpoints {
|
||||
addrPort, err := netip.ParseAddrPort(ep)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret = append(ret, addrPort)
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// TODO(kradalby): Try to replace the types in the DB to be correct.
|
||||
func (node *Node) SetEndpointsFromAddrPorts(in []netip.AddrPort) {
|
||||
var strs StringList
|
||||
for _, addrPort := range in {
|
||||
strs = append(strs, addrPort.String())
|
||||
}
|
||||
|
||||
node.Endpoints = strs
|
||||
}
|
||||
|
||||
// IsOnline returns if the node is connected to Headscale.
|
||||
// This is really a naive implementation, as we don't really see
|
||||
// if there is a working connection between the client and the server.
|
||||
|
@ -277,9 +263,22 @@ func (nodes Nodes) FilterByIP(ip netip.Addr) Nodes {
|
|||
// correctly in the database.
|
||||
// This currently means storing the keys as strings.
|
||||
func (n *Node) BeforeSave(tx *gorm.DB) (err error) {
|
||||
n.MachineKeyValue = n.MachineKey.String()
|
||||
n.NodeKeyValue = n.NodeKey.String()
|
||||
n.DiscoKeyValue = n.DiscoKey.String()
|
||||
n.MachineKeyDatabaseField = n.MachineKey.String()
|
||||
n.NodeKeyDatabaseField = n.NodeKey.String()
|
||||
n.DiscoKeyDatabaseField = n.DiscoKey.String()
|
||||
|
||||
var endpoints StringList
|
||||
for _, addrPort := range n.Endpoints {
|
||||
endpoints = append(endpoints, addrPort.String())
|
||||
}
|
||||
|
||||
n.EndpointsDatabaseField = endpoints
|
||||
|
||||
hi, err := json.Marshal(n.Hostinfo)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal Hostinfo to store in db: %w", err)
|
||||
}
|
||||
n.HostinfoDatabaseField = string(hi)
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -291,23 +290,40 @@ func (n *Node) BeforeSave(tx *gorm.DB) (err error) {
|
|||
// the proper types.
|
||||
func (n *Node) AfterFind(tx *gorm.DB) (err error) {
|
||||
var machineKey key.MachinePublic
|
||||
if err := machineKey.UnmarshalText([]byte(n.MachineKeyValue)); err != nil {
|
||||
return err
|
||||
if err := machineKey.UnmarshalText([]byte(n.MachineKeyDatabaseField)); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal machine key from db: %w", err)
|
||||
}
|
||||
n.MachineKey = machineKey
|
||||
|
||||
var nodeKey key.NodePublic
|
||||
if err := nodeKey.UnmarshalText([]byte(n.NodeKeyValue)); err != nil {
|
||||
return err
|
||||
if err := nodeKey.UnmarshalText([]byte(n.NodeKeyDatabaseField)); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal node key from db: %w", err)
|
||||
}
|
||||
n.NodeKey = nodeKey
|
||||
|
||||
var discoKey key.DiscoPublic
|
||||
if err := discoKey.UnmarshalText([]byte(n.DiscoKeyValue)); err != nil {
|
||||
return err
|
||||
if err := discoKey.UnmarshalText([]byte(n.DiscoKeyDatabaseField)); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal disco key from db: %w", err)
|
||||
}
|
||||
n.DiscoKey = discoKey
|
||||
|
||||
var endpoints []netip.AddrPort
|
||||
for _, ep := range n.EndpointsDatabaseField {
|
||||
addrPort, err := netip.ParseAddrPort(ep)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse endpoint from db: %w", err)
|
||||
}
|
||||
|
||||
endpoints = append(endpoints, addrPort)
|
||||
}
|
||||
n.Endpoints = endpoints
|
||||
|
||||
var hi tailcfg.Hostinfo
|
||||
if err := json.Unmarshal([]byte(n.HostinfoDatabaseField), &hi); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal Hostinfo from db: %w", err)
|
||||
}
|
||||
n.Hostinfo = &hi
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -346,11 +362,6 @@ func (node *Node) Proto() *v1.Node {
|
|||
return nodeProto
|
||||
}
|
||||
|
||||
// GetHostInfo returns a Hostinfo struct for the node.
|
||||
func (node *Node) GetHostInfo() tailcfg.Hostinfo {
|
||||
return tailcfg.Hostinfo(node.HostInfo)
|
||||
}
|
||||
|
||||
func (node *Node) GetFQDN(dnsConfig *tailcfg.DNSConfig, baseDomain string) (string, error) {
|
||||
var hostname string
|
||||
if dnsConfig != nil && dnsConfig.Proxied { // MagicDNS
|
||||
|
|
Loading…
Reference in a new issue