diff --git a/cmd/headscale/cli/nodes.go b/cmd/headscale/cli/nodes.go index 055d2e3e..687ec93c 100644 --- a/cmd/headscale/cli/nodes.go +++ b/cmd/headscale/cli/nodes.go @@ -529,7 +529,7 @@ func nodesToPtables( var machineKey key.MachinePublic err := machineKey.UnmarshalText( - []byte(util.MachinePublicKeyEnsurePrefix(node.MachineKey)), + []byte(node.MachineKey), ) if err != nil { machineKey = key.MachinePublic{} @@ -537,7 +537,7 @@ func nodesToPtables( var nodeKey key.NodePublic err = nodeKey.UnmarshalText( - []byte(util.NodePublicKeyEnsurePrefix(node.NodeKey)), + []byte(node.NodeKey), ) if err != nil { return nil, err diff --git a/hscontrol/app.go b/hscontrol/app.go index 59284cb1..bb67ffc4 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -911,10 +911,9 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) { } trimmedPrivateKey := strings.TrimSpace(string(privateKey)) - privateKeyEnsurePrefix := util.PrivateKeyEnsurePrefix(trimmedPrivateKey) var machineKey key.MachinePrivate - if err = machineKey.UnmarshalText([]byte(privateKeyEnsurePrefix)); err != nil { + if err = machineKey.UnmarshalText([]byte(trimmedPrivateKey)); err != nil { log.Info(). Str("path", path). Msg("This might be due to a legacy (headscale pre-0.12) private key. " + diff --git a/hscontrol/auth.go b/hscontrol/auth.go index b7563659..22dc8699 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -45,7 +45,7 @@ func (h *Headscale) handleRegister( // is that the client will hammer headscale with requests until it gets a // successful RegisterResponse. if registerRequest.Followup != "" { - if _, ok := h.registrationCache.Get(util.NodePublicKeyStripPrefix(registerRequest.NodeKey)); ok { + if _, ok := h.registrationCache.Get(registerRequest.NodeKey.String()); ok { log.Debug(). Caller(). Str("node", registerRequest.Hostinfo.Hostname). @@ -97,10 +97,10 @@ func (h *Headscale) handleRegister( // We create the node and then keep it around until a callback // happens newNode := types.Node{ - MachineKey: util.MachinePublicKeyStripPrefix(machineKey), + MachineKey: machineKey.String(), Hostname: registerRequest.Hostinfo.Hostname, GivenName: givenName, - NodeKey: util.NodePublicKeyStripPrefix(registerRequest.NodeKey), + NodeKey: registerRequest.NodeKey.String(), LastSeen: &now, Expiry: &time.Time{}, } @@ -136,7 +136,7 @@ func (h *Headscale) handleRegister( // So if we have a not valid MachineKey (but we were able to fetch the node with the NodeKeys), we update it. var storedMachineKey key.MachinePublic err = storedMachineKey.UnmarshalText( - []byte(util.MachinePublicKeyEnsurePrefix(node.MachineKey)), + []byte(node.MachineKey), ) if err != nil || storedMachineKey.IsZero() { if err := h.db.NodeSetMachineKey(node, machineKey); err != nil { @@ -156,7 +156,7 @@ func (h *Headscale) handleRegister( // - Trying to log out (sending a expiry in the past) // - A valid, registered node, looking for /map // - Expired node wanting to reauthenticate - if node.NodeKey == util.NodePublicKeyStripPrefix(registerRequest.NodeKey) { + if node.NodeKey == registerRequest.NodeKey.String() { // The client sends an Expiry in the past if the client is requesting to expire the key (aka logout) // https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648 if !registerRequest.Expiry.IsZero() && @@ -176,7 +176,7 @@ func (h *Headscale) handleRegister( } // The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration - if node.NodeKey == util.NodePublicKeyStripPrefix(registerRequest.OldNodeKey) && + if node.NodeKey == registerRequest.OldNodeKey.String() && !node.IsExpired() { h.handleNodeKeyRefresh( writer, @@ -207,9 +207,9 @@ func (h *Headscale) handleRegister( // we need to make sure the NodeKey matches the one in the request // TODO(juan): What happens when using fast user switching between two // headscale-managed tailnets? - node.NodeKey = util.NodePublicKeyStripPrefix(registerRequest.NodeKey) + node.NodeKey = registerRequest.NodeKey.String() h.registrationCache.Set( - util.NodePublicKeyStripPrefix(registerRequest.NodeKey), + registerRequest.NodeKey.String(), *node, registerCacheExpiration, ) @@ -294,7 +294,7 @@ func (h *Headscale) handleAuthKey( Str("node", registerRequest.Hostinfo.Hostname). Msg("Authentication key was valid, proceeding to acquire IP addresses") - nodeKey := util.NodePublicKeyStripPrefix(registerRequest.NodeKey) + nodeKey := registerRequest.NodeKey.String() // retrieve node information if it exist // The error is not important, because if it does not @@ -342,7 +342,7 @@ func (h *Headscale) handleAuthKey( } else { now := time.Now().UTC() - givenName, err := h.db.GenerateGivenName(util.MachinePublicKeyStripPrefix(machineKey), registerRequest.Hostinfo.Hostname) + givenName, err := h.db.GenerateGivenName(machineKey.String(), registerRequest.Hostinfo.Hostname) if err != nil { log.Error(). Caller(). @@ -359,7 +359,7 @@ func (h *Headscale) handleAuthKey( Hostname: registerRequest.Hostinfo.Hostname, GivenName: givenName, UserID: pak.User.ID, - MachineKey: util.MachinePublicKeyStripPrefix(machineKey), + MachineKey: machineKey.String(), RegisterMethod: util.RegisterMethodAuthKey, Expiry: ®isterRequest.Expiry, NodeKey: nodeKey, diff --git a/hscontrol/auth_legacy.go b/hscontrol/auth_legacy.go index f7e0382f..c3b2de34 100644 --- a/hscontrol/auth_legacy.go +++ b/hscontrol/auth_legacy.go @@ -33,7 +33,7 @@ func (h *Headscale) RegistrationHandler( body, _ := io.ReadAll(req.Body) var machineKey key.MachinePublic - err := machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machineKeyStr))) + err := machineKey.UnmarshalText([]byte("mkey:" + machineKeyStr)) if err != nil { log.Error(). Caller(). diff --git a/hscontrol/db/addresses_test.go b/hscontrol/db/addresses_test.go index 781fd896..07059eab 100644 --- a/hscontrol/db/addresses_test.go +++ b/hscontrol/db/addresses_test.go @@ -35,9 +35,6 @@ func (s *Suite) TestGetUsedIps(c *check.C) { node := types.Node{ ID: 0, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", Hostname: "testnode", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, @@ -83,9 +80,6 @@ func (s *Suite) TestGetMultiIp(c *check.C) { node := types.Node{ ID: uint64(index), - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", Hostname: "testnode", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, @@ -173,9 +167,6 @@ func (s *Suite) TestGetAvailableIpNodeWithoutIP(c *check.C) { node := types.Node{ ID: 0, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", Hostname: "testnode", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index 51bb4023..0deeb41d 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net/netip" + "strings" "sync" "time" @@ -252,6 +253,27 @@ func NewHeadscaleDatabase( return nil, err } + // Ensure all keys have correct prefixes + // https://github.com/tailscale/tailscale/blob/main/types/key/node.go#L35 + nodes := types.Nodes{} + if err := dbConn.Find(&nodes).Error; err != nil { + log.Error().Err(err).Msg("Error accessing db") + } + + for _, node := range nodes { + if !strings.HasPrefix(node.DiscoKey, "discokey:") { + node.DiscoKey = "discokey:" + node.DiscoKey + } + + if !strings.HasPrefix(node.NodeKey, "nodekey:") { + node.NodeKey = "nodekey:" + node.NodeKey + } + + if !strings.HasPrefix(node.MachineKey, "mkey:") { + node.MachineKey = "mkey:" + node.MachineKey + } + } + // TODO(kradalby): is this needed? err = db.setValue("db_version", dbVersion) diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index dc4b75dd..05d1cd3d 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -182,7 +182,7 @@ func (hsdb *HSDatabase) GetNodeByMachineKey( Preload("AuthKey.User"). Preload("User"). Preload("Routes"). - First(&mach, "machine_key = ?", util.MachinePublicKeyStripPrefix(machineKey)); result.Error != nil { + First(&mach, "machine_key = ?", machineKey.String()); result.Error != nil { return nil, result.Error } @@ -203,7 +203,7 @@ func (hsdb *HSDatabase) GetNodeByNodeKey( Preload("User"). Preload("Routes"). First(&node, "node_key = ?", - util.NodePublicKeyStripPrefix(nodeKey)); result.Error != nil { + nodeKey.String()); result.Error != nil { return nil, result.Error } @@ -224,9 +224,9 @@ func (hsdb *HSDatabase) GetNodeByAnyKey( Preload("User"). Preload("Routes"). First(&node, "machine_key = ? OR node_key = ? OR node_key = ?", - util.MachinePublicKeyStripPrefix(machineKey), - util.NodePublicKeyStripPrefix(nodeKey), - util.NodePublicKeyStripPrefix(oldNodeKey)); result.Error != nil { + machineKey.String(), + nodeKey.String(), + oldNodeKey.String()); result.Error != nil { return nil, result.Error } @@ -397,7 +397,7 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback( Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)). Msg("Registering node from API/CLI or auth callback") - if nodeInterface, ok := cache.Get(util.NodePublicKeyStripPrefix(nodeKey)); ok { + if nodeInterface, ok := cache.Get(nodeKey.String()); ok { if registrationNode, ok := nodeInterface.(types.Node); ok { user, err := hsdb.getUser(userName) if err != nil { @@ -507,7 +507,7 @@ func (hsdb *HSDatabase) NodeSetNodeKey(node *types.Node, nodeKey key.NodePublic) defer hsdb.mu.Unlock() if err := hsdb.db.Model(node).Updates(types.Node{ - NodeKey: util.NodePublicKeyStripPrefix(nodeKey), + NodeKey: nodeKey.String(), }).Error; err != nil { return err } @@ -524,7 +524,7 @@ func (hsdb *HSDatabase) NodeSetMachineKey( defer hsdb.mu.Unlock() if err := hsdb.db.Model(node).Updates(types.Node{ - MachineKey: util.MachinePublicKeyStripPrefix(machineKey), + MachineKey: machineKey.String(), }).Error; err != nil { return err } diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 54b1cd07..07b6193e 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -82,8 +82,8 @@ func (s *Suite) TestGetNodeByNodeKey(c *check.C) { node := types.Node{ ID: 0, - MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()), - NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()), + MachineKey: machineKey.Public().String(), + NodeKey: nodeKey.Public().String(), DiscoKey: "faa", Hostname: "testnode", UserID: user.ID, @@ -113,8 +113,8 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) { node := types.Node{ ID: 0, - MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()), - NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()), + MachineKey: machineKey.Public().String(), + NodeKey: nodeKey.Public().String(), DiscoKey: "faa", Hostname: "testnode", UserID: user.ID, @@ -575,7 +575,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) { node := types.Node{ ID: 0, MachineKey: "foo", - NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()), + NodeKey: nodeKey.Public().String(), DiscoKey: "faa", Hostname: "test", UserID: user.ID, diff --git a/hscontrol/db/preauth_keys_test.go b/hscontrol/db/preauth_keys_test.go index 9bf8c892..52b837c7 100644 --- a/hscontrol/db/preauth_keys_test.go +++ b/hscontrol/db/preauth_keys_test.go @@ -77,9 +77,6 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) { node := types.Node{ ID: 0, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", Hostname: "testest", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, @@ -101,9 +98,6 @@ func (*Suite) TestReusableBeingUsedKey(c *check.C) { node := types.Node{ ID: 1, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", Hostname: "testest", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, @@ -138,9 +132,6 @@ func (*Suite) TestEphemeralKey(c *check.C) { now := time.Now().Add(-time.Second * 30) node := types.Node{ ID: 0, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", Hostname: "testest", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, diff --git a/hscontrol/db/routes_test.go b/hscontrol/db/routes_test.go index ba5882b5..02959e63 100644 --- a/hscontrol/db/routes_test.go +++ b/hscontrol/db/routes_test.go @@ -29,9 +29,6 @@ func (s *Suite) TestGetRoutes(c *check.C) { node := types.Node{ ID: 0, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", Hostname: "test_get_route_node", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, @@ -80,9 +77,6 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) { node := types.Node{ ID: 0, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", Hostname: "test_enable_route_node", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, @@ -154,9 +148,6 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { } node1 := types.Node{ ID: 1, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", Hostname: "test_enable_route_node", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, @@ -179,9 +170,6 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { } node2 := types.Node{ ID: 2, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", Hostname: "test_enable_route_node", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, @@ -240,9 +228,6 @@ func (s *Suite) TestSubnetFailover(c *check.C) { now := time.Now() node1 := types.Node{ ID: 1, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", Hostname: "test_enable_route_node", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, @@ -277,9 +262,6 @@ func (s *Suite) TestSubnetFailover(c *check.C) { } node2 := types.Node{ ID: 2, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", Hostname: "test_enable_route_node", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, @@ -382,9 +364,6 @@ func (s *Suite) TestDeleteRoutes(c *check.C) { now := time.Now() node1 := types.Node{ ID: 1, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", Hostname: "test_enable_route_node", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, diff --git a/hscontrol/db/users_test.go b/hscontrol/db/users_test.go index 0c43b979..1ca3b49f 100644 --- a/hscontrol/db/users_test.go +++ b/hscontrol/db/users_test.go @@ -48,9 +48,6 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) { node := types.Node{ ID: 0, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", Hostname: "testnode", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, @@ -103,9 +100,6 @@ func (s *Suite) TestSetMachineUser(c *check.C) { node := types.Node{ ID: 0, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", Hostname: "testnode", UserID: oldUser.ID, RegisterMethod: util.RegisterMethodAuthKey, diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index e04e3b19..926e3e87 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -545,7 +545,7 @@ func (api headscaleV1APIServer) DebugCreateNode( } api.h.registrationCache.Set( - util.NodePublicKeyStripPrefix(nodeKey), + nodeKey.String(), newNode, registerCacheExpiration, ) diff --git a/hscontrol/handlers.go b/hscontrol/handlers.go index 7f3b23cf..5c0baa78 100644 --- a/hscontrol/handlers.go +++ b/hscontrol/handlers.go @@ -8,6 +8,7 @@ import ( "html/template" "net/http" "strconv" + "strings" "time" "github.com/gorilla/mux" @@ -71,7 +72,7 @@ func (h *Headscale) KeyHandler( writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusOK) _, err := writer.Write( - []byte(util.MachinePublicKeyStripPrefix(h.privateKey2019.Public())), + []byte(strings.TrimPrefix(h.privateKey2019.Public().String(), "mkey:")), ) if err != nil { log.Error(). @@ -229,7 +230,7 @@ func (h *Headscale) RegisterWebAPI( // the template and log an error. var nodeKey key.NodePublic err := nodeKey.UnmarshalText( - []byte(util.NodePublicKeyEnsurePrefix(nodeKeyStr)), + []byte(nodeKeyStr), ) if !ok || nodeKeyStr == "" || err != nil { diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index ed997df4..7f9e17cf 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -369,7 +369,7 @@ func (m *Mapper) marshalMapResponse( atomic.AddUint64(&m.seq, 1) var machineKey key.MachinePublic - err := machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(node.MachineKey))) + err := machineKey.UnmarshalText([]byte(node.MachineKey)) if err != nil { log.Error(). Caller(). diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index b32d7513..a5ddd973 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -122,7 +122,7 @@ func (h *Headscale) RegisterOIDC( // the template and log an error. var nodeKey key.NodePublic err := nodeKey.UnmarshalText( - []byte(util.NodePublicKeyEnsurePrefix(nodeKeyStr)), + []byte(nodeKeyStr), ) if !ok || nodeKeyStr == "" || err != nil { @@ -154,7 +154,7 @@ func (h *Headscale) RegisterOIDC( // place the node key into the state cache, so it can be retrieved later h.registrationCache.Set( stateStr, - util.NodePublicKeyStripPrefix(nodeKey), + nodeKey, registerCacheExpiration, ) @@ -479,10 +479,11 @@ func (h *Headscale) validateNodeForOIDCCallback( } var nodeKey key.NodePublic - nodeKeyFromCache, nodeKeyOK := nodeKeyIf.(string) + nodeKey, nodeKeyOK := nodeKeyIf.(key.NodePublic) if !nodeKeyOK { log.Trace(). - Msg("requested node state key is not a string") + Interface("got", nodeKeyIf). + Msg("requested node state key is not a nodekey") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusBadRequest) _, err := writer.Write([]byte("state is invalid")) @@ -493,24 +494,6 @@ func (h *Headscale) validateNodeForOIDCCallback( return nil, false, errOIDCInvalidNodeState } - err := nodeKey.UnmarshalText( - []byte(util.NodePublicKeyEnsurePrefix(nodeKeyFromCache)), - ) - if err != nil { - log.Error(). - Str("nodeKey", nodeKeyFromCache). - Bool("nodeKeyOK", nodeKeyOK). - Msg("could not parse node public key") - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusBadRequest) - _, werr := writer.Write([]byte("could not parse node public key")) - if werr != nil { - util.LogErr(err, "Failed to write response") - } - - return nil, false, err - } - // retrieve node information if it exist // The error is not important, because if it does not // exist, then this is a new node and we will move diff --git a/hscontrol/poll.go b/hscontrol/poll.go index 9a14c36a..91858a9f 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -8,7 +8,6 @@ import ( "github.com/juanfont/headscale/hscontrol/mapper" "github.com/juanfont/headscale/hscontrol/types" - "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "tailscale.com/tailcfg" ) @@ -91,7 +90,7 @@ func (h *Headscale) handlePoll( node.LastSeen = &now node.Hostname = mapRequest.Hostinfo.Hostname node.HostInfo = types.HostInfo(*mapRequest.Hostinfo) - node.DiscoKey = util.DiscoPublicKeyStripPrefix(mapRequest.DiscoKey) + node.DiscoKey = mapRequest.DiscoKey.String() node.Endpoints = mapRequest.Endpoints if err := h.db.NodeSave(node); err != nil { @@ -144,7 +143,7 @@ func (h *Headscale) handlePoll( node.LastSeen = &now node.Hostname = mapRequest.Hostinfo.Hostname node.HostInfo = types.HostInfo(*mapRequest.Hostinfo) - node.DiscoKey = util.DiscoPublicKeyStripPrefix(mapRequest.DiscoKey) + node.DiscoKey = mapRequest.DiscoKey.String() node.Endpoints = mapRequest.Endpoints // When a node connects to control, list the peers it has at diff --git a/hscontrol/poll_legacy.go b/hscontrol/poll_legacy.go index 2d269e1b..0cf009fa 100644 --- a/hscontrol/poll_legacy.go +++ b/hscontrol/poll_legacy.go @@ -45,7 +45,7 @@ func (h *Headscale) PollNetMapHandler( body, _ := io.ReadAll(req.Body) var machineKey key.MachinePublic - err := machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machineKeyStr))) + err := machineKey.UnmarshalText([]byte("mkey:" + machineKeyStr)) if err != nil { log.Error(). Str("handler", "PollNetMap"). diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index f2b193c7..ae11b719 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -11,7 +11,6 @@ import ( v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/policy/matcher" - "github.com/juanfont/headscale/hscontrol/util" "go4.org/netipx" "google.golang.org/protobuf/types/known/timestamppb" "tailscale.com/tailcfg" @@ -295,7 +294,7 @@ func (node *Node) MachinePublicKey() (key.MachinePublic, error) { if node.MachineKey != "" { err := machineKey.UnmarshalText( - []byte(util.MachinePublicKeyEnsurePrefix(node.MachineKey)), + []byte(node.MachineKey), ) if err != nil { return key.MachinePublic{}, fmt.Errorf("failed to parse machine public key: %w", err) @@ -309,7 +308,7 @@ func (node *Node) DiscoPublicKey() (key.DiscoPublic, error) { var discoKey key.DiscoPublic if node.DiscoKey != "" { err := discoKey.UnmarshalText( - []byte(util.DiscoPublicKeyEnsurePrefix(node.DiscoKey)), + []byte(node.DiscoKey), ) if err != nil { return key.DiscoPublic{}, fmt.Errorf("failed to parse disco public key: %w", err) @@ -323,7 +322,7 @@ func (node *Node) DiscoPublicKey() (key.DiscoPublic, error) { func (node *Node) NodePublicKey() (key.NodePublic, error) { var nodeKey key.NodePublic - err := nodeKey.UnmarshalText([]byte(util.NodePublicKeyEnsurePrefix(node.NodeKey))) + err := nodeKey.UnmarshalText([]byte(node.NodeKey)) if err != nil { return key.NodePublic{}, fmt.Errorf("failed to parse node public key: %w", err) } diff --git a/hscontrol/util/file.go b/hscontrol/util/file.go index 7b424da7..5b8656ff 100644 --- a/hscontrol/util/file.go +++ b/hscontrol/util/file.go @@ -11,11 +11,12 @@ import ( ) const ( - Base8 = 8 - Base10 = 10 - BitSize16 = 16 - BitSize32 = 32 - BitSize64 = 64 + Base8 = 8 + Base10 = 10 + BitSize16 = 16 + BitSize32 = 32 + BitSize64 = 64 + PermissionFallback = 0o700 ) func AbsolutePathFromConfigPath(path string) string { diff --git a/hscontrol/util/key.go b/hscontrol/util/key.go index 4eb1db6c..6501daca 100644 --- a/hscontrol/util/key.go +++ b/hscontrol/util/key.go @@ -4,106 +4,22 @@ import ( "encoding/json" "errors" "regexp" - "strings" "tailscale.com/types/key" ) -const ( - - // These constants are copied from the upstream tailscale.com/types/key - // library, because they are not exported. - // https://github.com/tailscale/tailscale/tree/main/types/key - - // nodePublicHexPrefix is the prefix used to identify a - // hex-encoded node public key. - // - // This prefix is used in the control protocol, so cannot be - // changed. - nodePublicHexPrefix = "nodekey:" - - // machinePublicHexPrefix is the prefix used to identify a - // hex-encoded machine public key. - // - // This prefix is used in the control protocol, so cannot be - // changed. - machinePublicHexPrefix = "mkey:" - - // discoPublicHexPrefix is the prefix used to identify a - // hex-encoded disco public key. - // - // This prefix is used in the control protocol, so cannot be - // changed. - discoPublicHexPrefix = "discokey:" - - // privateKey prefix. - privateHexPrefix = "privkey:" - - PermissionFallback = 0o700 - - ZstdCompression = "zstd" -) - var ( NodePublicKeyRegex = regexp.MustCompile("nodekey:[a-fA-F0-9]+") ErrCannotDecryptResponse = errors.New("cannot decrypt response") + ZstdCompression = "zstd" ) -func MachinePublicKeyStripPrefix(machineKey key.MachinePublic) string { - return strings.TrimPrefix(machineKey.String(), machinePublicHexPrefix) -} - -func NodePublicKeyStripPrefix(nodeKey key.NodePublic) string { - return strings.TrimPrefix(nodeKey.String(), nodePublicHexPrefix) -} - -func DiscoPublicKeyStripPrefix(discoKey key.DiscoPublic) string { - return strings.TrimPrefix(discoKey.String(), discoPublicHexPrefix) -} - -func MachinePublicKeyEnsurePrefix(machineKey string) string { - if !strings.HasPrefix(machineKey, machinePublicHexPrefix) { - return machinePublicHexPrefix + machineKey - } - - return machineKey -} - -func NodePublicKeyEnsurePrefix(nodeKey string) string { - if !strings.HasPrefix(nodeKey, nodePublicHexPrefix) { - return nodePublicHexPrefix + nodeKey - } - - return nodeKey -} - -func DiscoPublicKeyEnsurePrefix(discoKey string) string { - if !strings.HasPrefix(discoKey, discoPublicHexPrefix) { - return discoPublicHexPrefix + discoKey - } - - return discoKey -} - -func PrivateKeyEnsurePrefix(privateKey string) string { - if !strings.HasPrefix(privateKey, privateHexPrefix) { - return privateHexPrefix + privateKey - } - - return privateKey -} - func DecodeAndUnmarshalNaCl( msg []byte, output interface{}, pubKey *key.MachinePublic, privKey *key.MachinePrivate, ) error { - // log.Trace(). - // Str("pubkey", pubKey.ShortString()). - // Int("length", len(msg)). - // Msg("Trying to decrypt") - decrypted, ok := privKey.OpenFrom(*pubKey, msg) if !ok { return ErrCannotDecryptResponse diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index 64aaebb7..5019895a 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -348,6 +348,14 @@ func (t *HeadscaleInContainer) Shutdown() error { ) } + err = t.SaveDatabase("/tmp/control") + if err != nil { + log.Printf( + "Failed to save database from control: %s", + fmt.Errorf("failed to save database from control: %w", err), + ) + } + return t.pool.Purge(t.container) } @@ -393,6 +401,24 @@ func (t *HeadscaleInContainer) SaveMapResponses(savePath string) error { return nil } +func (t *HeadscaleInContainer) SaveDatabase(savePath string) error { + tarFile, err := t.FetchPath("/tmp/integration_test_db.sqlite3") + if err != nil { + return err + } + + err = os.WriteFile( + path.Join(savePath, t.hostname+".db.tar"), + tarFile, + os.ModePerm, + ) + if err != nil { + return err + } + + return nil +} + // Execute runs a command inside the Headscale container and returns the // result of stdout as a string. func (t *HeadscaleInContainer) Execute(