diff --git a/.github/workflows/test-integration.yaml b/.github/workflows/test-integration.yaml index 80daf20a..65324f77 100644 --- a/.github/workflows/test-integration.yaml +++ b/.github/workflows/test-integration.yaml @@ -50,6 +50,7 @@ jobs: - TestEphemeral2006DeletedTooQuickly - TestPingAllByHostname - TestTaildrop + - TestUpdateHostnameFromClient - TestExpireNode - TestNodeOnlineStatus - TestPingAllByIPManyUpDown diff --git a/CHANGELOG.md b/CHANGELOG.md index 22f05780..465adc87 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ - Allow nodes to use SSH agent forwarding [#2145](https://github.com/juanfont/headscale/pull/2145) - Fixed processing of fields in post request in MoveNode rpc [#2179](https://github.com/juanfont/headscale/pull/2179) - Added conversion of 'Hostname' to 'givenName' in a node with FQDN rules applied [#2198](https://github.com/juanfont/headscale/pull/2198) +- Fixed updating of hostname and givenName when it is updated in HostInfo [#2199](https://github.com/juanfont/headscale/pull/2199) ## 0.23.0 (2024-09-18) diff --git a/hscontrol/poll.go b/hscontrol/poll.go index 755265f3..a8ae01f4 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -471,7 +471,7 @@ func (m *mapSession) handleEndpointUpdate() { // Check if the Hostinfo of the node has changed. // If it has changed, check if there has been a change to - // the routable IPs of the host and update update them in + // the routable IPs of the host and update them in // the database. Then send a Changed update // (containing the whole node object) to peers to inform about // the route change. @@ -510,6 +510,12 @@ func (m *mapSession) handleEndpointUpdate() { m.node.ID) } + // Check if there has been a change to Hostname and update them + // in the database. Then send a Changed update + // (containing the whole node object) to peers to inform about + // the hostname change. + m.node.ApplyHostnameFromHostInfo(m.req.Hostinfo) + if err := m.h.db.DB.Save(m.node).Error; err != nil { m.errf(err, "Failed to persist/update node in the database") http.Error(m.w, "", http.StatusInternalServerError) @@ -526,7 +532,8 @@ func (m *mapSession) handleEndpointUpdate() { ChangeNodes: []types.NodeID{m.node.ID}, Message: "called from handlePoll -> update", }, - m.node.ID) + m.node.ID, + ) m.w.WriteHeader(http.StatusOK) mapResponseEndpointUpdates.WithLabelValues("ok").Inc() diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index c702f23a..9d632bd8 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -97,6 +97,11 @@ type ( Nodes []*Node ) +// GivenNameHasBeenChanged returns whether the `givenName` can be automatically changed based on the `Hostname` of the node. +func (node *Node) GivenNameHasBeenChanged() bool { + return node.GivenName == util.ConvertWithFQDNRules(node.Hostname) +} + // IsExpired returns whether the node registration has expired. func (node Node) IsExpired() bool { // If Expiry is not set, the client has not indicated that @@ -347,6 +352,21 @@ func (node *Node) RegisterMethodToV1Enum() v1.RegisterMethod { } } +// ApplyHostnameFromHostInfo takes a Hostinfo struct and updates the node. +func (node *Node) ApplyHostnameFromHostInfo(hostInfo *tailcfg.Hostinfo) { + if hostInfo == nil { + return + } + + if node.Hostname != hostInfo.Hostname { + if node.GivenNameHasBeenChanged() { + node.GivenName = util.ConvertWithFQDNRules(hostInfo.Hostname) + } + + node.Hostname = hostInfo.Hostname + } +} + // ApplyPeerChange takes a PeerChange struct and updates the node. func (node *Node) ApplyPeerChange(change *tailcfg.PeerChange) { if change.Key != nil { diff --git a/hscontrol/types/node_test.go b/hscontrol/types/node_test.go index 1d0e7939..d439d483 100644 --- a/hscontrol/types/node_test.go +++ b/hscontrol/types/node_test.go @@ -337,6 +337,66 @@ func TestPeerChangeFromMapRequest(t *testing.T) { } } +func TestApplyHostnameFromHostInfo(t *testing.T) { + tests := []struct { + name string + nodeBefore Node + change *tailcfg.Hostinfo + want Node + }{ + { + name: "hostinfo-not-exists", + nodeBefore: Node{ + GivenName: "manual-test.local", + Hostname: "TestHost.Local", + }, + change: nil, + want: Node{ + GivenName: "manual-test.local", + Hostname: "TestHost.Local", + }, + }, + { + name: "hostinfo-exists-no-automatic-givenName", + nodeBefore: Node{ + GivenName: "manual-test.local", + Hostname: "TestHost.Local", + }, + change: &tailcfg.Hostinfo{ + Hostname: "NewHostName.Local", + }, + want: Node{ + GivenName: "manual-test.local", + Hostname: "NewHostName.Local", + }, + }, + { + name: "hostinfo-exists-automatic-givenName", + nodeBefore: Node{ + GivenName: "automaticname.test", + Hostname: "AutomaticName.Test", + }, + change: &tailcfg.Hostinfo{ + Hostname: "NewHostName.Local", + }, + want: Node{ + GivenName: "newhostname.local", + Hostname: "NewHostName.Local", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.nodeBefore.ApplyHostnameFromHostInfo(tt.change) + + if diff := cmp.Diff(tt.want, tt.nodeBefore, util.Comparers...); diff != "" { + t.Errorf("Patch unexpected result (-want +got):\n%s", diff) + } + }) + } +} + func TestApplyPeerChange(t *testing.T) { tests := []struct { name string diff --git a/integration/general_test.go b/integration/general_test.go index 085691fb..93b06761 100644 --- a/integration/general_test.go +++ b/integration/general_test.go @@ -5,12 +5,14 @@ import ( "encoding/json" "fmt" "net/netip" + "strconv" "strings" "testing" "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" "github.com/rs/zerolog/log" @@ -654,6 +656,134 @@ func TestTaildrop(t *testing.T) { } } +func TestUpdateHostnameFromClient(t *testing.T) { + IntegrationSkip(t) + t.Parallel() + + user := "update-hostname-from-client" + + hostnames := map[string]string{ + "1": "user1-host", + "2": "User2-Host", + "3": "user3-host", + } + + scenario, err := NewScenario(dockertestMaxWait()) + assertNoErrf(t, "failed to create scenario: %s", err) + defer scenario.ShutdownAssertNoPanics(t) + + spec := map[string]int{ + user: 3, + } + + err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("updatehostname")) + assertNoErrHeadscaleEnv(t, err) + + allClients, err := scenario.ListTailscaleClients() + assertNoErrListClients(t, err) + + err = scenario.WaitForTailscaleSync() + assertNoErrSync(t, err) + + headscale, err := scenario.Headscale() + assertNoErrGetHeadscale(t, err) + + // update hostnames using the up command + for _, client := range allClients { + status, err := client.Status() + assertNoErr(t, err) + + command := []string{ + "tailscale", + "set", + "--hostname=" + hostnames[string(status.Self.ID)], + } + _, _, err = client.Execute(command) + assertNoErrf(t, "failed to set hostname: %s", err) + } + + err = scenario.WaitForTailscaleSync() + assertNoErrSync(t, err) + + var nodes []*v1.Node + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "node", + "list", + "--output", + "json", + }, + &nodes, + ) + + assertNoErr(t, err) + assert.Len(t, nodes, 3) + + for _, node := range nodes { + hostname := hostnames[strconv.FormatUint(node.GetId(), 10)] + assert.Equal(t, hostname, node.GetName()) + assert.Equal(t, util.ConvertWithFQDNRules(hostname), node.GetGivenName()) + } + + // Rename givenName in nodes + for _, node := range nodes { + givenName := fmt.Sprintf("%d-givenname", node.GetId()) + _, err = headscale.Execute( + []string{ + "headscale", + "node", + "rename", + givenName, + "--identifier", + strconv.FormatUint(node.GetId(), 10), + }) + assertNoErr(t, err) + } + + time.Sleep(5 * time.Second) + + // Verify that the clients can see the new hostname, but no givenName + for _, client := range allClients { + status, err := client.Status() + assertNoErr(t, err) + + command := []string{ + "tailscale", + "set", + "--hostname=" + hostnames[string(status.Self.ID)] + "NEW", + } + _, _, err = client.Execute(command) + assertNoErrf(t, "failed to set hostname: %s", err) + } + + err = scenario.WaitForTailscaleSync() + assertNoErrSync(t, err) + + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "node", + "list", + "--output", + "json", + }, + &nodes, + ) + + assertNoErr(t, err) + assert.Len(t, nodes, 3) + + for _, node := range nodes { + hostname := hostnames[strconv.FormatUint(node.GetId(), 10)] + givenName := fmt.Sprintf("%d-givenname", node.GetId()) + assert.Equal(t, hostname+"NEW", node.GetName()) + assert.Equal(t, givenName, node.GetGivenName()) + } +} + func TestExpireNode(t *testing.T) { IntegrationSkip(t) t.Parallel()