Only load needed part of configuration (#2109)

This commit is contained in:
Kristoffer Dalby 2024-09-07 09:23:58 +02:00 committed by GitHub
parent f368ed01ed
commit 8a3a0fee3c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 196 additions and 324 deletions

1
.gitignore vendored
View file

@ -22,6 +22,7 @@ dist/
/headscale /headscale
config.json config.json
config.yaml config.yaml
config*.yaml
derp.yaml derp.yaml
*.hujson *.hujson
*.key *.key

View file

@ -72,6 +72,8 @@ after improving the test harness as part of adopting [#1460](https://github.com/
- Add APIs for managing headscale policy. [#1792](https://github.com/juanfont/headscale/pull/1792) - Add APIs for managing headscale policy. [#1792](https://github.com/juanfont/headscale/pull/1792)
- Fix for registering nodes using preauthkeys when running on a postgres database in a non-UTC timezone. [#764](https://github.com/juanfont/headscale/issues/764) - Fix for registering nodes using preauthkeys when running on a postgres database in a non-UTC timezone. [#764](https://github.com/juanfont/headscale/issues/764)
- Make sure integration tests cover postgres for all scenarios - Make sure integration tests cover postgres for all scenarios
- CLI commands (all except `serve`) only requires minimal configuration, no more errors or warnings from unset settings [#2109](https://github.com/juanfont/headscale/pull/2109)
- CLI results are now concistently sent to stdout and errors to stderr [#2109](https://github.com/juanfont/headscale/pull/2109)
## 0.22.3 (2023-05-12) ## 0.22.3 (2023-05-12)

View file

@ -54,7 +54,7 @@ var listAPIKeys = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output") output, _ := cmd.Flags().GetString("output")
ctx, client, conn, cancel := getHeadscaleCLIClient() ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel() defer cancel()
defer conn.Close() defer conn.Close()
@ -67,14 +67,10 @@ var listAPIKeys = &cobra.Command{
fmt.Sprintf("Error getting the list of keys: %s", err), fmt.Sprintf("Error getting the list of keys: %s", err),
output, output,
) )
return
} }
if output != "" { if output != "" {
SuccessOutput(response.GetApiKeys(), "", output) SuccessOutput(response.GetApiKeys(), "", output)
return
} }
tableData := pterm.TableData{ tableData := pterm.TableData{
@ -102,8 +98,6 @@ var listAPIKeys = &cobra.Command{
fmt.Sprintf("Failed to render pterm table: %s", err), fmt.Sprintf("Failed to render pterm table: %s", err),
output, output,
) )
return
} }
}, },
} }
@ -119,9 +113,6 @@ If you loose a key, create a new one and revoke (expire) the old one.`,
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output") output, _ := cmd.Flags().GetString("output")
log.Trace().
Msg("Preparing to create ApiKey")
request := &v1.CreateApiKeyRequest{} request := &v1.CreateApiKeyRequest{}
durationStr, _ := cmd.Flags().GetString("expiration") durationStr, _ := cmd.Flags().GetString("expiration")
@ -133,19 +124,13 @@ If you loose a key, create a new one and revoke (expire) the old one.`,
fmt.Sprintf("Could not parse duration: %s\n", err), fmt.Sprintf("Could not parse duration: %s\n", err),
output, output,
) )
return
} }
expiration := time.Now().UTC().Add(time.Duration(duration)) expiration := time.Now().UTC().Add(time.Duration(duration))
log.Trace().
Dur("expiration", time.Duration(duration)).
Msg("expiration has been set")
request.Expiration = timestamppb.New(expiration) request.Expiration = timestamppb.New(expiration)
ctx, client, conn, cancel := getHeadscaleCLIClient() ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel() defer cancel()
defer conn.Close() defer conn.Close()
@ -156,8 +141,6 @@ If you loose a key, create a new one and revoke (expire) the old one.`,
fmt.Sprintf("Cannot create Api Key: %s\n", err), fmt.Sprintf("Cannot create Api Key: %s\n", err),
output, output,
) )
return
} }
SuccessOutput(response.GetApiKey(), response.GetApiKey(), output) SuccessOutput(response.GetApiKey(), response.GetApiKey(), output)
@ -178,11 +161,9 @@ var expireAPIKeyCmd = &cobra.Command{
fmt.Sprintf("Error getting prefix from CLI flag: %s", err), fmt.Sprintf("Error getting prefix from CLI flag: %s", err),
output, output,
) )
return
} }
ctx, client, conn, cancel := getHeadscaleCLIClient() ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel() defer cancel()
defer conn.Close() defer conn.Close()
@ -197,8 +178,6 @@ var expireAPIKeyCmd = &cobra.Command{
fmt.Sprintf("Cannot expire Api Key: %s\n", err), fmt.Sprintf("Cannot expire Api Key: %s\n", err),
output, output,
) )
return
} }
SuccessOutput(response, "Key expired", output) SuccessOutput(response, "Key expired", output)
@ -219,11 +198,9 @@ var deleteAPIKeyCmd = &cobra.Command{
fmt.Sprintf("Error getting prefix from CLI flag: %s", err), fmt.Sprintf("Error getting prefix from CLI flag: %s", err),
output, output,
) )
return
} }
ctx, client, conn, cancel := getHeadscaleCLIClient() ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel() defer cancel()
defer conn.Close() defer conn.Close()
@ -238,8 +215,6 @@ var deleteAPIKeyCmd = &cobra.Command{
fmt.Sprintf("Cannot delete Api Key: %s\n", err), fmt.Sprintf("Cannot delete Api Key: %s\n", err),
output, output,
) )
return
} }
SuccessOutput(response, "Key deleted", output) SuccessOutput(response, "Key deleted", output)

View file

@ -14,7 +14,7 @@ var configTestCmd = &cobra.Command{
Short: "Test the configuration.", Short: "Test the configuration.",
Long: "Run a test of the configuration and exit.", Long: "Run a test of the configuration and exit.",
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
_, err := getHeadscaleApp() _, err := newHeadscaleServerWithConfig()
if err != nil { if err != nil {
log.Fatal().Caller().Err(err).Msg("Error initializing") log.Fatal().Caller().Err(err).Msg("Error initializing")
} }

View file

@ -64,11 +64,9 @@ var createNodeCmd = &cobra.Command{
user, err := cmd.Flags().GetString("user") user, err := cmd.Flags().GetString("user")
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
return
} }
ctx, client, conn, cancel := getHeadscaleCLIClient() ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel() defer cancel()
defer conn.Close() defer conn.Close()
@ -79,8 +77,6 @@ var createNodeCmd = &cobra.Command{
fmt.Sprintf("Error getting node from flag: %s", err), fmt.Sprintf("Error getting node from flag: %s", err),
output, output,
) )
return
} }
machineKey, err := cmd.Flags().GetString("key") machineKey, err := cmd.Flags().GetString("key")
@ -90,8 +86,6 @@ var createNodeCmd = &cobra.Command{
fmt.Sprintf("Error getting key from flag: %s", err), fmt.Sprintf("Error getting key from flag: %s", err),
output, output,
) )
return
} }
var mkey key.MachinePublic var mkey key.MachinePublic
@ -102,8 +96,6 @@ var createNodeCmd = &cobra.Command{
fmt.Sprintf("Failed to parse machine key from flag: %s", err), fmt.Sprintf("Failed to parse machine key from flag: %s", err),
output, output,
) )
return
} }
routes, err := cmd.Flags().GetStringSlice("route") routes, err := cmd.Flags().GetStringSlice("route")
@ -113,8 +105,6 @@ var createNodeCmd = &cobra.Command{
fmt.Sprintf("Error getting routes from flag: %s", err), fmt.Sprintf("Error getting routes from flag: %s", err),
output, output,
) )
return
} }
request := &v1.DebugCreateNodeRequest{ request := &v1.DebugCreateNodeRequest{
@ -131,8 +121,6 @@ var createNodeCmd = &cobra.Command{
fmt.Sprintf("Cannot create node: %s", status.Convert(err).Message()), fmt.Sprintf("Cannot create node: %s", status.Convert(err).Message()),
output, output,
) )
return
} }
SuccessOutput(response.GetNode(), "Node created", output) SuccessOutput(response.GetNode(), "Node created", output)

View file

@ -116,11 +116,9 @@ var registerNodeCmd = &cobra.Command{
user, err := cmd.Flags().GetString("user") user, err := cmd.Flags().GetString("user")
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
return
} }
ctx, client, conn, cancel := getHeadscaleCLIClient() ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel() defer cancel()
defer conn.Close() defer conn.Close()
@ -131,8 +129,6 @@ var registerNodeCmd = &cobra.Command{
fmt.Sprintf("Error getting node key from flag: %s", err), fmt.Sprintf("Error getting node key from flag: %s", err),
output, output,
) )
return
} }
request := &v1.RegisterNodeRequest{ request := &v1.RegisterNodeRequest{
@ -150,8 +146,6 @@ var registerNodeCmd = &cobra.Command{
), ),
output, output,
) )
return
} }
SuccessOutput( SuccessOutput(
@ -169,17 +163,13 @@ var listNodesCmd = &cobra.Command{
user, err := cmd.Flags().GetString("user") user, err := cmd.Flags().GetString("user")
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
return
} }
showTags, err := cmd.Flags().GetBool("tags") showTags, err := cmd.Flags().GetBool("tags")
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting tags flag: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error getting tags flag: %s", err), output)
return
} }
ctx, client, conn, cancel := getHeadscaleCLIClient() ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel() defer cancel()
defer conn.Close() defer conn.Close()
@ -194,21 +184,15 @@ var listNodesCmd = &cobra.Command{
fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()), fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()),
output, output,
) )
return
} }
if output != "" { if output != "" {
SuccessOutput(response.GetNodes(), "", output) SuccessOutput(response.GetNodes(), "", output)
return
} }
tableData, err := nodesToPtables(user, showTags, response.GetNodes()) tableData, err := nodesToPtables(user, showTags, response.GetNodes())
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
return
} }
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render() err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
@ -218,8 +202,6 @@ var listNodesCmd = &cobra.Command{
fmt.Sprintf("Failed to render pterm table: %s", err), fmt.Sprintf("Failed to render pterm table: %s", err),
output, output,
) )
return
} }
}, },
} }
@ -243,7 +225,7 @@ var expireNodeCmd = &cobra.Command{
return return
} }
ctx, client, conn, cancel := getHeadscaleCLIClient() ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel() defer cancel()
defer conn.Close() defer conn.Close()
@ -286,7 +268,7 @@ var renameNodeCmd = &cobra.Command{
return return
} }
ctx, client, conn, cancel := getHeadscaleCLIClient() ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel() defer cancel()
defer conn.Close() defer conn.Close()
@ -335,7 +317,7 @@ var deleteNodeCmd = &cobra.Command{
return return
} }
ctx, client, conn, cancel := getHeadscaleCLIClient() ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel() defer cancel()
defer conn.Close() defer conn.Close()
@ -435,7 +417,7 @@ var moveNodeCmd = &cobra.Command{
return return
} }
ctx, client, conn, cancel := getHeadscaleCLIClient() ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel() defer cancel()
defer conn.Close() defer conn.Close()
@ -508,7 +490,7 @@ be assigned to nodes.`,
return return
} }
if confirm { if confirm {
ctx, client, conn, cancel := getHeadscaleCLIClient() ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel() defer cancel()
defer conn.Close() defer conn.Close()
@ -681,7 +663,7 @@ var tagCmd = &cobra.Command{
Aliases: []string{"tags", "t"}, Aliases: []string{"tags", "t"},
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output") output, _ := cmd.Flags().GetString("output")
ctx, client, conn, cancel := getHeadscaleCLIClient() ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel() defer cancel()
defer conn.Close() defer conn.Close()

View file

@ -1,6 +1,7 @@
package cli package cli
import ( import (
"fmt"
"io" "io"
"os" "os"
@ -30,7 +31,8 @@ var getPolicy = &cobra.Command{
Short: "Print the current ACL Policy", Short: "Print the current ACL Policy",
Aliases: []string{"show", "view", "fetch"}, Aliases: []string{"show", "view", "fetch"},
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
ctx, client, conn, cancel := getHeadscaleCLIClient() output, _ := cmd.Flags().GetString("output")
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel() defer cancel()
defer conn.Close() defer conn.Close()
@ -38,13 +40,13 @@ var getPolicy = &cobra.Command{
response, err := client.GetPolicy(ctx, request) response, err := client.GetPolicy(ctx, request)
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("Failed to get the policy") ErrorOutput(err, fmt.Sprintf("Failed loading ACL Policy: %s", err), output)
return
} }
// TODO(pallabpain): Maybe print this better? // TODO(pallabpain): Maybe print this better?
SuccessOutput("", response.GetPolicy(), "hujson") // This does not pass output as we dont support yaml, json or json-line
// output for this command. It is HuJSON already.
SuccessOutput("", response.GetPolicy(), "")
}, },
} }
@ -56,33 +58,28 @@ var setPolicy = &cobra.Command{
This command only works when the acl.policy_mode is set to "db", and the policy will be stored in the database.`, This command only works when the acl.policy_mode is set to "db", and the policy will be stored in the database.`,
Aliases: []string{"put", "update"}, Aliases: []string{"put", "update"},
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
policyPath, _ := cmd.Flags().GetString("file") policyPath, _ := cmd.Flags().GetString("file")
f, err := os.Open(policyPath) f, err := os.Open(policyPath)
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("Error opening the policy file") ErrorOutput(err, fmt.Sprintf("Error opening the policy file: %s", err), output)
return
} }
defer f.Close() defer f.Close()
policyBytes, err := io.ReadAll(f) policyBytes, err := io.ReadAll(f)
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("Error reading the policy file") ErrorOutput(err, fmt.Sprintf("Error reading the policy file: %s", err), output)
return
} }
request := &v1.SetPolicyRequest{Policy: string(policyBytes)} request := &v1.SetPolicyRequest{Policy: string(policyBytes)}
ctx, client, conn, cancel := getHeadscaleCLIClient() ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel() defer cancel()
defer conn.Close() defer conn.Close()
if _, err := client.SetPolicy(ctx, request); err != nil { if _, err := client.SetPolicy(ctx, request); err != nil {
log.Fatal().Err(err).Msg("Failed to set ACL Policy") ErrorOutput(err, fmt.Sprintf("Failed to set ACL Policy: %s", err), output)
return
} }
SuccessOutput(nil, "Policy updated.", "") SuccessOutput(nil, "Policy updated.", "")

View file

@ -60,11 +60,9 @@ var listPreAuthKeys = &cobra.Command{
user, err := cmd.Flags().GetString("user") user, err := cmd.Flags().GetString("user")
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
return
} }
ctx, client, conn, cancel := getHeadscaleCLIClient() ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel() defer cancel()
defer conn.Close() defer conn.Close()
@ -85,8 +83,6 @@ var listPreAuthKeys = &cobra.Command{
if output != "" { if output != "" {
SuccessOutput(response.GetPreAuthKeys(), "", output) SuccessOutput(response.GetPreAuthKeys(), "", output)
return
} }
tableData := pterm.TableData{ tableData := pterm.TableData{
@ -134,8 +130,6 @@ var listPreAuthKeys = &cobra.Command{
fmt.Sprintf("Failed to render pterm table: %s", err), fmt.Sprintf("Failed to render pterm table: %s", err),
output, output,
) )
return
} }
}, },
} }
@ -150,20 +144,12 @@ var createPreAuthKeyCmd = &cobra.Command{
user, err := cmd.Flags().GetString("user") user, err := cmd.Flags().GetString("user")
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
return
} }
reusable, _ := cmd.Flags().GetBool("reusable") reusable, _ := cmd.Flags().GetBool("reusable")
ephemeral, _ := cmd.Flags().GetBool("ephemeral") ephemeral, _ := cmd.Flags().GetBool("ephemeral")
tags, _ := cmd.Flags().GetStringSlice("tags") tags, _ := cmd.Flags().GetStringSlice("tags")
log.Trace().
Bool("reusable", reusable).
Bool("ephemeral", ephemeral).
Str("user", user).
Msg("Preparing to create preauthkey")
request := &v1.CreatePreAuthKeyRequest{ request := &v1.CreatePreAuthKeyRequest{
User: user, User: user,
Reusable: reusable, Reusable: reusable,
@ -180,8 +166,6 @@ var createPreAuthKeyCmd = &cobra.Command{
fmt.Sprintf("Could not parse duration: %s\n", err), fmt.Sprintf("Could not parse duration: %s\n", err),
output, output,
) )
return
} }
expiration := time.Now().UTC().Add(time.Duration(duration)) expiration := time.Now().UTC().Add(time.Duration(duration))
@ -192,7 +176,7 @@ var createPreAuthKeyCmd = &cobra.Command{
request.Expiration = timestamppb.New(expiration) request.Expiration = timestamppb.New(expiration)
ctx, client, conn, cancel := getHeadscaleCLIClient() ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel() defer cancel()
defer conn.Close() defer conn.Close()
@ -203,8 +187,6 @@ var createPreAuthKeyCmd = &cobra.Command{
fmt.Sprintf("Cannot create Pre Auth Key: %s\n", err), fmt.Sprintf("Cannot create Pre Auth Key: %s\n", err),
output, output,
) )
return
} }
SuccessOutput(response.GetPreAuthKey(), response.GetPreAuthKey().GetKey(), output) SuccessOutput(response.GetPreAuthKey(), response.GetPreAuthKey().GetKey(), output)
@ -227,11 +209,9 @@ var expirePreAuthKeyCmd = &cobra.Command{
user, err := cmd.Flags().GetString("user") user, err := cmd.Flags().GetString("user")
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
return
} }
ctx, client, conn, cancel := getHeadscaleCLIClient() ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel() defer cancel()
defer conn.Close() defer conn.Close()
@ -247,8 +227,6 @@ var expirePreAuthKeyCmd = &cobra.Command{
fmt.Sprintf("Cannot expire Pre Auth Key: %s\n", err), fmt.Sprintf("Cannot expire Pre Auth Key: %s\n", err),
output, output,
) )
return
} }
SuccessOutput(response, "Key expired", output) SuccessOutput(response, "Key expired", output)

View file

@ -9,6 +9,7 @@ import (
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/viper"
"github.com/tcnksm/go-latest" "github.com/tcnksm/go-latest"
) )
@ -49,11 +50,6 @@ func initConfig() {
} }
} }
cfg, err := types.GetHeadscaleConfig()
if err != nil {
log.Fatal().Err(err).Msg("Failed to read headscale configuration")
}
machineOutput := HasMachineOutputFlag() machineOutput := HasMachineOutputFlag()
// If the user has requested a "node" readable format, // If the user has requested a "node" readable format,
@ -62,11 +58,13 @@ func initConfig() {
zerolog.SetGlobalLevel(zerolog.Disabled) zerolog.SetGlobalLevel(zerolog.Disabled)
} }
if cfg.Log.Format == types.JSONLogFormat { // logFormat := viper.GetString("log.format")
log.Logger = log.Output(os.Stdout) // if logFormat == types.JSONLogFormat {
} // log.Logger = log.Output(os.Stdout)
// }
if !cfg.DisableUpdateCheck && !machineOutput { disableUpdateCheck := viper.GetBool("disable_check_updates")
if !disableUpdateCheck && !machineOutput {
if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") && if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") &&
Version != "dev" { Version != "dev" {
githubTag := &latest.GithubTag{ githubTag := &latest.GithubTag{

View file

@ -64,11 +64,9 @@ var listRoutesCmd = &cobra.Command{
fmt.Sprintf("Error getting machine id from flag: %s", err), fmt.Sprintf("Error getting machine id from flag: %s", err),
output, output,
) )
return
} }
ctx, client, conn, cancel := getHeadscaleCLIClient() ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel() defer cancel()
defer conn.Close() defer conn.Close()
@ -82,14 +80,10 @@ var listRoutesCmd = &cobra.Command{
fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()), fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()),
output, output,
) )
return
} }
if output != "" { if output != "" {
SuccessOutput(response.GetRoutes(), "", output) SuccessOutput(response.GetRoutes(), "", output)
return
} }
routes = response.GetRoutes() routes = response.GetRoutes()
@ -103,14 +97,10 @@ var listRoutesCmd = &cobra.Command{
fmt.Sprintf("Cannot get routes for node %d: %s", machineID, status.Convert(err).Message()), fmt.Sprintf("Cannot get routes for node %d: %s", machineID, status.Convert(err).Message()),
output, output,
) )
return
} }
if output != "" { if output != "" {
SuccessOutput(response.GetRoutes(), "", output) SuccessOutput(response.GetRoutes(), "", output)
return
} }
routes = response.GetRoutes() routes = response.GetRoutes()
@ -119,8 +109,6 @@ var listRoutesCmd = &cobra.Command{
tableData := routesToPtables(routes) tableData := routesToPtables(routes)
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
return
} }
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render() err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
@ -130,8 +118,6 @@ var listRoutesCmd = &cobra.Command{
fmt.Sprintf("Failed to render pterm table: %s", err), fmt.Sprintf("Failed to render pterm table: %s", err),
output, output,
) )
return
} }
}, },
} }
@ -150,11 +136,9 @@ var enableRouteCmd = &cobra.Command{
fmt.Sprintf("Error getting machine id from flag: %s", err), fmt.Sprintf("Error getting machine id from flag: %s", err),
output, output,
) )
return
} }
ctx, client, conn, cancel := getHeadscaleCLIClient() ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel() defer cancel()
defer conn.Close() defer conn.Close()
@ -167,14 +151,10 @@ var enableRouteCmd = &cobra.Command{
fmt.Sprintf("Cannot enable route %d: %s", routeID, status.Convert(err).Message()), fmt.Sprintf("Cannot enable route %d: %s", routeID, status.Convert(err).Message()),
output, output,
) )
return
} }
if output != "" { if output != "" {
SuccessOutput(response, "", output) SuccessOutput(response, "", output)
return
} }
}, },
} }
@ -193,11 +173,9 @@ var disableRouteCmd = &cobra.Command{
fmt.Sprintf("Error getting machine id from flag: %s", err), fmt.Sprintf("Error getting machine id from flag: %s", err),
output, output,
) )
return
} }
ctx, client, conn, cancel := getHeadscaleCLIClient() ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel() defer cancel()
defer conn.Close() defer conn.Close()
@ -210,14 +188,10 @@ var disableRouteCmd = &cobra.Command{
fmt.Sprintf("Cannot disable route %d: %s", routeID, status.Convert(err).Message()), fmt.Sprintf("Cannot disable route %d: %s", routeID, status.Convert(err).Message()),
output, output,
) )
return
} }
if output != "" { if output != "" {
SuccessOutput(response, "", output) SuccessOutput(response, "", output)
return
} }
}, },
} }
@ -236,11 +210,9 @@ var deleteRouteCmd = &cobra.Command{
fmt.Sprintf("Error getting machine id from flag: %s", err), fmt.Sprintf("Error getting machine id from flag: %s", err),
output, output,
) )
return
} }
ctx, client, conn, cancel := getHeadscaleCLIClient() ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel() defer cancel()
defer conn.Close() defer conn.Close()
@ -253,14 +225,10 @@ var deleteRouteCmd = &cobra.Command{
fmt.Sprintf("Cannot delete route %d: %s", routeID, status.Convert(err).Message()), fmt.Sprintf("Cannot delete route %d: %s", routeID, status.Convert(err).Message()),
output, output,
) )
return
} }
if output != "" { if output != "" {
SuccessOutput(response, "", output) SuccessOutput(response, "", output)
return
} }
}, },
} }

View file

@ -16,7 +16,7 @@ var serveCmd = &cobra.Command{
return nil return nil
}, },
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
app, err := getHeadscaleApp() app, err := newHeadscaleServerWithConfig()
if err != nil { if err != nil {
log.Fatal().Caller().Err(err).Msg("Error initializing") log.Fatal().Caller().Err(err).Msg("Error initializing")
} }

View file

@ -44,7 +44,7 @@ var createUserCmd = &cobra.Command{
userName := args[0] userName := args[0]
ctx, client, conn, cancel := getHeadscaleCLIClient() ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel() defer cancel()
defer conn.Close() defer conn.Close()
@ -63,8 +63,6 @@ var createUserCmd = &cobra.Command{
), ),
output, output,
) )
return
} }
SuccessOutput(response.GetUser(), "User created", output) SuccessOutput(response.GetUser(), "User created", output)
@ -91,7 +89,7 @@ var destroyUserCmd = &cobra.Command{
Name: userName, Name: userName,
} }
ctx, client, conn, cancel := getHeadscaleCLIClient() ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel() defer cancel()
defer conn.Close() defer conn.Close()
@ -102,8 +100,6 @@ var destroyUserCmd = &cobra.Command{
fmt.Sprintf("Error: %s", status.Convert(err).Message()), fmt.Sprintf("Error: %s", status.Convert(err).Message()),
output, output,
) )
return
} }
confirm := false confirm := false
@ -134,8 +130,6 @@ var destroyUserCmd = &cobra.Command{
), ),
output, output,
) )
return
} }
SuccessOutput(response, "User destroyed", output) SuccessOutput(response, "User destroyed", output)
} else { } else {
@ -151,7 +145,7 @@ var listUsersCmd = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output") output, _ := cmd.Flags().GetString("output")
ctx, client, conn, cancel := getHeadscaleCLIClient() ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel() defer cancel()
defer conn.Close() defer conn.Close()
@ -164,14 +158,10 @@ var listUsersCmd = &cobra.Command{
fmt.Sprintf("Cannot get users: %s", status.Convert(err).Message()), fmt.Sprintf("Cannot get users: %s", status.Convert(err).Message()),
output, output,
) )
return
} }
if output != "" { if output != "" {
SuccessOutput(response.GetUsers(), "", output) SuccessOutput(response.GetUsers(), "", output)
return
} }
tableData := pterm.TableData{{"ID", "Name", "Created"}} tableData := pterm.TableData{{"ID", "Name", "Created"}}
@ -192,8 +182,6 @@ var listUsersCmd = &cobra.Command{
fmt.Sprintf("Failed to render pterm table: %s", err), fmt.Sprintf("Failed to render pterm table: %s", err),
output, output,
) )
return
} }
}, },
} }
@ -213,7 +201,7 @@ var renameUserCmd = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output") output, _ := cmd.Flags().GetString("output")
ctx, client, conn, cancel := getHeadscaleCLIClient() ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel() defer cancel()
defer conn.Close() defer conn.Close()
@ -232,8 +220,6 @@ var renameUserCmd = &cobra.Command{
), ),
output, output,
) )
return
} }
SuccessOutput(response.GetUser(), "User renamed", output) SuccessOutput(response.GetUser(), "User renamed", output)

View file

@ -23,8 +23,8 @@ const (
SocketWritePermissions = 0o666 SocketWritePermissions = 0o666
) )
func getHeadscaleApp() (*hscontrol.Headscale, error) { func newHeadscaleServerWithConfig() (*hscontrol.Headscale, error) {
cfg, err := types.GetHeadscaleConfig() cfg, err := types.LoadServerConfig()
if err != nil { if err != nil {
return nil, fmt.Errorf( return nil, fmt.Errorf(
"failed to load configuration while creating headscale instance: %w", "failed to load configuration while creating headscale instance: %w",
@ -40,8 +40,8 @@ func getHeadscaleApp() (*hscontrol.Headscale, error) {
return app, nil return app, nil
} }
func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) { func newHeadscaleCLIWithConfig() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) {
cfg, err := types.GetHeadscaleConfig() cfg, err := types.LoadCLIConfig()
if err != nil { if err != nil {
log.Fatal(). log.Fatal().
Err(err). Err(err).
@ -130,7 +130,7 @@ func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.
return ctx, client, conn, cancel return ctx, client, conn, cancel
} }
func SuccessOutput(result interface{}, override string, outputFormat string) { func output(result interface{}, override string, outputFormat string) string {
var jsonBytes []byte var jsonBytes []byte
var err error var err error
switch outputFormat { switch outputFormat {
@ -151,21 +151,26 @@ func SuccessOutput(result interface{}, override string, outputFormat string) {
} }
default: default:
// nolint // nolint
fmt.Println(override) return override
return
} }
// nolint return string(jsonBytes)
fmt.Println(string(jsonBytes))
} }
// SuccessOutput prints the result to stdout and exits with status code 0.
func SuccessOutput(result interface{}, override string, outputFormat string) {
fmt.Println(output(result, override, outputFormat))
os.Exit(0)
}
// ErrorOutput prints an error message to stderr and exits with status code 1.
func ErrorOutput(errResult error, override string, outputFormat string) { func ErrorOutput(errResult error, override string, outputFormat string) {
type errOutput struct { type errOutput struct {
Error string `json:"error"` Error string `json:"error"`
} }
SuccessOutput(errOutput{errResult.Error()}, override, outputFormat) fmt.Fprintf(os.Stderr, "%s\n", output(errOutput{errResult.Error()}, override, outputFormat))
os.Exit(1)
} }
func HasMachineOutputFlag() bool { func HasMachineOutputFlag() bool {

View file

@ -4,7 +4,6 @@ import (
"io/fs" "io/fs"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"testing" "testing"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
@ -113,60 +112,3 @@ func (*Suite) TestConfigLoading(c *check.C) {
c.Assert(viper.GetBool("logtail.enabled"), check.Equals, false) c.Assert(viper.GetBool("logtail.enabled"), check.Equals, false)
c.Assert(viper.GetBool("randomize_client_port"), check.Equals, false) c.Assert(viper.GetBool("randomize_client_port"), check.Equals, false)
} }
func writeConfig(c *check.C, tmpDir string, configYaml []byte) {
// Populate a custom config file
configFile := filepath.Join(tmpDir, "config.yaml")
err := os.WriteFile(configFile, configYaml, 0o600)
if err != nil {
c.Fatalf("Couldn't write file %s", configFile)
}
}
func (*Suite) TestTLSConfigValidation(c *check.C) {
tmpDir, err := os.MkdirTemp("", "headscale")
if err != nil {
c.Fatal(err)
}
// defer os.RemoveAll(tmpDir)
configYaml := []byte(`---
tls_letsencrypt_hostname: example.com
tls_letsencrypt_challenge_type: ""
tls_cert_path: abc.pem
noise:
private_key_path: noise_private.key`)
writeConfig(c, tmpDir, configYaml)
// Check configuration validation errors (1)
err = types.LoadConfig(tmpDir, false)
c.Assert(err, check.NotNil)
// check.Matches can not handle multiline strings
tmp := strings.ReplaceAll(err.Error(), "\n", "***")
c.Assert(
tmp,
check.Matches,
".*Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both.*",
)
c.Assert(
tmp,
check.Matches,
".*Fatal config error: the only supported values for tls_letsencrypt_challenge_type are.*",
)
c.Assert(
tmp,
check.Matches,
".*Fatal config error: server_url must start with https:// or http://.*",
)
// Check configuration validation errors (2)
configYaml = []byte(`---
noise:
private_key_path: noise_private.key
server_url: http://127.0.0.1:8080
tls_letsencrypt_hostname: example.com
tls_letsencrypt_challenge_type: TLS-ALPN-01
`)
writeConfig(c, tmpDir, configYaml)
err = types.LoadConfig(tmpDir, false)
c.Assert(err, check.IsNil)
}

View file

@ -684,7 +684,7 @@ func (api headscaleV1APIServer) GetPolicy(
case types.PolicyModeDB: case types.PolicyModeDB:
p, err := api.h.db.GetPolicy() p, err := api.h.db.GetPolicy()
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("loading ACL from database: %w", err)
} }
return &v1.GetPolicyResponse{ return &v1.GetPolicyResponse{
@ -696,20 +696,20 @@ func (api headscaleV1APIServer) GetPolicy(
absPath := util.AbsolutePathFromConfigPath(api.h.cfg.Policy.Path) absPath := util.AbsolutePathFromConfigPath(api.h.cfg.Policy.Path)
f, err := os.Open(absPath) f, err := os.Open(absPath)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("reading policy from path %q: %w", absPath, err)
} }
defer f.Close() defer f.Close()
b, err := io.ReadAll(f) b, err := io.ReadAll(f)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("reading policy from file: %w", err)
} }
return &v1.GetPolicyResponse{Policy: string(b)}, nil return &v1.GetPolicyResponse{Policy: string(b)}, nil
} }
return nil, nil return nil, fmt.Errorf("no supported policy mode found in configuration, policy.mode: %q", api.h.cfg.Policy.Mode)
} }
func (api headscaleV1APIServer) SetPolicy( func (api headscaleV1APIServer) SetPolicy(

View file

@ -212,6 +212,12 @@ type Tuning struct {
NodeMapSessionBufferedChanSize int NodeMapSessionBufferedChanSize int
} }
// LoadConfig prepares and loads the Headscale configuration into Viper.
// This means it sets the default values, reads the configuration file and
// environment variables, and handles deprecated configuration options.
// It has to be called before LoadServerConfig and LoadCLIConfig.
// The configuration is not validated and the caller should check for errors
// using a validation function.
func LoadConfig(path string, isFile bool) error { func LoadConfig(path string, isFile bool) error {
if isFile { if isFile {
viper.SetConfigFile(path) viper.SetConfigFile(path)
@ -284,14 +290,14 @@ func LoadConfig(path string, isFile bool) error {
viper.SetDefault("prefixes.allocation", string(IPAllocationStrategySequential)) viper.SetDefault("prefixes.allocation", string(IPAllocationStrategySequential))
if IsCLIConfigured() {
return nil
}
if err := viper.ReadInConfig(); err != nil { if err := viper.ReadInConfig(); err != nil {
return fmt.Errorf("fatal error reading config file: %w", err) return fmt.Errorf("fatal error reading config file: %w", err)
} }
return nil
}
func validateServerConfig() error {
depr := deprecator{ depr := deprecator{
warns: make(set.Set[string]), warns: make(set.Set[string]),
fatals: make(set.Set[string]), fatals: make(set.Set[string]),
@ -360,12 +366,12 @@ func LoadConfig(path string, isFile bool) error {
if errorText != "" { if errorText != "" {
// nolint // nolint
return errors.New(strings.TrimSuffix(errorText, "\n")) return errors.New(strings.TrimSuffix(errorText, "\n"))
} else {
return nil
} }
return nil
} }
func GetTLSConfig() TLSConfig { func tlsConfig() TLSConfig {
return TLSConfig{ return TLSConfig{
LetsEncrypt: LetsEncryptConfig{ LetsEncrypt: LetsEncryptConfig{
Hostname: viper.GetString("tls_letsencrypt_hostname"), Hostname: viper.GetString("tls_letsencrypt_hostname"),
@ -384,7 +390,7 @@ func GetTLSConfig() TLSConfig {
} }
} }
func GetDERPConfig() DERPConfig { func derpConfig() DERPConfig {
serverEnabled := viper.GetBool("derp.server.enabled") serverEnabled := viper.GetBool("derp.server.enabled")
serverRegionID := viper.GetInt("derp.server.region_id") serverRegionID := viper.GetInt("derp.server.region_id")
serverRegionCode := viper.GetString("derp.server.region_code") serverRegionCode := viper.GetString("derp.server.region_code")
@ -445,7 +451,7 @@ func GetDERPConfig() DERPConfig {
} }
} }
func GetLogTailConfig() LogTailConfig { func logtailConfig() LogTailConfig {
enabled := viper.GetBool("logtail.enabled") enabled := viper.GetBool("logtail.enabled")
return LogTailConfig{ return LogTailConfig{
@ -453,7 +459,7 @@ func GetLogTailConfig() LogTailConfig {
} }
} }
func GetPolicyConfig() PolicyConfig { func policyConfig() PolicyConfig {
policyPath := viper.GetString("policy.path") policyPath := viper.GetString("policy.path")
policyMode := viper.GetString("policy.mode") policyMode := viper.GetString("policy.mode")
@ -463,7 +469,7 @@ func GetPolicyConfig() PolicyConfig {
} }
} }
func GetLogConfig() LogConfig { func logConfig() LogConfig {
logLevelStr := viper.GetString("log.level") logLevelStr := viper.GetString("log.level")
logLevel, err := zerolog.ParseLevel(logLevelStr) logLevel, err := zerolog.ParseLevel(logLevelStr)
if err != nil { if err != nil {
@ -473,9 +479,9 @@ func GetLogConfig() LogConfig {
logFormatOpt := viper.GetString("log.format") logFormatOpt := viper.GetString("log.format")
var logFormat string var logFormat string
switch logFormatOpt { switch logFormatOpt {
case "json": case JSONLogFormat:
logFormat = JSONLogFormat logFormat = JSONLogFormat
case "text": case TextLogFormat:
logFormat = TextLogFormat logFormat = TextLogFormat
case "": case "":
logFormat = TextLogFormat logFormat = TextLogFormat
@ -491,7 +497,7 @@ func GetLogConfig() LogConfig {
} }
} }
func GetDatabaseConfig() DatabaseConfig { func databaseConfig() DatabaseConfig {
debug := viper.GetBool("database.debug") debug := viper.GetBool("database.debug")
type_ := viper.GetString("database.type") type_ := viper.GetString("database.type")
@ -543,7 +549,7 @@ func GetDatabaseConfig() DatabaseConfig {
} }
} }
func DNS() (DNSConfig, error) { func dns() (DNSConfig, error) {
var dns DNSConfig var dns DNSConfig
// TODO: Use this instead of manually getting settings when // TODO: Use this instead of manually getting settings when
@ -575,12 +581,12 @@ func DNS() (DNSConfig, error) {
return dns, nil return dns, nil
} }
// GlobalResolvers returns the global DNS resolvers // globalResolvers returns the global DNS resolvers
// defined in the config file. // defined in the config file.
// If a nameserver is a valid IP, it will be used as a regular resolver. // If a nameserver is a valid IP, it will be used as a regular resolver.
// If a nameserver is a valid URL, it will be used as a DoH resolver. // If a nameserver is a valid URL, it will be used as a DoH resolver.
// If a nameserver is neither a valid URL nor a valid IP, it will be ignored. // If a nameserver is neither a valid URL nor a valid IP, it will be ignored.
func (d *DNSConfig) GlobalResolvers() []*dnstype.Resolver { func (d *DNSConfig) globalResolvers() []*dnstype.Resolver {
var resolvers []*dnstype.Resolver var resolvers []*dnstype.Resolver
for _, nsStr := range d.Nameservers.Global { for _, nsStr := range d.Nameservers.Global {
@ -613,11 +619,11 @@ func (d *DNSConfig) GlobalResolvers() []*dnstype.Resolver {
return resolvers return resolvers
} }
// SplitResolvers returns a map of domain to DNS resolvers. // splitResolvers returns a map of domain to DNS resolvers.
// If a nameserver is a valid IP, it will be used as a regular resolver. // If a nameserver is a valid IP, it will be used as a regular resolver.
// If a nameserver is a valid URL, it will be used as a DoH resolver. // If a nameserver is a valid URL, it will be used as a DoH resolver.
// If a nameserver is neither a valid URL nor a valid IP, it will be ignored. // If a nameserver is neither a valid URL nor a valid IP, it will be ignored.
func (d *DNSConfig) SplitResolvers() map[string][]*dnstype.Resolver { func (d *DNSConfig) splitResolvers() map[string][]*dnstype.Resolver {
routes := make(map[string][]*dnstype.Resolver) routes := make(map[string][]*dnstype.Resolver)
for domain, nameservers := range d.Nameservers.Split { for domain, nameservers := range d.Nameservers.Split {
var resolvers []*dnstype.Resolver var resolvers []*dnstype.Resolver
@ -653,7 +659,7 @@ func (d *DNSConfig) SplitResolvers() map[string][]*dnstype.Resolver {
return routes return routes
} }
func DNSToTailcfgDNS(dns DNSConfig) *tailcfg.DNSConfig { func dnsToTailcfgDNS(dns DNSConfig) *tailcfg.DNSConfig {
cfg := tailcfg.DNSConfig{} cfg := tailcfg.DNSConfig{}
if dns.BaseDomain == "" && dns.MagicDNS { if dns.BaseDomain == "" && dns.MagicDNS {
@ -662,9 +668,9 @@ func DNSToTailcfgDNS(dns DNSConfig) *tailcfg.DNSConfig {
cfg.Proxied = dns.MagicDNS cfg.Proxied = dns.MagicDNS
cfg.ExtraRecords = dns.ExtraRecords cfg.ExtraRecords = dns.ExtraRecords
cfg.Resolvers = dns.GlobalResolvers() cfg.Resolvers = dns.globalResolvers()
routes := dns.SplitResolvers() routes := dns.splitResolvers()
cfg.Routes = routes cfg.Routes = routes
if dns.BaseDomain != "" { if dns.BaseDomain != "" {
cfg.Domains = []string{dns.BaseDomain} cfg.Domains = []string{dns.BaseDomain}
@ -674,7 +680,7 @@ func DNSToTailcfgDNS(dns DNSConfig) *tailcfg.DNSConfig {
return &cfg return &cfg
} }
func PrefixV4() (*netip.Prefix, error) { func prefixV4() (*netip.Prefix, error) {
prefixV4Str := viper.GetString("prefixes.v4") prefixV4Str := viper.GetString("prefixes.v4")
if prefixV4Str == "" { if prefixV4Str == "" {
@ -698,7 +704,7 @@ func PrefixV4() (*netip.Prefix, error) {
return &prefixV4, nil return &prefixV4, nil
} }
func PrefixV6() (*netip.Prefix, error) { func prefixV6() (*netip.Prefix, error) {
prefixV6Str := viper.GetString("prefixes.v6") prefixV6Str := viper.GetString("prefixes.v6")
if prefixV6Str == "" { if prefixV6Str == "" {
@ -723,27 +729,37 @@ func PrefixV6() (*netip.Prefix, error) {
return &prefixV6, nil return &prefixV6, nil
} }
func GetHeadscaleConfig() (*Config, error) { // LoadCLIConfig returns the needed configuration for the CLI client
if IsCLIConfigured() { // of Headscale to connect to a Headscale server.
return &Config{ func LoadCLIConfig() (*Config, error) {
CLI: CLIConfig{ return &Config{
Address: viper.GetString("cli.address"), DisableUpdateCheck: viper.GetBool("disable_check_updates"),
APIKey: viper.GetString("cli.api_key"), UnixSocket: viper.GetString("unix_socket"),
Timeout: viper.GetDuration("cli.timeout"), CLI: CLIConfig{
Insecure: viper.GetBool("cli.insecure"), Address: viper.GetString("cli.address"),
}, APIKey: viper.GetString("cli.api_key"),
}, nil Timeout: viper.GetDuration("cli.timeout"),
Insecure: viper.GetBool("cli.insecure"),
},
}, nil
}
// LoadServerConfig returns the full Headscale configuration to
// host a Headscale server. This is called as part of `headscale serve`.
func LoadServerConfig() (*Config, error) {
if err := validateServerConfig(); err != nil {
return nil, err
} }
logConfig := GetLogConfig() logConfig := logConfig()
zerolog.SetGlobalLevel(logConfig.Level) zerolog.SetGlobalLevel(logConfig.Level)
prefix4, err := PrefixV4() prefix4, err := prefixV4()
if err != nil { if err != nil {
return nil, err return nil, err
} }
prefix6, err := PrefixV6() prefix6, err := prefixV6()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -763,13 +779,13 @@ func GetHeadscaleConfig() (*Config, error) {
return nil, fmt.Errorf("config error, prefixes.allocation is set to %s, which is not a valid strategy, allowed options: %s, %s", allocStr, IPAllocationStrategySequential, IPAllocationStrategyRandom) return nil, fmt.Errorf("config error, prefixes.allocation is set to %s, which is not a valid strategy, allowed options: %s, %s", allocStr, IPAllocationStrategySequential, IPAllocationStrategyRandom)
} }
dnsConfig, err := DNS() dnsConfig, err := dns()
if err != nil { if err != nil {
return nil, err return nil, err
} }
derpConfig := GetDERPConfig() derpConfig := derpConfig()
logTailConfig := GetLogTailConfig() logTailConfig := logtailConfig()
randomizeClientPort := viper.GetBool("randomize_client_port") randomizeClientPort := viper.GetBool("randomize_client_port")
oidcClientSecret := viper.GetString("oidc.client_secret") oidcClientSecret := viper.GetString("oidc.client_secret")
@ -806,7 +822,7 @@ func GetHeadscaleConfig() (*Config, error) {
MetricsAddr: viper.GetString("metrics_listen_addr"), MetricsAddr: viper.GetString("metrics_listen_addr"),
GRPCAddr: viper.GetString("grpc_listen_addr"), GRPCAddr: viper.GetString("grpc_listen_addr"),
GRPCAllowInsecure: viper.GetBool("grpc_allow_insecure"), GRPCAllowInsecure: viper.GetBool("grpc_allow_insecure"),
DisableUpdateCheck: viper.GetBool("disable_check_updates"), DisableUpdateCheck: false,
PrefixV4: prefix4, PrefixV4: prefix4,
PrefixV6: prefix6, PrefixV6: prefix6,
@ -823,11 +839,11 @@ func GetHeadscaleConfig() (*Config, error) {
"ephemeral_node_inactivity_timeout", "ephemeral_node_inactivity_timeout",
), ),
Database: GetDatabaseConfig(), Database: databaseConfig(),
TLS: GetTLSConfig(), TLS: tlsConfig(),
DNSConfig: DNSToTailcfgDNS(dnsConfig), DNSConfig: dnsToTailcfgDNS(dnsConfig),
DNSUserNameInMagicDNS: dnsConfig.UserNameInMagicDNS, DNSUserNameInMagicDNS: dnsConfig.UserNameInMagicDNS,
ACMEEmail: viper.GetString("acme_email"), ACMEEmail: viper.GetString("acme_email"),
@ -870,7 +886,7 @@ func GetHeadscaleConfig() (*Config, error) {
LogTail: logTailConfig, LogTail: logTailConfig,
RandomizeClientPort: randomizeClientPort, RandomizeClientPort: randomizeClientPort,
Policy: GetPolicyConfig(), Policy: policyConfig(),
CLI: CLIConfig{ CLI: CLIConfig{
Address: viper.GetString("cli.address"), Address: viper.GetString("cli.address"),
@ -890,10 +906,6 @@ func GetHeadscaleConfig() (*Config, error) {
}, nil }, nil
} }
func IsCLIConfigured() bool {
return viper.GetString("cli.address") != "" && viper.GetString("cli.api_key") != ""
}
type deprecator struct { type deprecator struct {
warns set.Set[string] warns set.Set[string]
fatals set.Set[string] fatals set.Set[string]

View file

@ -1,6 +1,8 @@
package types package types
import ( import (
"os"
"path/filepath"
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
@ -22,7 +24,7 @@ func TestReadConfig(t *testing.T) {
name: "unmarshal-dns-full-config", name: "unmarshal-dns-full-config",
configPath: "testdata/dns_full.yaml", configPath: "testdata/dns_full.yaml",
setup: func(t *testing.T) (any, error) { setup: func(t *testing.T) (any, error) {
dns, err := DNS() dns, err := dns()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -48,12 +50,12 @@ func TestReadConfig(t *testing.T) {
name: "dns-to-tailcfg.DNSConfig", name: "dns-to-tailcfg.DNSConfig",
configPath: "testdata/dns_full.yaml", configPath: "testdata/dns_full.yaml",
setup: func(t *testing.T) (any, error) { setup: func(t *testing.T) (any, error) {
dns, err := DNS() dns, err := dns()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return DNSToTailcfgDNS(dns), nil return dnsToTailcfgDNS(dns), nil
}, },
want: &tailcfg.DNSConfig{ want: &tailcfg.DNSConfig{
Proxied: true, Proxied: true,
@ -79,7 +81,7 @@ func TestReadConfig(t *testing.T) {
name: "unmarshal-dns-full-no-magic", name: "unmarshal-dns-full-no-magic",
configPath: "testdata/dns_full_no_magic.yaml", configPath: "testdata/dns_full_no_magic.yaml",
setup: func(t *testing.T) (any, error) { setup: func(t *testing.T) (any, error) {
dns, err := DNS() dns, err := dns()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -105,12 +107,12 @@ func TestReadConfig(t *testing.T) {
name: "dns-to-tailcfg.DNSConfig", name: "dns-to-tailcfg.DNSConfig",
configPath: "testdata/dns_full_no_magic.yaml", configPath: "testdata/dns_full_no_magic.yaml",
setup: func(t *testing.T) (any, error) { setup: func(t *testing.T) (any, error) {
dns, err := DNS() dns, err := dns()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return DNSToTailcfgDNS(dns), nil return dnsToTailcfgDNS(dns), nil
}, },
want: &tailcfg.DNSConfig{ want: &tailcfg.DNSConfig{
Proxied: false, Proxied: false,
@ -136,7 +138,7 @@ func TestReadConfig(t *testing.T) {
name: "base-domain-in-server-url-err", name: "base-domain-in-server-url-err",
configPath: "testdata/base-domain-in-server-url.yaml", configPath: "testdata/base-domain-in-server-url.yaml",
setup: func(t *testing.T) (any, error) { setup: func(t *testing.T) (any, error) {
return GetHeadscaleConfig() return LoadServerConfig()
}, },
want: nil, want: nil,
wantErr: "server_url cannot contain the base_domain, this will cause the headscale server and embedded DERP to become unreachable from the Tailscale node.", wantErr: "server_url cannot contain the base_domain, this will cause the headscale server and embedded DERP to become unreachable from the Tailscale node.",
@ -145,7 +147,7 @@ func TestReadConfig(t *testing.T) {
name: "base-domain-not-in-server-url", name: "base-domain-not-in-server-url",
configPath: "testdata/base-domain-not-in-server-url.yaml", configPath: "testdata/base-domain-not-in-server-url.yaml",
setup: func(t *testing.T) (any, error) { setup: func(t *testing.T) (any, error) {
cfg, err := GetHeadscaleConfig() cfg, err := LoadServerConfig()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -165,7 +167,7 @@ func TestReadConfig(t *testing.T) {
name: "policy-path-is-loaded", name: "policy-path-is-loaded",
configPath: "testdata/policy-path-is-loaded.yaml", configPath: "testdata/policy-path-is-loaded.yaml",
setup: func(t *testing.T) (any, error) { setup: func(t *testing.T) (any, error) {
cfg, err := GetHeadscaleConfig() cfg, err := LoadServerConfig()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -245,7 +247,7 @@ func TestReadConfigFromEnv(t *testing.T) {
setup: func(t *testing.T) (any, error) { setup: func(t *testing.T) (any, error) {
t.Logf("all settings: %#v", viper.AllSettings()) t.Logf("all settings: %#v", viper.AllSettings())
dns, err := DNS() dns, err := dns()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -289,3 +291,49 @@ func TestReadConfigFromEnv(t *testing.T) {
}) })
} }
} }
func TestTLSConfigValidation(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "headscale")
if err != nil {
t.Fatal(err)
}
// defer os.RemoveAll(tmpDir)
configYaml := []byte(`---
tls_letsencrypt_hostname: example.com
tls_letsencrypt_challenge_type: ""
tls_cert_path: abc.pem
noise:
private_key_path: noise_private.key`)
// Populate a custom config file
configFilePath := filepath.Join(tmpDir, "config.yaml")
err = os.WriteFile(configFilePath, configYaml, 0o600)
if err != nil {
t.Fatalf("Couldn't write file %s", configFilePath)
}
// Check configuration validation errors (1)
err = LoadConfig(tmpDir, false)
assert.NoError(t, err)
err = validateServerConfig()
assert.Error(t, err)
assert.Contains(t, err.Error(), "Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both")
assert.Contains(t, err.Error(), "Fatal config error: the only supported values for tls_letsencrypt_challenge_type are")
assert.Contains(t, err.Error(), "Fatal config error: server_url must start with https:// or http://")
// Check configuration validation errors (2)
configYaml = []byte(`---
noise:
private_key_path: noise_private.key
server_url: http://127.0.0.1:8080
tls_letsencrypt_hostname: example.com
tls_letsencrypt_challenge_type: TLS-ALPN-01
`)
err = os.WriteFile(configFilePath, configYaml, 0o600)
if err != nil {
t.Fatalf("Couldn't write file %s", configFilePath)
}
err = LoadConfig(tmpDir, false)
assert.NoError(t, err)
}

View file

@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"sort" "sort"
"strings"
"testing" "testing"
"time" "time"
@ -735,13 +736,7 @@ func TestNodeTagCommand(t *testing.T) {
assert.Equal(t, []string{"tag:test"}, node.GetForcedTags()) assert.Equal(t, []string{"tag:test"}, node.GetForcedTags())
// try to set a wrong tag and retrieve the error _, err = headscale.Execute(
type errOutput struct {
Error string `json:"error"`
}
var errorOutput errOutput
err = executeAndUnmarshal(
headscale,
[]string{ []string{
"headscale", "headscale",
"nodes", "nodes",
@ -750,10 +745,8 @@ func TestNodeTagCommand(t *testing.T) {
"-t", "wrong-tag", "-t", "wrong-tag",
"--output", "json", "--output", "json",
}, },
&errorOutput,
) )
assert.Nil(t, err) assert.ErrorContains(t, err, "tag must start with the string 'tag:'")
assert.Contains(t, errorOutput.Error, "tag must start with the string 'tag:'")
// Test list all nodes after added seconds // Test list all nodes after added seconds
resultMachines := make([]*v1.Node, len(machineKeys)) resultMachines := make([]*v1.Node, len(machineKeys))
@ -1398,18 +1391,17 @@ func TestNodeRenameCommand(t *testing.T) {
assert.Contains(t, listAllAfterRename[4].GetGivenName(), "node-5") assert.Contains(t, listAllAfterRename[4].GetGivenName(), "node-5")
// Test failure for too long names // Test failure for too long names
result, err := headscale.Execute( _, err = headscale.Execute(
[]string{ []string{
"headscale", "headscale",
"nodes", "nodes",
"rename", "rename",
"--identifier", "--identifier",
fmt.Sprintf("%d", listAll[4].GetId()), fmt.Sprintf("%d", listAll[4].GetId()),
"testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine12345678901234567890", strings.Repeat("t", 64),
}, },
) )
assert.Nil(t, err) assert.ErrorContains(t, err, "not be over 63 chars")
assert.Contains(t, result, "not be over 63 chars")
var listAllAfterRenameAttempt []v1.Node var listAllAfterRenameAttempt []v1.Node
err = executeAndUnmarshal( err = executeAndUnmarshal(
@ -1536,7 +1528,7 @@ func TestNodeMoveCommand(t *testing.T) {
assert.Equal(t, allNodes[0].GetUser(), node.GetUser()) assert.Equal(t, allNodes[0].GetUser(), node.GetUser())
assert.Equal(t, allNodes[0].GetUser().GetName(), "new-user") assert.Equal(t, allNodes[0].GetUser().GetName(), "new-user")
moveToNonExistingNSResult, err := headscale.Execute( _, err = headscale.Execute(
[]string{ []string{
"headscale", "headscale",
"nodes", "nodes",
@ -1549,11 +1541,9 @@ func TestNodeMoveCommand(t *testing.T) {
"json", "json",
}, },
) )
assert.Nil(t, err) assert.ErrorContains(
assert.Contains(
t, t,
moveToNonExistingNSResult, err,
"user not found", "user not found",
) )
assert.Equal(t, node.GetUser().GetName(), "new-user") assert.Equal(t, node.GetUser().GetName(), "new-user")