diff --git a/Makefile b/Makefile index b053e297..f6bf5ba5 100644 --- a/Makefile +++ b/Makefile @@ -7,7 +7,7 @@ build: dev: lint test build test: - go test -coverprofile=coverage.out + @go test -coverprofile=coverage.out ./... coverprofile_func: go tool cover -func=coverage.out diff --git a/cmd/headscale/headscale.go b/cmd/headscale/headscale.go index 634e85f2..776ffb92 100644 --- a/cmd/headscale/headscale.go +++ b/cmd/headscale/headscale.go @@ -1,6 +1,7 @@ package main import ( + "errors" "fmt" "io" "log" @@ -249,11 +250,16 @@ var createPreAuthKeyCmd = &cobra.Command{ }, } -func main() { +func loadConfig(path string) error { viper.SetConfigName("config") - viper.AddConfigPath("/etc/headscale/") - viper.AddConfigPath("$HOME/.headscale") - viper.AddConfigPath(".") + if path == "" { + viper.AddConfigPath("/etc/headscale/") + viper.AddConfigPath("$HOME/.headscale") + viper.AddConfigPath(".") + } else { + // For testing + viper.AddConfigPath(path) + } viper.AutomaticEnv() viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache") @@ -261,23 +267,37 @@ func main() { err := viper.ReadInConfig() if err != nil { - log.Fatalf("Fatal error config file: %s \n", err) + return errors.New(fmt.Sprintf("Fatal error reading config file: %s \n", err)) } + // Collect any validation errors and return them all at once + var errorText string if (viper.GetString("tls_letsencrypt_hostname") != "") && ((viper.GetString("tls_cert_path") != "") || (viper.GetString("tls_key_path") != "")) { - log.Fatalf("Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both") + errorText += "Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both\n" } if (viper.GetString("tls_letsencrypt_hostname") != "") && (viper.GetString("tls_letsencrypt_challenge_type") == "TLS-ALPN-01") && (!strings.HasSuffix(viper.GetString("listen_addr"), ":443")) { - log.Fatalf("Fatal config error: when using tls_letsencrypt_hostname with TLS-ALPN-01 as challenge type, listen_addr must end in :443") + errorText += "Fatal config error: when using tls_letsencrypt_hostname with TLS-ALPN-01 as challenge type, listen_addr must end in :443\n" } if (viper.GetString("tls_letsencrypt_challenge_type") != "HTTP-01") && (viper.GetString("tls_letsencrypt_challenge_type") != "TLS-ALPN-01") { - log.Fatalf("Fatal config error: the only supported values for tls_letsencrypt_challenge_type are HTTP-01 and TLS-ALPN-01") + errorText += "Fatal config error: the only supported values for tls_letsencrypt_challenge_type are HTTP-01 and TLS-ALPN-01\n" } if !strings.HasPrefix(viper.GetString("server_url"), "http://") && !strings.HasPrefix(viper.GetString("server_url"), "https://") { - log.Fatalf("Fatal config error: server_url must start with https:// or http://") + errorText += "Fatal config error: server_url must start with https:// or http://\n" + } + if errorText != "" { + return errors.New(strings.TrimSuffix(errorText, "\n")) + } else { + return nil + } +} + +func main() { + err := loadConfig("") + if err != nil { + log.Fatalf(err.Error()) } headscaleCmd.AddCommand(versionCmd) @@ -302,7 +322,6 @@ func main() { fmt.Println(err) os.Exit(-1) } - } func absPath(path string) string { diff --git a/cmd/headscale/headscale_test.go b/cmd/headscale/headscale_test.go new file mode 100644 index 00000000..a3894f62 --- /dev/null +++ b/cmd/headscale/headscale_test.go @@ -0,0 +1,98 @@ +package main + +import ( + "fmt" + "io/ioutil" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/spf13/viper" + "gopkg.in/check.v1" +) + +func Test(t *testing.T) { + check.TestingT(t) +} + +var _ = check.Suite(&Suite{}) + +type Suite struct{} + +func (s *Suite) SetUpSuite(c *check.C) { +} + +func (s *Suite) TearDownSuite(c *check.C) { + +} + +func (*Suite) TestConfigLoading(c *check.C) { + tmpDir, err := ioutil.TempDir("", "headscale") + if err != nil { + c.Fatal(err) + } + defer os.RemoveAll(tmpDir) + + path, err := os.Getwd() + if err != nil { + c.Fatal(err) + } + + // Symlink the example config file + err = os.Symlink(filepath.Clean(path+"/../../config.json.example"), filepath.Join(tmpDir, "config.json")) + if err != nil { + c.Fatal(err) + } + + // Load example config, it should load without validation errors + err = loadConfig(tmpDir) + c.Assert(err, check.IsNil) + + // Test that config file was interpreted correctly + c.Assert(viper.GetString("server_url"), check.Equals, "http://192.168.1.12:8000") + c.Assert(viper.GetString("listen_addr"), check.Equals, "0.0.0.0:8000") + c.Assert(viper.GetString("derp_map_path"), check.Equals, "derp.yaml") + c.Assert(viper.GetString("db_port"), check.Equals, "5432") + c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "") + c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01") +} + +func writeConfig(c *check.C, tmpDir string, configYaml []byte) { + // Populate a custom config file + configFile := filepath.Join(tmpDir, "config.yaml") + err := ioutil.WriteFile(configFile, configYaml, 0644) + if err != nil { + c.Fatalf("Couldn't write file %s", configFile) + } +} + +func (*Suite) TestTLSConfigValidation(c *check.C) { + tmpDir, err := ioutil.TempDir("", "headscale") + if err != nil { + c.Fatal(err) + } + //defer os.RemoveAll(tmpDir) + fmt.Println(tmpDir) + + configYaml := []byte("---\ntls_letsencrypt_hostname: \"example.com\"\ntls_letsencrypt_challenge_type: \"\"\ntls_cert_path: \"abc.pem\"") + writeConfig(c, tmpDir, configYaml) + + // Check configuration validation errors (1) + err = loadConfig(tmpDir) + c.Assert(err, check.NotNil) + // check.Matches can not handle multiline strings + tmp := strings.ReplaceAll(err.Error(), "\n", "***") + c.Assert(tmp, check.Matches, ".*Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both.*") + c.Assert(tmp, check.Matches, ".*Fatal config error: the only supported values for tls_letsencrypt_challenge_type are.*") + c.Assert(tmp, check.Matches, ".*Fatal config error: server_url must start with https:// or http://.*") + fmt.Println(tmp) + + // Check configuration validation errors (2) + configYaml = []byte("---\nserver_url: \"http://192.168.1.12:8000\"\ntls_letsencrypt_hostname: \"example.com\"\ntls_letsencrypt_challenge_type: \"TLS-ALPN-01\"") + fmt.Printf(string(configYaml)) + writeConfig(c, tmpDir, configYaml) + err = loadConfig(tmpDir) + c.Assert(err, check.NotNil) + c.Assert(err, check.ErrorMatches, "Fatal config error: when using tls_letsencrypt_hostname with TLS-ALPN-01 as challenge type, listen_addr must end in :443.*") +}