diff --git a/Makefile b/Makefile index b053e297..f6bf5ba5 100644 --- a/Makefile +++ b/Makefile @@ -7,7 +7,7 @@ build: dev: lint test build test: - go test -coverprofile=coverage.out + @go test -coverprofile=coverage.out ./... coverprofile_func: go tool cover -func=coverage.out diff --git a/cmd/headscale/headscale.go b/cmd/headscale/headscale.go index 634e85f2..f9dcaa17 100644 --- a/cmd/headscale/headscale.go +++ b/cmd/headscale/headscale.go @@ -249,11 +249,16 @@ var createPreAuthKeyCmd = &cobra.Command{ }, } -func main() { +func loadConfig(path string) { viper.SetConfigName("config") - viper.AddConfigPath("/etc/headscale/") - viper.AddConfigPath("$HOME/.headscale") - viper.AddConfigPath(".") + if path == "" { + viper.AddConfigPath("/etc/headscale/") + viper.AddConfigPath("$HOME/.headscale") + viper.AddConfigPath(".") + } else { + // For testing + viper.AddConfigPath(path) + } viper.AutomaticEnv() viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache") @@ -279,6 +284,10 @@ func main() { if !strings.HasPrefix(viper.GetString("server_url"), "http://") && !strings.HasPrefix(viper.GetString("server_url"), "https://") { log.Fatalf("Fatal config error: server_url must start with https:// or http://") } +} + +func main() { + loadConfig("") headscaleCmd.AddCommand(versionCmd) headscaleCmd.AddCommand(serveCmd) @@ -302,7 +311,6 @@ func main() { fmt.Println(err) os.Exit(-1) } - } func absPath(path string) string { diff --git a/cmd/headscale/headscale_test.go b/cmd/headscale/headscale_test.go new file mode 100644 index 00000000..c1fa3c07 --- /dev/null +++ b/cmd/headscale/headscale_test.go @@ -0,0 +1,56 @@ +package main + +import ( + "io/ioutil" + "os" + "path/filepath" + "testing" + + "github.com/spf13/viper" + "gopkg.in/check.v1" +) + +func Test(t *testing.T) { + check.TestingT(t) +} + +var _ = check.Suite(&Suite{}) + +type Suite struct{} + +func (s *Suite) SetUpSuite(c *check.C) { +} + +func (s *Suite) TearDownSuite(c *check.C) { + +} + +func (*Suite) TestConfigLoading(c *check.C) { + tmpDir, err := ioutil.TempDir("", "headscale") + if err != nil { + c.Fatal(err) + } + defer os.RemoveAll(tmpDir) + + path, err := os.Getwd() + if err != nil { + c.Fatal(err) + } + + // Symlink the example config file + err = os.Symlink(filepath.Clean(path+"/../../config.json.example"), filepath.Join(tmpDir, "config.json")) + if err != nil { + c.Fatal(err) + } + + // Load config + loadConfig(tmpDir) + + // Test that config file was interpreted correctly + c.Assert(viper.GetString("server_url"), check.Equals, "http://192.168.1.12:8000") + c.Assert(viper.GetString("listen_addr"), check.Equals, "0.0.0.0:8000") + c.Assert(viper.GetString("derp_map_path"), check.Equals, "derp.yaml") + c.Assert(viper.GetString("db_port"), check.Equals, "5432") + c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "") + c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01") +}