This commit is contained in:
LiuHanCheng 2022-11-05 16:07:22 +08:00 committed by GitHub
parent bf87b33292
commit 07f92e647c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 47 additions and 35 deletions

View file

@ -26,6 +26,7 @@ linters:
- ireturn - ireturn
- execinquery - execinquery
- exhaustruct - exhaustruct
- nolintlint
# We should strive to enable these: # We should strive to enable these:
- wrapcheck - wrapcheck

View file

@ -3,6 +3,7 @@ package cli
import ( import (
"fmt" "fmt"
"github.com/juanfont/headscale"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@ -10,8 +11,7 @@ import (
) )
const ( const (
keyLength = 64 errPreAuthKeyMalformed = Error("key is malformed. expected 64 hex characters with `nodekey` prefix")
errPreAuthKeyTooShort = Error("key too short, must be 64 hexadecimal characters")
) )
// Error is used to compare errors as per https://dave.cheney.net/2016/04/07/constant-errors // Error is used to compare errors as per https://dave.cheney.net/2016/04/07/constant-errors
@ -87,8 +87,8 @@ var createNodeCmd = &cobra.Command{
return return
} }
if len(machineKey) != keyLength { if !headscale.NodePublicKeyRegex.Match([]byte(machineKey)) {
err = errPreAuthKeyTooShort err = errPreAuthKeyMalformed
ErrorOutput( ErrorOutput(
err, err,
fmt.Sprintf("Error: %s", err), fmt.Sprintf("Error: %s", err),

View file

@ -1,4 +1,4 @@
//nolint // nolint
package headscale package headscale
import ( import (
@ -12,6 +12,7 @@ import (
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key"
) )
type headscaleV1APIServer struct { // v1.HeadscaleServiceServer type headscaleV1APIServer struct { // v1.HeadscaleServiceServer
@ -496,9 +497,14 @@ func (api headscaleV1APIServer) DebugCreateMachine(
HostInfo: HostInfo(hostinfo), HostInfo: HostInfo(hostinfo),
} }
nodeKey := key.NodePublic{}
err = nodeKey.UnmarshalText([]byte(request.GetKey()))
if err != nil {
log.Panic().Msg("can not add machine for debug. invalid node key")
}
api.h.registrationCache.Set( api.h.registrationCache.Set(
request.GetKey(), NodePublicKeyStripPrefix(nodeKey),
newMachine, newMachine,
registerCacheExpiration, registerCacheExpiration,
) )

View file

@ -1,4 +1,4 @@
//nolint // nolint
package headscale package headscale
import ( import (
@ -558,8 +558,8 @@ func (s *IntegrationCLITestSuite) TestNodeTagCommand() {
assert.Nil(s.T(), err) assert.Nil(s.T(), err)
machineKeys := []string{ machineKeys := []string{
"9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe", "nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
"6abd00bb5fdda622db51387088c68e97e71ce58e7056aa54f592b6a8219d524c", "nodekey:6abd00bb5fdda622db51387088c68e97e71ce58e7056aa54f592b6a8219d524c",
} }
machines := make([]*v1.Machine, len(machineKeys)) machines := make([]*v1.Machine, len(machineKeys))
assert.Nil(s.T(), err) assert.Nil(s.T(), err)
@ -691,11 +691,11 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() {
// Randomly generated machine keys // Randomly generated machine keys
machineKeys := []string{ machineKeys := []string{
"9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe", "nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
"6abd00bb5fdda622db51387088c68e97e71ce58e7056aa54f592b6a8219d524c", "nodekey:6abd00bb5fdda622db51387088c68e97e71ce58e7056aa54f592b6a8219d524c",
"f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507", "nodekey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507",
"8bc13285cee598acf76b1824a6f4490f7f2e3751b201e28aeb3b07fe81d5b4a1", "nodekey:8bc13285cee598acf76b1824a6f4490f7f2e3751b201e28aeb3b07fe81d5b4a1",
"cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084", "nodekey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084",
} }
machines := make([]*v1.Machine, len(machineKeys)) machines := make([]*v1.Machine, len(machineKeys))
assert.Nil(s.T(), err) assert.Nil(s.T(), err)
@ -779,8 +779,8 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() {
assert.Equal(s.T(), "machine-5", listAll[4].Name) assert.Equal(s.T(), "machine-5", listAll[4].Name)
otherNamespaceMachineKeys := []string{ otherNamespaceMachineKeys := []string{
"b5b444774186d4217adcec407563a1223929465ee2c68a4da13af0d0185b4f8e", "nodekey:b5b444774186d4217adcec407563a1223929465ee2c68a4da13af0d0185b4f8e",
"dc721977ac7415aafa87f7d4574cbe07c6b171834a6d37375782bdc1fb6b3584", "nodekey:dc721977ac7415aafa87f7d4574cbe07c6b171834a6d37375782bdc1fb6b3584",
} }
otherNamespaceMachines := make([]*v1.Machine, len(otherNamespaceMachineKeys)) otherNamespaceMachines := make([]*v1.Machine, len(otherNamespaceMachineKeys))
assert.Nil(s.T(), err) assert.Nil(s.T(), err)
@ -950,11 +950,11 @@ func (s *IntegrationCLITestSuite) TestNodeExpireCommand() {
// Randomly generated machine keys // Randomly generated machine keys
machineKeys := []string{ machineKeys := []string{
"9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe", "nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
"6abd00bb5fdda622db51387088c68e97e71ce58e7056aa54f592b6a8219d524c", "nodekey:6abd00bb5fdda622db51387088c68e97e71ce58e7056aa54f592b6a8219d524c",
"f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507", "nodekey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507",
"8bc13285cee598acf76b1824a6f4490f7f2e3751b201e28aeb3b07fe81d5b4a1", "nodekey:8bc13285cee598acf76b1824a6f4490f7f2e3751b201e28aeb3b07fe81d5b4a1",
"cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084", "nodekey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084",
} }
machines := make([]*v1.Machine, len(machineKeys)) machines := make([]*v1.Machine, len(machineKeys))
assert.Nil(s.T(), err) assert.Nil(s.T(), err)
@ -1077,11 +1077,11 @@ func (s *IntegrationCLITestSuite) TestNodeRenameCommand() {
// Randomly generated machine keys // Randomly generated machine keys
machineKeys := []string{ machineKeys := []string{
"cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084", "nodekey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084",
"8bc13285cee598acf76b1824a6f4490f7f2e3751b201e28aeb3b07fe81d5b4a1", "nodekey:8bc13285cee598acf76b1824a6f4490f7f2e3751b201e28aeb3b07fe81d5b4a1",
"f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507", "nodekey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507",
"6abd00bb5fdda622db51387088c68e97e71ce58e7056aa54f592b6a8219d524c", "nodekey:6abd00bb5fdda622db51387088c68e97e71ce58e7056aa54f592b6a8219d524c",
"9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe", "nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
} }
machines := make([]*v1.Machine, len(machineKeys)) machines := make([]*v1.Machine, len(machineKeys))
assert.Nil(s.T(), err) assert.Nil(s.T(), err)
@ -1248,7 +1248,7 @@ func (s *IntegrationCLITestSuite) TestRouteCommand() {
assert.Nil(s.T(), err) assert.Nil(s.T(), err)
// Randomly generated machine keys // Randomly generated machine keys
machineKey := "9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe" machineKey := "nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe"
_, _, err = ExecuteCommand( _, _, err = ExecuteCommand(
&s.headscale, &s.headscale,
@ -1588,7 +1588,7 @@ func (s *IntegrationCLITestSuite) TestNodeMoveCommand() {
assert.Nil(s.T(), err) assert.Nil(s.T(), err)
// Randomly generated machine key // Randomly generated machine key
machineKey := "688411b767663479632d44140f08a9fde87383adc7cdeb518f62ce28a17ef0aa" machineKey := "nodekey:688411b767663479632d44140f08a9fde87383adc7cdeb518f62ce28a17ef0aa"
_, _, err = ExecuteCommand( _, _, err = ExecuteCommand(
&s.headscale, &s.headscale,

View file

@ -1,4 +1,4 @@
//nolint // nolint
package headscale package headscale
import ( import (

View file

@ -839,7 +839,13 @@ func (h *Headscale) RegisterMachineFromAuthCallback(
namespaceName string, namespaceName string,
registrationMethod string, registrationMethod string,
) (*Machine, error) { ) (*Machine, error) {
if machineInterface, ok := h.registrationCache.Get(nodeKeyStr); ok { nodeKey := key.NodePublic{}
err := nodeKey.UnmarshalText([]byte(nodeKeyStr))
if err != nil {
return nil, err
}
if machineInterface, ok := h.registrationCache.Get(NodePublicKeyStripPrefix(nodeKey)); ok {
if registrationMachine, ok := machineInterface.(Machine); ok { if registrationMachine, ok := machineInterface.(Machine); ok {
namespace, err := h.GetNamespace(namespaceName) namespace, err := h.GetNamespace(namespaceName)
if err != nil { if err != nil {

View file

@ -604,10 +604,9 @@ func (h *Headscale) registerMachineForOIDCCallback(
namespace *Namespace, namespace *Namespace,
nodeKey *key.NodePublic, nodeKey *key.NodePublic,
) error { ) error {
nodeKeyStr := NodePublicKeyStripPrefix(*nodeKey)
if _, err := h.RegisterMachineFromAuthCallback( if _, err := h.RegisterMachineFromAuthCallback(
nodeKeyStr, nodeKey.String(),
namespace.Name, namespace.Name,
RegisterMethodOIDC, RegisterMethodOIDC,
); err != nil { ); err != nil {

View file

@ -490,12 +490,12 @@ func (h *Headscale) handleNewMachineCommon(
resp.AuthURL = fmt.Sprintf( resp.AuthURL = fmt.Sprintf(
"%s/oidc/register/%s", "%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), strings.TrimSuffix(h.cfg.ServerURL, "/"),
NodePublicKeyStripPrefix(registerRequest.NodeKey), registerRequest.NodeKey,
) )
} else { } else {
resp.AuthURL = fmt.Sprintf("%s/register/%s", resp.AuthURL = fmt.Sprintf("%s/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), strings.TrimSuffix(h.cfg.ServerURL, "/"),
NodePublicKeyStripPrefix(registerRequest.NodeKey)) registerRequest.NodeKey)
} }
respBody, err := h.marshalResponse(resp, machineKey) respBody, err := h.marshalResponse(resp, machineKey)
@ -726,7 +726,7 @@ func (h *Headscale) handleMachineExpiredCommon(
} else { } else {
resp.AuthURL = fmt.Sprintf("%s/register/%s", resp.AuthURL = fmt.Sprintf("%s/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), strings.TrimSuffix(h.cfg.ServerURL, "/"),
NodePublicKeyStripPrefix(registerRequest.NodeKey)) registerRequest.NodeKey)
} }
respBody, err := h.marshalResponse(resp, machineKey) respBody, err := h.marshalResponse(resp, machineKey)