diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index af4391a3..593fbd49 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -33,15 +33,19 @@ const ( HeadscaleDateTimeFormat = "2006-01-02 15:04:05" ) -func LoadConfig(path string) error { - viper.SetConfigName("config") - if path == "" { - viper.AddConfigPath("/etc/headscale/") - viper.AddConfigPath("$HOME/.headscale") - viper.AddConfigPath(".") +func LoadConfig(path string, isFile bool) error { + if isFile { + viper.SetConfigFile(path) } else { - // For testing - viper.AddConfigPath(path) + viper.SetConfigName("config") + if path == "" { + viper.AddConfigPath("/etc/headscale/") + viper.AddConfigPath("$HOME/.headscale") + viper.AddConfigPath(".") + } else { + // For testing + viper.AddConfigPath(path) + } } viper.SetEnvPrefix("headscale") diff --git a/cmd/headscale/headscale.go b/cmd/headscale/headscale.go index 600b186e..28b5f2ed 100644 --- a/cmd/headscale/headscale.go +++ b/cmd/headscale/headscale.go @@ -43,7 +43,7 @@ func main() { NoColor: !colors, }) - if err := cli.LoadConfig(""); err != nil { + if err := cli.LoadConfig("", false); err != nil { log.Fatal().Caller().Err(err) } diff --git a/cmd/headscale/headscale_test.go b/cmd/headscale/headscale_test.go index faf55f4c..92bba4b0 100644 --- a/cmd/headscale/headscale_test.go +++ b/cmd/headscale/headscale_test.go @@ -27,6 +27,51 @@ func (s *Suite) SetUpSuite(c *check.C) { func (s *Suite) TearDownSuite(c *check.C) { } +func (*Suite) TestConfigFileLoading(c *check.C) { + tmpDir, err := ioutil.TempDir("", "headscale") + if err != nil { + c.Fatal(err) + } + defer os.RemoveAll(tmpDir) + + path, err := os.Getwd() + if err != nil { + c.Fatal(err) + } + + cfgFile := filepath.Join(tmpDir, "config.yaml") + + // Symlink the example config file + err = os.Symlink( + filepath.Clean(path+"/../../config-example.yaml"), + cfgFile, + ) + if err != nil { + c.Fatal(err) + } + + // Load example config, it should load without validation errors + err = cli.LoadConfig(cfgFile, true) + c.Assert(err, check.IsNil) + + // Test that config file was interpreted correctly + c.Assert(viper.GetString("server_url"), check.Equals, "http://127.0.0.1:8080") + c.Assert(viper.GetString("listen_addr"), check.Equals, "0.0.0.0:8080") + c.Assert(viper.GetString("metrics_listen_addr"), check.Equals, "127.0.0.1:9090") + c.Assert(viper.GetString("db_type"), check.Equals, "sqlite3") + c.Assert(viper.GetString("db_path"), check.Equals, "/var/lib/headscale/db.sqlite") + c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "") + c.Assert(viper.GetString("tls_letsencrypt_listen"), check.Equals, ":http") + c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01") + c.Assert(viper.GetStringSlice("dns_config.nameservers")[0], check.Equals, "1.1.1.1") + c.Assert( + cli.GetFileMode("unix_socket_permission"), + check.Equals, + fs.FileMode(0o770), + ) + c.Assert(viper.GetBool("logtail.enabled"), check.Equals, false) +} + func (*Suite) TestConfigLoading(c *check.C) { tmpDir, err := ioutil.TempDir("", "headscale") if err != nil { @@ -49,7 +94,7 @@ func (*Suite) TestConfigLoading(c *check.C) { } // Load example config, it should load without validation errors - err = cli.LoadConfig(tmpDir) + err = cli.LoadConfig(tmpDir, false) c.Assert(err, check.IsNil) // Test that config file was interpreted correctly @@ -92,7 +137,7 @@ func (*Suite) TestDNSConfigLoading(c *check.C) { } // Load example config, it should load without validation errors - err = cli.LoadConfig(tmpDir) + err = cli.LoadConfig(tmpDir, false) c.Assert(err, check.IsNil) dnsConfig, baseDomain := cli.GetDNSConfig() @@ -125,7 +170,7 @@ func (*Suite) TestTLSConfigValidation(c *check.C) { writeConfig(c, tmpDir, configYaml) // Check configuration validation errors (1) - err = cli.LoadConfig(tmpDir) + err = cli.LoadConfig(tmpDir, false) c.Assert(err, check.NotNil) // check.Matches can not handle multiline strings tmp := strings.ReplaceAll(err.Error(), "\n", "***") @@ -150,6 +195,6 @@ func (*Suite) TestTLSConfigValidation(c *check.C) { "---\nserver_url: \"http://127.0.0.1:8080\"\ntls_letsencrypt_hostname: \"example.com\"\ntls_letsencrypt_challenge_type: \"TLS-ALPN-01\"", ) writeConfig(c, tmpDir, configYaml) - err = cli.LoadConfig(tmpDir) + err = cli.LoadConfig(tmpDir, false) c.Assert(err, check.IsNil) }