Unmarshal keys in the non-deprecated way

This commit is contained in:
Kristoffer Dalby 2021-11-26 23:50:42 +00:00
parent 0012c76170
commit c38f00fab8
7 changed files with 27 additions and 40 deletions

5
api.go
View file

@ -13,7 +13,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/klauspost/compress/zstd"
"github.com/rs/zerolog/log"
"go4.org/mem"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
@ -74,7 +73,9 @@ func (h *Headscale) RegisterWebAPI(ctx *gin.Context) {
func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
body, _ := io.ReadAll(ctx.Request.Body)
machineKeyStr := ctx.Param("id")
machineKey, err := key.ParseMachinePublicUntyped(mem.S(machineKeyStr))
var machineKey key.MachinePublic
err := machineKey.UnmarshalText([]byte(machineKeyStr))
if err != nil {
log.Error().
Caller().

View file

@ -11,7 +11,6 @@ import (
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/pterm/pterm"
"github.com/spf13/cobra"
"go4.org/mem"
"google.golang.org/grpc/status"
"tailscale.com/types/key"
)
@ -486,7 +485,8 @@ func nodesToPtables(
expiry = machine.Expiry.AsTime()
}
nodeKey, err := key.ParseNodePublicUntyped(mem.S(machine.NodeKey))
var nodeKey key.NodePublic
err := nodeKey.UnmarshalText([]byte(machine.NodeKey))
if err != nil {
return nil, err
}

View file

@ -720,6 +720,7 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() {
[]string{},
)
assert.Nil(s.T(), err)
fmt.Println("Error: ", err)
var listOnlySharedMachineNamespace []v1.Machine
err = json.Unmarshal(
@ -728,6 +729,8 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() {
)
assert.Nil(s.T(), err)
fmt.Println("List: ", listOnlySharedMachineNamespaceResult)
fmt.Println("List2: ", listOnlySharedMachineNamespace)
assert.Len(s.T(), listOnlySharedMachineNamespace, 2)
assert.Equal(s.T(), uint64(6), listOnlySharedMachineNamespace[0].Id)

View file

@ -12,7 +12,6 @@ import (
"github.com/fatih/set"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/rs/zerolog/log"
"go4.org/mem"
"google.golang.org/protobuf/types/known/timestamppb"
"gorm.io/datatypes"
"inet.af/netaddr"
@ -439,7 +438,8 @@ func (machine Machine) toNode(
dnsConfig *tailcfg.DNSConfig,
includeRoutes bool,
) (*tailcfg.Node, error) {
nodeKey, err := key.ParseNodePublicUntyped(mem.S(machine.NodeKey))
var nodeKey key.NodePublic
err := nodeKey.UnmarshalText([]byte(machine.NodeKey))
if err != nil {
log.Trace().
Caller().
@ -449,19 +449,18 @@ func (machine Machine) toNode(
return nil, fmt.Errorf("failed to parse node public key: %w", err)
}
machineKey, err := key.ParseMachinePublicUntyped(mem.S(machine.MachineKey))
var machineKey key.MachinePublic
err = machineKey.UnmarshalText([]byte(machine.MachineKey))
if err != nil {
return nil, fmt.Errorf("failed to parse machine public key: %w", err)
}
var discoKey key.DiscoPublic
if machine.DiscoKey != "" {
dKey := key.DiscoPublic{}
err := dKey.UnmarshalText([]byte(discoPublicHexPrefix + machine.DiscoKey))
err := discoKey.UnmarshalText([]byte(discoPublicHexPrefix + machine.DiscoKey))
if err != nil {
return nil, fmt.Errorf("failed to parse disco public key: %w", err)
}
discoKey = key.DiscoPublic(dKey)
} else {
discoKey = key.DiscoPublic{}
}
@ -634,7 +633,8 @@ func (h *Headscale) RegisterMachine(
return nil, err
}
machineKey, err := key.ParseMachinePublicUntyped(mem.S(machineKeyStr))
var machineKey key.MachinePublic
err = machineKey.UnmarshalText([]byte(machineKeyStr))
if err != nil {
return nil, err
}

View file

@ -15,7 +15,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/patrickmn/go-cache"
"github.com/rs/zerolog/log"
"go4.org/mem"
"golang.org/x/oauth2"
"gorm.io/gorm"
"tailscale.com/types/key"
@ -192,7 +191,8 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
machineKeyStr, machineKeyOK := machineKeyIf.(string)
machineKey, err := key.ParseMachinePublicUntyped(mem.S(machineKeyStr))
var machineKey key.MachinePublic
err = machineKey.UnmarshalText([]byte(machineKeyStr))
if err != nil {
log.Error().
Msg("could not parse machine public key")

19
poll.go
View file

@ -9,7 +9,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/rs/zerolog/log"
"go4.org/mem"
"gorm.io/datatypes"
"gorm.io/gorm"
"tailscale.com/tailcfg"
@ -36,8 +35,10 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
Str("id", ctx.Param("id")).
Msg("PollNetMapHandler called")
body, _ := io.ReadAll(ctx.Request.Body)
mKeyStr := ctx.Param("id")
mKey, err := key.ParseMachinePublicUntyped(mem.S(mKeyStr))
machineKeyStr := ctx.Param("id")
var machineKey key.MachinePublic
err := machineKey.UnmarshalText([]byte(machineKeyStr))
if err != nil {
log.Error().
Str("handler", "PollNetMap").
@ -48,7 +49,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
return
}
req := tailcfg.MapRequest{}
err = decode(body, &req, &mKey, h.privateKey)
err = decode(body, &req, &machineKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "PollNetMap").
@ -59,19 +60,19 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
return
}
machine, err := h.GetMachineByMachineKey(mKey)
machine, err := h.GetMachineByMachineKey(machineKey)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Warn().
Str("handler", "PollNetMap").
Msgf("Ignoring request, cannot find machine with key %s", mKey.String())
Msgf("Ignoring request, cannot find machine with key %s", machineKey.String())
ctx.String(http.StatusUnauthorized, "")
return
}
log.Error().
Str("handler", "PollNetMap").
Msgf("Failed to fetch machine from the database with Machine key: %s", mKey.String())
Msgf("Failed to fetch machine from the database with Machine key: %s", machineKey.String())
ctx.String(http.StatusInternalServerError, "")
}
log.Trace().
@ -101,7 +102,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
}
h.db.Save(&machine)
data, err := h.getMapResponse(mKey, req, machine)
data, err := h.getMapResponse(machineKey, req, machine)
if err != nil {
log.Error().
Str("handler", "PollNetMap").
@ -206,7 +207,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
ctx,
machine,
req,
mKey,
machineKey,
pollDataChan,
keepAliveChan,
updateChan,

View file

@ -20,22 +20,12 @@ import (
const (
errCannotDecryptReponse = Error("cannot decrypt response")
errResponseMissingNonce = Error("response missing nonce")
errCouldNotAllocateIP = Error("could not find any suitable IP")
// These constants are copied from the upstream tailscale.com/types/key
// library, because they are not exported.
// https://github.com/tailscale/tailscale/tree/main/types/key
// nodePrivateHexPrefix is the prefix used to identify a
// hex-encoded node private key.
//
// This prefix name is a little unfortunate, in that it comes from
// WireGuard's own key types, and we've used it for both key types
// we persist to disk (machine and node keys). But we're stuck
// with it for now, barring another round of tricky migration.
nodePrivateHexPrefix = "privkey:"
// nodePublicHexPrefix is the prefix used to identify a
// hex-encoded node public key.
//
@ -43,14 +33,6 @@ const (
// changed.
nodePublicHexPrefix = "nodekey:"
// machinePrivateHexPrefix is the prefix used to identify a
// hex-encoded machine private key.
//
// This prefix name is a little unfortunate, in that it comes from
// WireGuard's own key types. Unfortunately we're stuck with it for
// machine keys, because we serialize them to disk with this prefix.
machinePrivateHexPrefix = "privkey:"
// machinePublicHexPrefix is the prefix used to identify a
// hex-encoded machine public key.
//