Unmarshal keys in the non-deprecated way

This commit is contained in:
Kristoffer Dalby 2021-11-26 23:50:42 +00:00
parent 0012c76170
commit c38f00fab8
7 changed files with 27 additions and 40 deletions

5
api.go
View file

@ -13,7 +13,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/klauspost/compress/zstd" "github.com/klauspost/compress/zstd"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"go4.org/mem"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key" "tailscale.com/types/key"
@ -74,7 +73,9 @@ func (h *Headscale) RegisterWebAPI(ctx *gin.Context) {
func (h *Headscale) RegistrationHandler(ctx *gin.Context) { func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
body, _ := io.ReadAll(ctx.Request.Body) body, _ := io.ReadAll(ctx.Request.Body)
machineKeyStr := ctx.Param("id") machineKeyStr := ctx.Param("id")
machineKey, err := key.ParseMachinePublicUntyped(mem.S(machineKeyStr))
var machineKey key.MachinePublic
err := machineKey.UnmarshalText([]byte(machineKeyStr))
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().

View file

@ -11,7 +11,6 @@ import (
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/pterm/pterm" "github.com/pterm/pterm"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"go4.org/mem"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"tailscale.com/types/key" "tailscale.com/types/key"
) )
@ -486,7 +485,8 @@ func nodesToPtables(
expiry = machine.Expiry.AsTime() expiry = machine.Expiry.AsTime()
} }
nodeKey, err := key.ParseNodePublicUntyped(mem.S(machine.NodeKey)) var nodeKey key.NodePublic
err := nodeKey.UnmarshalText([]byte(machine.NodeKey))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -720,6 +720,7 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() {
[]string{}, []string{},
) )
assert.Nil(s.T(), err) assert.Nil(s.T(), err)
fmt.Println("Error: ", err)
var listOnlySharedMachineNamespace []v1.Machine var listOnlySharedMachineNamespace []v1.Machine
err = json.Unmarshal( err = json.Unmarshal(
@ -728,6 +729,8 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() {
) )
assert.Nil(s.T(), err) assert.Nil(s.T(), err)
fmt.Println("List: ", listOnlySharedMachineNamespaceResult)
fmt.Println("List2: ", listOnlySharedMachineNamespace)
assert.Len(s.T(), listOnlySharedMachineNamespace, 2) assert.Len(s.T(), listOnlySharedMachineNamespace, 2)
assert.Equal(s.T(), uint64(6), listOnlySharedMachineNamespace[0].Id) assert.Equal(s.T(), uint64(6), listOnlySharedMachineNamespace[0].Id)

View file

@ -12,7 +12,6 @@ import (
"github.com/fatih/set" "github.com/fatih/set"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"go4.org/mem"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
"gorm.io/datatypes" "gorm.io/datatypes"
"inet.af/netaddr" "inet.af/netaddr"
@ -439,7 +438,8 @@ func (machine Machine) toNode(
dnsConfig *tailcfg.DNSConfig, dnsConfig *tailcfg.DNSConfig,
includeRoutes bool, includeRoutes bool,
) (*tailcfg.Node, error) { ) (*tailcfg.Node, error) {
nodeKey, err := key.ParseNodePublicUntyped(mem.S(machine.NodeKey)) var nodeKey key.NodePublic
err := nodeKey.UnmarshalText([]byte(machine.NodeKey))
if err != nil { if err != nil {
log.Trace(). log.Trace().
Caller(). Caller().
@ -449,19 +449,18 @@ func (machine Machine) toNode(
return nil, fmt.Errorf("failed to parse node public key: %w", err) return nil, fmt.Errorf("failed to parse node public key: %w", err)
} }
machineKey, err := key.ParseMachinePublicUntyped(mem.S(machine.MachineKey)) var machineKey key.MachinePublic
err = machineKey.UnmarshalText([]byte(machine.MachineKey))
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse machine public key: %w", err) return nil, fmt.Errorf("failed to parse machine public key: %w", err)
} }
var discoKey key.DiscoPublic var discoKey key.DiscoPublic
if machine.DiscoKey != "" { if machine.DiscoKey != "" {
dKey := key.DiscoPublic{} err := discoKey.UnmarshalText([]byte(discoPublicHexPrefix + machine.DiscoKey))
err := dKey.UnmarshalText([]byte(discoPublicHexPrefix + machine.DiscoKey))
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse disco public key: %w", err) return nil, fmt.Errorf("failed to parse disco public key: %w", err)
} }
discoKey = key.DiscoPublic(dKey)
} else { } else {
discoKey = key.DiscoPublic{} discoKey = key.DiscoPublic{}
} }
@ -634,7 +633,8 @@ func (h *Headscale) RegisterMachine(
return nil, err return nil, err
} }
machineKey, err := key.ParseMachinePublicUntyped(mem.S(machineKeyStr)) var machineKey key.MachinePublic
err = machineKey.UnmarshalText([]byte(machineKeyStr))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -15,7 +15,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/patrickmn/go-cache" "github.com/patrickmn/go-cache"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"go4.org/mem"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/types/key" "tailscale.com/types/key"
@ -192,7 +191,8 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
machineKeyStr, machineKeyOK := machineKeyIf.(string) machineKeyStr, machineKeyOK := machineKeyIf.(string)
machineKey, err := key.ParseMachinePublicUntyped(mem.S(machineKeyStr)) var machineKey key.MachinePublic
err = machineKey.UnmarshalText([]byte(machineKeyStr))
if err != nil { if err != nil {
log.Error(). log.Error().
Msg("could not parse machine public key") Msg("could not parse machine public key")

19
poll.go
View file

@ -9,7 +9,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"go4.org/mem"
"gorm.io/datatypes" "gorm.io/datatypes"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
@ -36,8 +35,10 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
Str("id", ctx.Param("id")). Str("id", ctx.Param("id")).
Msg("PollNetMapHandler called") Msg("PollNetMapHandler called")
body, _ := io.ReadAll(ctx.Request.Body) body, _ := io.ReadAll(ctx.Request.Body)
mKeyStr := ctx.Param("id") machineKeyStr := ctx.Param("id")
mKey, err := key.ParseMachinePublicUntyped(mem.S(mKeyStr))
var machineKey key.MachinePublic
err := machineKey.UnmarshalText([]byte(machineKeyStr))
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
@ -48,7 +49,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
return return
} }
req := tailcfg.MapRequest{} req := tailcfg.MapRequest{}
err = decode(body, &req, &mKey, h.privateKey) err = decode(body, &req, &machineKey, h.privateKey)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
@ -59,19 +60,19 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
return return
} }
machine, err := h.GetMachineByMachineKey(mKey) machine, err := h.GetMachineByMachineKey(machineKey)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
log.Warn(). log.Warn().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Msgf("Ignoring request, cannot find machine with key %s", mKey.String()) Msgf("Ignoring request, cannot find machine with key %s", machineKey.String())
ctx.String(http.StatusUnauthorized, "") ctx.String(http.StatusUnauthorized, "")
return return
} }
log.Error(). log.Error().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Msgf("Failed to fetch machine from the database with Machine key: %s", mKey.String()) Msgf("Failed to fetch machine from the database with Machine key: %s", machineKey.String())
ctx.String(http.StatusInternalServerError, "") ctx.String(http.StatusInternalServerError, "")
} }
log.Trace(). log.Trace().
@ -101,7 +102,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
} }
h.db.Save(&machine) h.db.Save(&machine)
data, err := h.getMapResponse(mKey, req, machine) data, err := h.getMapResponse(machineKey, req, machine)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
@ -206,7 +207,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
ctx, ctx,
machine, machine,
req, req,
mKey, machineKey,
pollDataChan, pollDataChan,
keepAliveChan, keepAliveChan,
updateChan, updateChan,

View file

@ -20,22 +20,12 @@ import (
const ( const (
errCannotDecryptReponse = Error("cannot decrypt response") errCannotDecryptReponse = Error("cannot decrypt response")
errResponseMissingNonce = Error("response missing nonce")
errCouldNotAllocateIP = Error("could not find any suitable IP") errCouldNotAllocateIP = Error("could not find any suitable IP")
// These constants are copied from the upstream tailscale.com/types/key // These constants are copied from the upstream tailscale.com/types/key
// library, because they are not exported. // library, because they are not exported.
// https://github.com/tailscale/tailscale/tree/main/types/key // https://github.com/tailscale/tailscale/tree/main/types/key
// nodePrivateHexPrefix is the prefix used to identify a
// hex-encoded node private key.
//
// This prefix name is a little unfortunate, in that it comes from
// WireGuard's own key types, and we've used it for both key types
// we persist to disk (machine and node keys). But we're stuck
// with it for now, barring another round of tricky migration.
nodePrivateHexPrefix = "privkey:"
// nodePublicHexPrefix is the prefix used to identify a // nodePublicHexPrefix is the prefix used to identify a
// hex-encoded node public key. // hex-encoded node public key.
// //
@ -43,14 +33,6 @@ const (
// changed. // changed.
nodePublicHexPrefix = "nodekey:" nodePublicHexPrefix = "nodekey:"
// machinePrivateHexPrefix is the prefix used to identify a
// hex-encoded machine private key.
//
// This prefix name is a little unfortunate, in that it comes from
// WireGuard's own key types. Unfortunately we're stuck with it for
// machine keys, because we serialize them to disk with this prefix.
machinePrivateHexPrefix = "privkey:"
// machinePublicHexPrefix is the prefix used to identify a // machinePublicHexPrefix is the prefix used to identify a
// hex-encoded machine public key. // hex-encoded machine public key.
// //