Add test case and fix nil pointer in preauthkeys command without expiration

This commit is contained in:
Kristoffer Dalby 2021-11-08 08:02:01 +00:00
parent 9a26fa7989
commit dce6b8d72e
3 changed files with 82 additions and 26 deletions

View file

@ -2,13 +2,12 @@ package cli
import ( import (
"fmt" "fmt"
"log"
"strconv" "strconv"
"time" "time"
"github.com/hako/durafmt"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/pterm/pterm" "github.com/pterm/pterm"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
) )
@ -18,7 +17,7 @@ func init() {
preauthkeysCmd.PersistentFlags().StringP("namespace", "n", "", "Namespace") preauthkeysCmd.PersistentFlags().StringP("namespace", "n", "", "Namespace")
err := preauthkeysCmd.MarkPersistentFlagRequired("namespace") err := preauthkeysCmd.MarkPersistentFlagRequired("namespace")
if err != nil { if err != nil {
log.Fatalf(err.Error()) log.Fatal().Err(err).Msg("")
} }
preauthkeysCmd.AddCommand(listPreAuthKeys) preauthkeysCmd.AddCommand(listPreAuthKeys)
preauthkeysCmd.AddCommand(createPreAuthKeyCmd) preauthkeysCmd.AddCommand(createPreAuthKeyCmd)
@ -26,7 +25,7 @@ func init() {
createPreAuthKeyCmd.PersistentFlags().Bool("reusable", false, "Make the preauthkey reusable") createPreAuthKeyCmd.PersistentFlags().Bool("reusable", false, "Make the preauthkey reusable")
createPreAuthKeyCmd.PersistentFlags().Bool("ephemeral", false, "Preauthkey for ephemeral nodes") createPreAuthKeyCmd.PersistentFlags().Bool("ephemeral", false, "Preauthkey for ephemeral nodes")
createPreAuthKeyCmd.Flags(). createPreAuthKeyCmd.Flags().
StringP("expiration", "e", "", "Human-readable expiration of the key (30m, 24h, 365d...)") DurationP("expiration", "e", 24*time.Hour, "Human-readable expiration of the key (30m, 24h, 365d...)")
} }
var preauthkeysCmd = &cobra.Command{ var preauthkeysCmd = &cobra.Command{
@ -92,7 +91,8 @@ var listPreAuthKeys = &cobra.Command{
} }
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render() err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
if err != nil { if err != nil {
log.Fatal(err) ErrorOutput(err, fmt.Sprintf("Failed to render pterm table: %s", err), output)
return
} }
}, },
} }
@ -103,7 +103,7 @@ var createPreAuthKeyCmd = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output") output, _ := cmd.Flags().GetString("output")
n, err := cmd.Flags().GetString("namespace") namespace, err := cmd.Flags().GetString("namespace")
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
return return
@ -112,28 +112,25 @@ var createPreAuthKeyCmd = &cobra.Command{
reusable, _ := cmd.Flags().GetBool("reusable") reusable, _ := cmd.Flags().GetBool("reusable")
ephemeral, _ := cmd.Flags().GetBool("ephemeral") ephemeral, _ := cmd.Flags().GetBool("ephemeral")
e, _ := cmd.Flags().GetString("expiration") request := &v1.CreatePreAuthKeyRequest{
var expiration *time.Time Namespace: namespace,
if e != "" { Resuable: reusable,
duration, err := durafmt.ParseStringShort(e) Ephemeral: ephemeral,
if err != nil {
log.Fatalf("Error parsing expiration: %s", err)
} }
exp := time.Now().UTC().Add(duration.Duration())
expiration = &exp if cmd.Flags().Changed("expiration") {
duration, _ := cmd.Flags().GetDuration("expiration")
expiration := time.Now().UTC().Add(duration)
log.Trace().Dur("expiration", duration).Msg("expiration has been set")
request.Expiration = timestamppb.New(expiration)
} }
ctx, client, conn, cancel := getHeadscaleCLIClient() ctx, client, conn, cancel := getHeadscaleCLIClient()
defer cancel() defer cancel()
defer conn.Close() defer conn.Close()
request := &v1.CreatePreAuthKeyRequest{
Namespace: n,
Resuable: reusable,
Ephemeral: ephemeral,
Expiration: timestamppb.New(*expiration),
}
response, err := client.CreatePreAuthKey(ctx, request) response, err := client.CreatePreAuthKey(ctx, request)
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Cannot create Pre Auth Key: %s\n", err), output) ErrorOutput(err, fmt.Sprintf("Cannot create Pre Auth Key: %s\n", err), output)
@ -155,9 +152,10 @@ var expirePreAuthKeyCmd = &cobra.Command{
}, },
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output") output, _ := cmd.Flags().GetString("output")
n, err := cmd.Flags().GetString("namespace") namespace, err := cmd.Flags().GetString("namespace")
if err != nil { if err != nil {
log.Fatalf("Error getting namespace: %s", err) ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
return
} }
ctx, client, conn, cancel := getHeadscaleCLIClient() ctx, client, conn, cancel := getHeadscaleCLIClient()
@ -165,7 +163,7 @@ var expirePreAuthKeyCmd = &cobra.Command{
defer conn.Close() defer conn.Close()
request := &v1.ExpirePreAuthKeyRequest{ request := &v1.ExpirePreAuthKeyRequest{
Namespace: n, Namespace: namespace,
Key: args[0], Key: args[0],
} }

View file

@ -99,7 +99,11 @@ func (api headscaleV1APIServer) CreatePreAuthKey(
ctx context.Context, ctx context.Context,
request *v1.CreatePreAuthKeyRequest, request *v1.CreatePreAuthKeyRequest,
) (*v1.CreatePreAuthKeyResponse, error) { ) (*v1.CreatePreAuthKeyResponse, error) {
expiration := request.GetExpiration().AsTime() var expiration time.Time
if request.GetExpiration() != nil {
expiration = request.GetExpiration().AsTime()
}
preAuthKey, err := api.h.CreatePreAuthKey( preAuthKey, err := api.h.CreatePreAuthKey(
request.GetNamespace(), request.GetNamespace(),
request.GetResuable(), request.GetResuable(),

View file

@ -298,6 +298,12 @@ func (s *IntegrationCLITestSuite) TestPreAuthKeyCommand() {
assert.True(s.T(), listedPreAuthKeys[3].Expiration.AsTime().After(time.Now())) assert.True(s.T(), listedPreAuthKeys[3].Expiration.AsTime().After(time.Now()))
assert.True(s.T(), listedPreAuthKeys[4].Expiration.AsTime().After(time.Now())) assert.True(s.T(), listedPreAuthKeys[4].Expiration.AsTime().After(time.Now()))
assert.True(s.T(), listedPreAuthKeys[0].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)))
assert.True(s.T(), listedPreAuthKeys[1].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)))
assert.True(s.T(), listedPreAuthKeys[2].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)))
assert.True(s.T(), listedPreAuthKeys[3].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)))
assert.True(s.T(), listedPreAuthKeys[4].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)))
// Expire three keys // Expire three keys
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
_, err := ExecuteCommand( _, err := ExecuteCommand(
@ -342,6 +348,54 @@ func (s *IntegrationCLITestSuite) TestPreAuthKeyCommand() {
assert.True(s.T(), listedAfterExpirePreAuthKeys[4].Expiration.AsTime().After(time.Now())) assert.True(s.T(), listedAfterExpirePreAuthKeys[4].Expiration.AsTime().After(time.Now()))
} }
func (s *IntegrationCLITestSuite) TestPreAuthKeyCommandWithoutExpiry() {
namespace, err := s.createNamespace("pre-auth-key-without-exp-namespace")
assert.Nil(s.T(), err)
preAuthResult, err := ExecuteCommand(
&s.headscale,
[]string{
"headscale",
"preauthkeys",
"--namespace",
namespace.Name,
"create",
"--reusable",
"--output",
"json",
},
[]string{},
)
assert.Nil(s.T(), err)
var preAuthKey v1.PreAuthKey
err = json.Unmarshal([]byte(preAuthResult), &preAuthKey)
assert.Nil(s.T(), err)
// Test list of keys
listResult, err := ExecuteCommand(
&s.headscale,
[]string{
"headscale",
"preauthkeys",
"--namespace",
namespace.Name,
"list",
"--output",
"json",
},
[]string{},
)
assert.Nil(s.T(), err)
var listedPreAuthKeys []v1.PreAuthKey
err = json.Unmarshal([]byte(listResult), &listedPreAuthKeys)
assert.Nil(s.T(), err)
assert.Len(s.T(), listedPreAuthKeys, 1)
assert.True(s.T(), time.Time{}.Equal(listedPreAuthKeys[0].Expiration.AsTime()))
}
func (s *IntegrationCLITestSuite) TestNodeCommand() { func (s *IntegrationCLITestSuite) TestNodeCommand() {
namespace, err := s.createNamespace("machine-namespace") namespace, err := s.createNamespace("machine-namespace")
assert.Nil(s.T(), err) assert.Nil(s.T(), err)