remove the use key stripping and store the proper keys (#1603)

This commit is contained in:
Kristoffer Dalby 2023-11-16 17:55:29 +01:00 committed by GitHub
parent 2af71c9e31
commit c0fd06e3f5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 99 additions and 198 deletions

View file

@ -529,7 +529,7 @@ func nodesToPtables(
var machineKey key.MachinePublic var machineKey key.MachinePublic
err := machineKey.UnmarshalText( err := machineKey.UnmarshalText(
[]byte(util.MachinePublicKeyEnsurePrefix(node.MachineKey)), []byte(node.MachineKey),
) )
if err != nil { if err != nil {
machineKey = key.MachinePublic{} machineKey = key.MachinePublic{}
@ -537,7 +537,7 @@ func nodesToPtables(
var nodeKey key.NodePublic var nodeKey key.NodePublic
err = nodeKey.UnmarshalText( err = nodeKey.UnmarshalText(
[]byte(util.NodePublicKeyEnsurePrefix(node.NodeKey)), []byte(node.NodeKey),
) )
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -911,10 +911,9 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
} }
trimmedPrivateKey := strings.TrimSpace(string(privateKey)) trimmedPrivateKey := strings.TrimSpace(string(privateKey))
privateKeyEnsurePrefix := util.PrivateKeyEnsurePrefix(trimmedPrivateKey)
var machineKey key.MachinePrivate var machineKey key.MachinePrivate
if err = machineKey.UnmarshalText([]byte(privateKeyEnsurePrefix)); err != nil { if err = machineKey.UnmarshalText([]byte(trimmedPrivateKey)); err != nil {
log.Info(). log.Info().
Str("path", path). Str("path", path).
Msg("This might be due to a legacy (headscale pre-0.12) private key. " + Msg("This might be due to a legacy (headscale pre-0.12) private key. " +

View file

@ -45,7 +45,7 @@ func (h *Headscale) handleRegister(
// is that the client will hammer headscale with requests until it gets a // is that the client will hammer headscale with requests until it gets a
// successful RegisterResponse. // successful RegisterResponse.
if registerRequest.Followup != "" { if registerRequest.Followup != "" {
if _, ok := h.registrationCache.Get(util.NodePublicKeyStripPrefix(registerRequest.NodeKey)); ok { if _, ok := h.registrationCache.Get(registerRequest.NodeKey.String()); ok {
log.Debug(). log.Debug().
Caller(). Caller().
Str("node", registerRequest.Hostinfo.Hostname). Str("node", registerRequest.Hostinfo.Hostname).
@ -97,10 +97,10 @@ func (h *Headscale) handleRegister(
// We create the node and then keep it around until a callback // We create the node and then keep it around until a callback
// happens // happens
newNode := types.Node{ newNode := types.Node{
MachineKey: util.MachinePublicKeyStripPrefix(machineKey), MachineKey: machineKey.String(),
Hostname: registerRequest.Hostinfo.Hostname, Hostname: registerRequest.Hostinfo.Hostname,
GivenName: givenName, GivenName: givenName,
NodeKey: util.NodePublicKeyStripPrefix(registerRequest.NodeKey), NodeKey: registerRequest.NodeKey.String(),
LastSeen: &now, LastSeen: &now,
Expiry: &time.Time{}, Expiry: &time.Time{},
} }
@ -136,7 +136,7 @@ func (h *Headscale) handleRegister(
// So if we have a not valid MachineKey (but we were able to fetch the node with the NodeKeys), we update it. // So if we have a not valid MachineKey (but we were able to fetch the node with the NodeKeys), we update it.
var storedMachineKey key.MachinePublic var storedMachineKey key.MachinePublic
err = storedMachineKey.UnmarshalText( err = storedMachineKey.UnmarshalText(
[]byte(util.MachinePublicKeyEnsurePrefix(node.MachineKey)), []byte(node.MachineKey),
) )
if err != nil || storedMachineKey.IsZero() { if err != nil || storedMachineKey.IsZero() {
if err := h.db.NodeSetMachineKey(node, machineKey); err != nil { if err := h.db.NodeSetMachineKey(node, machineKey); err != nil {
@ -156,7 +156,7 @@ func (h *Headscale) handleRegister(
// - Trying to log out (sending a expiry in the past) // - Trying to log out (sending a expiry in the past)
// - A valid, registered node, looking for /map // - A valid, registered node, looking for /map
// - Expired node wanting to reauthenticate // - Expired node wanting to reauthenticate
if node.NodeKey == util.NodePublicKeyStripPrefix(registerRequest.NodeKey) { if node.NodeKey == registerRequest.NodeKey.String() {
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout) // The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648 // https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
if !registerRequest.Expiry.IsZero() && if !registerRequest.Expiry.IsZero() &&
@ -176,7 +176,7 @@ func (h *Headscale) handleRegister(
} }
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration // The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
if node.NodeKey == util.NodePublicKeyStripPrefix(registerRequest.OldNodeKey) && if node.NodeKey == registerRequest.OldNodeKey.String() &&
!node.IsExpired() { !node.IsExpired() {
h.handleNodeKeyRefresh( h.handleNodeKeyRefresh(
writer, writer,
@ -207,9 +207,9 @@ func (h *Headscale) handleRegister(
// we need to make sure the NodeKey matches the one in the request // we need to make sure the NodeKey matches the one in the request
// TODO(juan): What happens when using fast user switching between two // TODO(juan): What happens when using fast user switching between two
// headscale-managed tailnets? // headscale-managed tailnets?
node.NodeKey = util.NodePublicKeyStripPrefix(registerRequest.NodeKey) node.NodeKey = registerRequest.NodeKey.String()
h.registrationCache.Set( h.registrationCache.Set(
util.NodePublicKeyStripPrefix(registerRequest.NodeKey), registerRequest.NodeKey.String(),
*node, *node,
registerCacheExpiration, registerCacheExpiration,
) )
@ -294,7 +294,7 @@ func (h *Headscale) handleAuthKey(
Str("node", registerRequest.Hostinfo.Hostname). Str("node", registerRequest.Hostinfo.Hostname).
Msg("Authentication key was valid, proceeding to acquire IP addresses") Msg("Authentication key was valid, proceeding to acquire IP addresses")
nodeKey := util.NodePublicKeyStripPrefix(registerRequest.NodeKey) nodeKey := registerRequest.NodeKey.String()
// retrieve node information if it exist // retrieve node information if it exist
// The error is not important, because if it does not // The error is not important, because if it does not
@ -342,7 +342,7 @@ func (h *Headscale) handleAuthKey(
} else { } else {
now := time.Now().UTC() now := time.Now().UTC()
givenName, err := h.db.GenerateGivenName(util.MachinePublicKeyStripPrefix(machineKey), registerRequest.Hostinfo.Hostname) givenName, err := h.db.GenerateGivenName(machineKey.String(), registerRequest.Hostinfo.Hostname)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
@ -359,7 +359,7 @@ func (h *Headscale) handleAuthKey(
Hostname: registerRequest.Hostinfo.Hostname, Hostname: registerRequest.Hostinfo.Hostname,
GivenName: givenName, GivenName: givenName,
UserID: pak.User.ID, UserID: pak.User.ID,
MachineKey: util.MachinePublicKeyStripPrefix(machineKey), MachineKey: machineKey.String(),
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
Expiry: &registerRequest.Expiry, Expiry: &registerRequest.Expiry,
NodeKey: nodeKey, NodeKey: nodeKey,

View file

@ -33,7 +33,7 @@ func (h *Headscale) RegistrationHandler(
body, _ := io.ReadAll(req.Body) body, _ := io.ReadAll(req.Body)
var machineKey key.MachinePublic var machineKey key.MachinePublic
err := machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machineKeyStr))) err := machineKey.UnmarshalText([]byte("mkey:" + machineKeyStr))
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().

View file

@ -35,9 +35,6 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
node := types.Node{ node := types.Node{
ID: 0, ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testnode", Hostname: "testnode",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
@ -83,9 +80,6 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
node := types.Node{ node := types.Node{
ID: uint64(index), ID: uint64(index),
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testnode", Hostname: "testnode",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
@ -173,9 +167,6 @@ func (s *Suite) TestGetAvailableIpNodeWithoutIP(c *check.C) {
node := types.Node{ node := types.Node{
ID: 0, ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testnode", Hostname: "testnode",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,

View file

@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/netip" "net/netip"
"strings"
"sync" "sync"
"time" "time"
@ -252,6 +253,27 @@ func NewHeadscaleDatabase(
return nil, err return nil, err
} }
// Ensure all keys have correct prefixes
// https://github.com/tailscale/tailscale/blob/main/types/key/node.go#L35
nodes := types.Nodes{}
if err := dbConn.Find(&nodes).Error; err != nil {
log.Error().Err(err).Msg("Error accessing db")
}
for _, node := range nodes {
if !strings.HasPrefix(node.DiscoKey, "discokey:") {
node.DiscoKey = "discokey:" + node.DiscoKey
}
if !strings.HasPrefix(node.NodeKey, "nodekey:") {
node.NodeKey = "nodekey:" + node.NodeKey
}
if !strings.HasPrefix(node.MachineKey, "mkey:") {
node.MachineKey = "mkey:" + node.MachineKey
}
}
// TODO(kradalby): is this needed? // TODO(kradalby): is this needed?
err = db.setValue("db_version", dbVersion) err = db.setValue("db_version", dbVersion)

View file

@ -182,7 +182,7 @@ func (hsdb *HSDatabase) GetNodeByMachineKey(
Preload("AuthKey.User"). Preload("AuthKey.User").
Preload("User"). Preload("User").
Preload("Routes"). Preload("Routes").
First(&mach, "machine_key = ?", util.MachinePublicKeyStripPrefix(machineKey)); result.Error != nil { First(&mach, "machine_key = ?", machineKey.String()); result.Error != nil {
return nil, result.Error return nil, result.Error
} }
@ -203,7 +203,7 @@ func (hsdb *HSDatabase) GetNodeByNodeKey(
Preload("User"). Preload("User").
Preload("Routes"). Preload("Routes").
First(&node, "node_key = ?", First(&node, "node_key = ?",
util.NodePublicKeyStripPrefix(nodeKey)); result.Error != nil { nodeKey.String()); result.Error != nil {
return nil, result.Error return nil, result.Error
} }
@ -224,9 +224,9 @@ func (hsdb *HSDatabase) GetNodeByAnyKey(
Preload("User"). Preload("User").
Preload("Routes"). Preload("Routes").
First(&node, "machine_key = ? OR node_key = ? OR node_key = ?", First(&node, "machine_key = ? OR node_key = ? OR node_key = ?",
util.MachinePublicKeyStripPrefix(machineKey), machineKey.String(),
util.NodePublicKeyStripPrefix(nodeKey), nodeKey.String(),
util.NodePublicKeyStripPrefix(oldNodeKey)); result.Error != nil { oldNodeKey.String()); result.Error != nil {
return nil, result.Error return nil, result.Error
} }
@ -397,7 +397,7 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)). Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)).
Msg("Registering node from API/CLI or auth callback") Msg("Registering node from API/CLI or auth callback")
if nodeInterface, ok := cache.Get(util.NodePublicKeyStripPrefix(nodeKey)); ok { if nodeInterface, ok := cache.Get(nodeKey.String()); ok {
if registrationNode, ok := nodeInterface.(types.Node); ok { if registrationNode, ok := nodeInterface.(types.Node); ok {
user, err := hsdb.getUser(userName) user, err := hsdb.getUser(userName)
if err != nil { if err != nil {
@ -507,7 +507,7 @@ func (hsdb *HSDatabase) NodeSetNodeKey(node *types.Node, nodeKey key.NodePublic)
defer hsdb.mu.Unlock() defer hsdb.mu.Unlock()
if err := hsdb.db.Model(node).Updates(types.Node{ if err := hsdb.db.Model(node).Updates(types.Node{
NodeKey: util.NodePublicKeyStripPrefix(nodeKey), NodeKey: nodeKey.String(),
}).Error; err != nil { }).Error; err != nil {
return err return err
} }
@ -524,7 +524,7 @@ func (hsdb *HSDatabase) NodeSetMachineKey(
defer hsdb.mu.Unlock() defer hsdb.mu.Unlock()
if err := hsdb.db.Model(node).Updates(types.Node{ if err := hsdb.db.Model(node).Updates(types.Node{
MachineKey: util.MachinePublicKeyStripPrefix(machineKey), MachineKey: machineKey.String(),
}).Error; err != nil { }).Error; err != nil {
return err return err
} }

View file

@ -82,8 +82,8 @@ func (s *Suite) TestGetNodeByNodeKey(c *check.C) {
node := types.Node{ node := types.Node{
ID: 0, ID: 0,
MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()), MachineKey: machineKey.Public().String(),
NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()), NodeKey: nodeKey.Public().String(),
DiscoKey: "faa", DiscoKey: "faa",
Hostname: "testnode", Hostname: "testnode",
UserID: user.ID, UserID: user.ID,
@ -113,8 +113,8 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) {
node := types.Node{ node := types.Node{
ID: 0, ID: 0,
MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()), MachineKey: machineKey.Public().String(),
NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()), NodeKey: nodeKey.Public().String(),
DiscoKey: "faa", DiscoKey: "faa",
Hostname: "testnode", Hostname: "testnode",
UserID: user.ID, UserID: user.ID,
@ -575,7 +575,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
node := types.Node{ node := types.Node{
ID: 0, ID: 0,
MachineKey: "foo", MachineKey: "foo",
NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()), NodeKey: nodeKey.Public().String(),
DiscoKey: "faa", DiscoKey: "faa",
Hostname: "test", Hostname: "test",
UserID: user.ID, UserID: user.ID,

View file

@ -77,9 +77,6 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) {
node := types.Node{ node := types.Node{
ID: 0, ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testest", Hostname: "testest",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
@ -101,9 +98,6 @@ func (*Suite) TestReusableBeingUsedKey(c *check.C) {
node := types.Node{ node := types.Node{
ID: 1, ID: 1,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testest", Hostname: "testest",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
@ -138,9 +132,6 @@ func (*Suite) TestEphemeralKey(c *check.C) {
now := time.Now().Add(-time.Second * 30) now := time.Now().Add(-time.Second * 30)
node := types.Node{ node := types.Node{
ID: 0, ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testest", Hostname: "testest",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,

View file

@ -29,9 +29,6 @@ func (s *Suite) TestGetRoutes(c *check.C) {
node := types.Node{ node := types.Node{
ID: 0, ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "test_get_route_node", Hostname: "test_get_route_node",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
@ -80,9 +77,6 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
node := types.Node{ node := types.Node{
ID: 0, ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "test_enable_route_node", Hostname: "test_enable_route_node",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
@ -154,9 +148,6 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
} }
node1 := types.Node{ node1 := types.Node{
ID: 1, ID: 1,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "test_enable_route_node", Hostname: "test_enable_route_node",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
@ -179,9 +170,6 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
} }
node2 := types.Node{ node2 := types.Node{
ID: 2, ID: 2,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "test_enable_route_node", Hostname: "test_enable_route_node",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
@ -240,9 +228,6 @@ func (s *Suite) TestSubnetFailover(c *check.C) {
now := time.Now() now := time.Now()
node1 := types.Node{ node1 := types.Node{
ID: 1, ID: 1,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "test_enable_route_node", Hostname: "test_enable_route_node",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
@ -277,9 +262,6 @@ func (s *Suite) TestSubnetFailover(c *check.C) {
} }
node2 := types.Node{ node2 := types.Node{
ID: 2, ID: 2,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "test_enable_route_node", Hostname: "test_enable_route_node",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
@ -382,9 +364,6 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
now := time.Now() now := time.Now()
node1 := types.Node{ node1 := types.Node{
ID: 1, ID: 1,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "test_enable_route_node", Hostname: "test_enable_route_node",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,

View file

@ -48,9 +48,6 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) {
node := types.Node{ node := types.Node{
ID: 0, ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testnode", Hostname: "testnode",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
@ -103,9 +100,6 @@ func (s *Suite) TestSetMachineUser(c *check.C) {
node := types.Node{ node := types.Node{
ID: 0, ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testnode", Hostname: "testnode",
UserID: oldUser.ID, UserID: oldUser.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,

View file

@ -545,7 +545,7 @@ func (api headscaleV1APIServer) DebugCreateNode(
} }
api.h.registrationCache.Set( api.h.registrationCache.Set(
util.NodePublicKeyStripPrefix(nodeKey), nodeKey.String(),
newNode, newNode,
registerCacheExpiration, registerCacheExpiration,
) )

View file

@ -8,6 +8,7 @@ import (
"html/template" "html/template"
"net/http" "net/http"
"strconv" "strconv"
"strings"
"time" "time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
@ -71,7 +72,7 @@ func (h *Headscale) KeyHandler(
writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusOK) writer.WriteHeader(http.StatusOK)
_, err := writer.Write( _, err := writer.Write(
[]byte(util.MachinePublicKeyStripPrefix(h.privateKey2019.Public())), []byte(strings.TrimPrefix(h.privateKey2019.Public().String(), "mkey:")),
) )
if err != nil { if err != nil {
log.Error(). log.Error().
@ -229,7 +230,7 @@ func (h *Headscale) RegisterWebAPI(
// the template and log an error. // the template and log an error.
var nodeKey key.NodePublic var nodeKey key.NodePublic
err := nodeKey.UnmarshalText( err := nodeKey.UnmarshalText(
[]byte(util.NodePublicKeyEnsurePrefix(nodeKeyStr)), []byte(nodeKeyStr),
) )
if !ok || nodeKeyStr == "" || err != nil { if !ok || nodeKeyStr == "" || err != nil {

View file

@ -369,7 +369,7 @@ func (m *Mapper) marshalMapResponse(
atomic.AddUint64(&m.seq, 1) atomic.AddUint64(&m.seq, 1)
var machineKey key.MachinePublic var machineKey key.MachinePublic
err := machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(node.MachineKey))) err := machineKey.UnmarshalText([]byte(node.MachineKey))
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().

View file

@ -122,7 +122,7 @@ func (h *Headscale) RegisterOIDC(
// the template and log an error. // the template and log an error.
var nodeKey key.NodePublic var nodeKey key.NodePublic
err := nodeKey.UnmarshalText( err := nodeKey.UnmarshalText(
[]byte(util.NodePublicKeyEnsurePrefix(nodeKeyStr)), []byte(nodeKeyStr),
) )
if !ok || nodeKeyStr == "" || err != nil { if !ok || nodeKeyStr == "" || err != nil {
@ -154,7 +154,7 @@ func (h *Headscale) RegisterOIDC(
// place the node key into the state cache, so it can be retrieved later // place the node key into the state cache, so it can be retrieved later
h.registrationCache.Set( h.registrationCache.Set(
stateStr, stateStr,
util.NodePublicKeyStripPrefix(nodeKey), nodeKey,
registerCacheExpiration, registerCacheExpiration,
) )
@ -479,10 +479,11 @@ func (h *Headscale) validateNodeForOIDCCallback(
} }
var nodeKey key.NodePublic var nodeKey key.NodePublic
nodeKeyFromCache, nodeKeyOK := nodeKeyIf.(string) nodeKey, nodeKeyOK := nodeKeyIf.(key.NodePublic)
if !nodeKeyOK { if !nodeKeyOK {
log.Trace(). log.Trace().
Msg("requested node state key is not a string") Interface("got", nodeKeyIf).
Msg("requested node state key is not a nodekey")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest) writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("state is invalid")) _, err := writer.Write([]byte("state is invalid"))
@ -493,24 +494,6 @@ func (h *Headscale) validateNodeForOIDCCallback(
return nil, false, errOIDCInvalidNodeState return nil, false, errOIDCInvalidNodeState
} }
err := nodeKey.UnmarshalText(
[]byte(util.NodePublicKeyEnsurePrefix(nodeKeyFromCache)),
)
if err != nil {
log.Error().
Str("nodeKey", nodeKeyFromCache).
Bool("nodeKeyOK", nodeKeyOK).
Msg("could not parse node public key")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, werr := writer.Write([]byte("could not parse node public key"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
return nil, false, err
}
// retrieve node information if it exist // retrieve node information if it exist
// The error is not important, because if it does not // The error is not important, because if it does not
// exist, then this is a new node and we will move // exist, then this is a new node and we will move

View file

@ -8,7 +8,6 @@ import (
"github.com/juanfont/headscale/hscontrol/mapper" "github.com/juanfont/headscale/hscontrol/mapper"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
) )
@ -91,7 +90,7 @@ func (h *Headscale) handlePoll(
node.LastSeen = &now node.LastSeen = &now
node.Hostname = mapRequest.Hostinfo.Hostname node.Hostname = mapRequest.Hostinfo.Hostname
node.HostInfo = types.HostInfo(*mapRequest.Hostinfo) node.HostInfo = types.HostInfo(*mapRequest.Hostinfo)
node.DiscoKey = util.DiscoPublicKeyStripPrefix(mapRequest.DiscoKey) node.DiscoKey = mapRequest.DiscoKey.String()
node.Endpoints = mapRequest.Endpoints node.Endpoints = mapRequest.Endpoints
if err := h.db.NodeSave(node); err != nil { if err := h.db.NodeSave(node); err != nil {
@ -144,7 +143,7 @@ func (h *Headscale) handlePoll(
node.LastSeen = &now node.LastSeen = &now
node.Hostname = mapRequest.Hostinfo.Hostname node.Hostname = mapRequest.Hostinfo.Hostname
node.HostInfo = types.HostInfo(*mapRequest.Hostinfo) node.HostInfo = types.HostInfo(*mapRequest.Hostinfo)
node.DiscoKey = util.DiscoPublicKeyStripPrefix(mapRequest.DiscoKey) node.DiscoKey = mapRequest.DiscoKey.String()
node.Endpoints = mapRequest.Endpoints node.Endpoints = mapRequest.Endpoints
// When a node connects to control, list the peers it has at // When a node connects to control, list the peers it has at

View file

@ -45,7 +45,7 @@ func (h *Headscale) PollNetMapHandler(
body, _ := io.ReadAll(req.Body) body, _ := io.ReadAll(req.Body)
var machineKey key.MachinePublic var machineKey key.MachinePublic
err := machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machineKeyStr))) err := machineKey.UnmarshalText([]byte("mkey:" + machineKeyStr))
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").

View file

@ -11,7 +11,6 @@ import (
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/policy/matcher" "github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/util"
"go4.org/netipx" "go4.org/netipx"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
@ -295,7 +294,7 @@ func (node *Node) MachinePublicKey() (key.MachinePublic, error) {
if node.MachineKey != "" { if node.MachineKey != "" {
err := machineKey.UnmarshalText( err := machineKey.UnmarshalText(
[]byte(util.MachinePublicKeyEnsurePrefix(node.MachineKey)), []byte(node.MachineKey),
) )
if err != nil { if err != nil {
return key.MachinePublic{}, fmt.Errorf("failed to parse machine public key: %w", err) return key.MachinePublic{}, fmt.Errorf("failed to parse machine public key: %w", err)
@ -309,7 +308,7 @@ func (node *Node) DiscoPublicKey() (key.DiscoPublic, error) {
var discoKey key.DiscoPublic var discoKey key.DiscoPublic
if node.DiscoKey != "" { if node.DiscoKey != "" {
err := discoKey.UnmarshalText( err := discoKey.UnmarshalText(
[]byte(util.DiscoPublicKeyEnsurePrefix(node.DiscoKey)), []byte(node.DiscoKey),
) )
if err != nil { if err != nil {
return key.DiscoPublic{}, fmt.Errorf("failed to parse disco public key: %w", err) return key.DiscoPublic{}, fmt.Errorf("failed to parse disco public key: %w", err)
@ -323,7 +322,7 @@ func (node *Node) DiscoPublicKey() (key.DiscoPublic, error) {
func (node *Node) NodePublicKey() (key.NodePublic, error) { func (node *Node) NodePublicKey() (key.NodePublic, error) {
var nodeKey key.NodePublic var nodeKey key.NodePublic
err := nodeKey.UnmarshalText([]byte(util.NodePublicKeyEnsurePrefix(node.NodeKey))) err := nodeKey.UnmarshalText([]byte(node.NodeKey))
if err != nil { if err != nil {
return key.NodePublic{}, fmt.Errorf("failed to parse node public key: %w", err) return key.NodePublic{}, fmt.Errorf("failed to parse node public key: %w", err)
} }

View file

@ -16,6 +16,7 @@ const (
BitSize16 = 16 BitSize16 = 16
BitSize32 = 32 BitSize32 = 32
BitSize64 = 64 BitSize64 = 64
PermissionFallback = 0o700
) )
func AbsolutePathFromConfigPath(path string) string { func AbsolutePathFromConfigPath(path string) string {

View file

@ -4,106 +4,22 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"regexp" "regexp"
"strings"
"tailscale.com/types/key" "tailscale.com/types/key"
) )
const (
// These constants are copied from the upstream tailscale.com/types/key
// library, because they are not exported.
// https://github.com/tailscale/tailscale/tree/main/types/key
// nodePublicHexPrefix is the prefix used to identify a
// hex-encoded node public key.
//
// This prefix is used in the control protocol, so cannot be
// changed.
nodePublicHexPrefix = "nodekey:"
// machinePublicHexPrefix is the prefix used to identify a
// hex-encoded machine public key.
//
// This prefix is used in the control protocol, so cannot be
// changed.
machinePublicHexPrefix = "mkey:"
// discoPublicHexPrefix is the prefix used to identify a
// hex-encoded disco public key.
//
// This prefix is used in the control protocol, so cannot be
// changed.
discoPublicHexPrefix = "discokey:"
// privateKey prefix.
privateHexPrefix = "privkey:"
PermissionFallback = 0o700
ZstdCompression = "zstd"
)
var ( var (
NodePublicKeyRegex = regexp.MustCompile("nodekey:[a-fA-F0-9]+") NodePublicKeyRegex = regexp.MustCompile("nodekey:[a-fA-F0-9]+")
ErrCannotDecryptResponse = errors.New("cannot decrypt response") ErrCannotDecryptResponse = errors.New("cannot decrypt response")
ZstdCompression = "zstd"
) )
func MachinePublicKeyStripPrefix(machineKey key.MachinePublic) string {
return strings.TrimPrefix(machineKey.String(), machinePublicHexPrefix)
}
func NodePublicKeyStripPrefix(nodeKey key.NodePublic) string {
return strings.TrimPrefix(nodeKey.String(), nodePublicHexPrefix)
}
func DiscoPublicKeyStripPrefix(discoKey key.DiscoPublic) string {
return strings.TrimPrefix(discoKey.String(), discoPublicHexPrefix)
}
func MachinePublicKeyEnsurePrefix(machineKey string) string {
if !strings.HasPrefix(machineKey, machinePublicHexPrefix) {
return machinePublicHexPrefix + machineKey
}
return machineKey
}
func NodePublicKeyEnsurePrefix(nodeKey string) string {
if !strings.HasPrefix(nodeKey, nodePublicHexPrefix) {
return nodePublicHexPrefix + nodeKey
}
return nodeKey
}
func DiscoPublicKeyEnsurePrefix(discoKey string) string {
if !strings.HasPrefix(discoKey, discoPublicHexPrefix) {
return discoPublicHexPrefix + discoKey
}
return discoKey
}
func PrivateKeyEnsurePrefix(privateKey string) string {
if !strings.HasPrefix(privateKey, privateHexPrefix) {
return privateHexPrefix + privateKey
}
return privateKey
}
func DecodeAndUnmarshalNaCl( func DecodeAndUnmarshalNaCl(
msg []byte, msg []byte,
output interface{}, output interface{},
pubKey *key.MachinePublic, pubKey *key.MachinePublic,
privKey *key.MachinePrivate, privKey *key.MachinePrivate,
) error { ) error {
// log.Trace().
// Str("pubkey", pubKey.ShortString()).
// Int("length", len(msg)).
// Msg("Trying to decrypt")
decrypted, ok := privKey.OpenFrom(*pubKey, msg) decrypted, ok := privKey.OpenFrom(*pubKey, msg)
if !ok { if !ok {
return ErrCannotDecryptResponse return ErrCannotDecryptResponse

View file

@ -348,6 +348,14 @@ func (t *HeadscaleInContainer) Shutdown() error {
) )
} }
err = t.SaveDatabase("/tmp/control")
if err != nil {
log.Printf(
"Failed to save database from control: %s",
fmt.Errorf("failed to save database from control: %w", err),
)
}
return t.pool.Purge(t.container) return t.pool.Purge(t.container)
} }
@ -393,6 +401,24 @@ func (t *HeadscaleInContainer) SaveMapResponses(savePath string) error {
return nil return nil
} }
func (t *HeadscaleInContainer) SaveDatabase(savePath string) error {
tarFile, err := t.FetchPath("/tmp/integration_test_db.sqlite3")
if err != nil {
return err
}
err = os.WriteFile(
path.Join(savePath, t.hostname+".db.tar"),
tarFile,
os.ModePerm,
)
if err != nil {
return err
}
return nil
}
// Execute runs a command inside the Headscale container and returns the // Execute runs a command inside the Headscale container and returns the
// result of stdout as a string. // result of stdout as a string.
func (t *HeadscaleInContainer) Execute( func (t *HeadscaleInContainer) Execute(