From 2289a2acbf6ffb838549ae8bd63b2026af52a2aa Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 6 Jun 2023 10:23:39 +0200 Subject: [PATCH] move Config definitions into types Signed-off-by: Kristoffer Dalby --- cmd/headscale/cli/root.go | 10 +++++----- cmd/headscale/cli/utils.go | 5 +++-- cmd/headscale/headscale_test.go | 14 +++++++------- hscontrol/app.go | 8 ++++---- hscontrol/derp.go | 3 ++- hscontrol/suite_test.go | 5 +++-- hscontrol/{ => types}/config.go | 14 +++++++------- integration/auth_oidc_test.go | 6 +++--- 8 files changed, 34 insertions(+), 31 deletions(-) rename hscontrol/{ => types}/config.go (98%) diff --git a/cmd/headscale/cli/root.go b/cmd/headscale/cli/root.go index ab76fff5..f70945e6 100644 --- a/cmd/headscale/cli/root.go +++ b/cmd/headscale/cli/root.go @@ -5,7 +5,7 @@ import ( "os" "runtime" - "github.com/juanfont/headscale/hscontrol" + "github.com/juanfont/headscale/hscontrol/types" "github.com/rs/zerolog" "github.com/rs/zerolog/log" "github.com/spf13/cobra" @@ -38,18 +38,18 @@ func initConfig() { cfgFile = os.Getenv("HEADSCALE_CONFIG") } if cfgFile != "" { - err := hscontrol.LoadConfig(cfgFile, true) + err := types.LoadConfig(cfgFile, true) if err != nil { log.Fatal().Caller().Err(err).Msgf("Error loading config file %s", cfgFile) } } else { - err := hscontrol.LoadConfig("", false) + err := types.LoadConfig("", false) if err != nil { log.Fatal().Caller().Err(err).Msgf("Error loading config") } } - cfg, err := hscontrol.GetHeadscaleConfig() + cfg, err := types.GetHeadscaleConfig() if err != nil { log.Fatal().Caller().Err(err) } @@ -64,7 +64,7 @@ func initConfig() { zerolog.SetGlobalLevel(zerolog.Disabled) } - if cfg.Log.Format == hscontrol.JSONLogFormat { + if cfg.Log.Format == types.JSONLogFormat { log.Logger = log.Output(os.Stdout) } diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index 5ce7816f..baaf2094 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -11,6 +11,7 @@ import ( v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol" "github.com/juanfont/headscale/hscontrol/policy" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "google.golang.org/grpc" @@ -25,7 +26,7 @@ const ( ) func getHeadscaleApp() (*hscontrol.Headscale, error) { - cfg, err := hscontrol.GetHeadscaleConfig() + cfg, err := types.GetHeadscaleConfig() if err != nil { return nil, fmt.Errorf( "failed to load configuration while creating headscale instance: %w", @@ -57,7 +58,7 @@ func getHeadscaleApp() (*hscontrol.Headscale, error) { } func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) { - cfg, err := hscontrol.GetHeadscaleConfig() + cfg, err := types.GetHeadscaleConfig() if err != nil { log.Fatal(). Err(err). diff --git a/cmd/headscale/headscale_test.go b/cmd/headscale/headscale_test.go index 89fd7754..897e2537 100644 --- a/cmd/headscale/headscale_test.go +++ b/cmd/headscale/headscale_test.go @@ -7,7 +7,7 @@ import ( "strings" "testing" - "github.com/juanfont/headscale/hscontrol" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/spf13/viper" "gopkg.in/check.v1" @@ -51,7 +51,7 @@ func (*Suite) TestConfigFileLoading(c *check.C) { } // Load example config, it should load without validation errors - err = hscontrol.LoadConfig(cfgFile, true) + err = types.LoadConfig(cfgFile, true) c.Assert(err, check.IsNil) // Test that config file was interpreted correctly @@ -94,7 +94,7 @@ func (*Suite) TestConfigLoading(c *check.C) { } // Load example config, it should load without validation errors - err = hscontrol.LoadConfig(tmpDir, false) + err = types.LoadConfig(tmpDir, false) c.Assert(err, check.IsNil) // Test that config file was interpreted correctly @@ -138,10 +138,10 @@ func (*Suite) TestDNSConfigLoading(c *check.C) { } // Load example config, it should load without validation errors - err = hscontrol.LoadConfig(tmpDir, false) + err = types.LoadConfig(tmpDir, false) c.Assert(err, check.IsNil) - dnsConfig, baseDomain := hscontrol.GetDNSConfig() + dnsConfig, baseDomain := types.GetDNSConfig() c.Assert(dnsConfig.Nameservers[0].String(), check.Equals, "1.1.1.1") c.Assert(dnsConfig.Resolvers[0].Addr, check.Equals, "1.1.1.1") @@ -173,7 +173,7 @@ noise: writeConfig(c, tmpDir, configYaml) // Check configuration validation errors (1) - err = hscontrol.LoadConfig(tmpDir, false) + err = types.LoadConfig(tmpDir, false) c.Assert(err, check.NotNil) // check.Matches can not handle multiline strings tmp := strings.ReplaceAll(err.Error(), "\n", "***") @@ -202,6 +202,6 @@ tls_letsencrypt_hostname: example.com tls_letsencrypt_challenge_type: TLS-ALPN-01 `) writeConfig(c, tmpDir, configYaml) - err = hscontrol.LoadConfig(tmpDir, false) + err = types.LoadConfig(tmpDir, false) c.Assert(err, check.IsNil) } diff --git a/hscontrol/app.go b/hscontrol/app.go index ac35b6b9..ecce42b1 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -75,7 +75,7 @@ const ( // Headscale represents the base app of the service. type Headscale struct { - cfg *Config + cfg *types.Config db *db.HSDatabase dbString string dbType string @@ -102,7 +102,7 @@ type Headscale struct { cancelStateUpdateChan chan struct{} } -func NewHeadscale(cfg *Config) (*Headscale, error) { +func NewHeadscale(cfg *types.Config) (*Headscale, error) { privateKey, err := readOrCreatePrivateKey(cfg.PrivateKeyPath) if err != nil { return nil, fmt.Errorf("failed to read or create private key: %w", err) @@ -778,13 +778,13 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) { } switch h.cfg.TLS.LetsEncrypt.ChallengeType { - case tlsALPN01ChallengeType: + case types.TlsALPN01ChallengeType: // Configuration via autocert with TLS-ALPN-01 (https://tools.ietf.org/html/rfc8737) // The RFC requires that the validation is done on port 443; in other words, headscale // must be reachable on port 443. return certManager.TLSConfig(), nil - case http01ChallengeType: + case types.Http01ChallengeType: // Configuration via autocert with HTTP-01. This requires listening on // port 80 for the certificate validation in addition to the headscale // service, which can be configured to run on any other port. diff --git a/hscontrol/derp.go b/hscontrol/derp.go index fbc366a9..1e712a51 100644 --- a/hscontrol/derp.go +++ b/hscontrol/derp.go @@ -9,6 +9,7 @@ import ( "os" "time" + "github.com/juanfont/headscale/hscontrol/types" "github.com/rs/zerolog/log" "gopkg.in/yaml.v3" "tailscale.com/tailcfg" @@ -80,7 +81,7 @@ func mergeDERPMaps(derpMaps []*tailcfg.DERPMap) *tailcfg.DERPMap { return &result } -func GetDERPMap(cfg DERPConfig) *tailcfg.DERPMap { +func GetDERPMap(cfg types.DERPConfig) *tailcfg.DERPMap { derpMaps := make([]*tailcfg.DERPMap, 0) for _, path := range cfg.Paths { diff --git a/hscontrol/suite_test.go b/hscontrol/suite_test.go index 69a651a8..efee33e0 100644 --- a/hscontrol/suite_test.go +++ b/hscontrol/suite_test.go @@ -5,6 +5,7 @@ import ( "os" "testing" + "github.com/juanfont/headscale/hscontrol/types" "gopkg.in/check.v1" ) @@ -38,7 +39,7 @@ func (s *Suite) ResetDB(c *check.C) { if err != nil { c.Fatal(err) } - cfg := Config{ + cfg := types.Config{ PrivateKeyPath: tmpDir + "/private.key", NoisePrivateKeyPath: tmpDir + "/noise_private.key", DBtype: "sqlite3", @@ -46,7 +47,7 @@ func (s *Suite) ResetDB(c *check.C) { IPPrefixes: []netip.Prefix{ netip.MustParsePrefix("10.27.0.0/23"), }, - OIDC: OIDCConfig{ + OIDC: types.OIDCConfig{ StripEmaildomain: false, }, } diff --git a/hscontrol/config.go b/hscontrol/types/config.go similarity index 98% rename from hscontrol/config.go rename to hscontrol/types/config.go index 63deace0..1172b3e7 100644 --- a/hscontrol/config.go +++ b/hscontrol/types/config.go @@ -1,4 +1,4 @@ -package hscontrol +package types import ( "errors" @@ -23,8 +23,8 @@ import ( ) const ( - tlsALPN01ChallengeType = "TLS-ALPN-01" - http01ChallengeType = "HTTP-01" + TlsALPN01ChallengeType = "TLS-ALPN-01" + Http01ChallengeType = "HTTP-01" JSONLogFormat = "json" TextLogFormat = "text" @@ -165,7 +165,7 @@ func LoadConfig(path string, isFile bool) error { viper.AutomaticEnv() viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache") - viper.SetDefault("tls_letsencrypt_challenge_type", http01ChallengeType) + viper.SetDefault("tls_letsencrypt_challenge_type", Http01ChallengeType) viper.SetDefault("log.level", "info") viper.SetDefault("log.format", TextLogFormat) @@ -222,15 +222,15 @@ func LoadConfig(path string, isFile bool) error { } if (viper.GetString("tls_letsencrypt_hostname") != "") && - (viper.GetString("tls_letsencrypt_challenge_type") == tlsALPN01ChallengeType) && + (viper.GetString("tls_letsencrypt_challenge_type") == TlsALPN01ChallengeType) && (!strings.HasSuffix(viper.GetString("listen_addr"), ":443")) { // this is only a warning because there could be something sitting in front of headscale that redirects the traffic (e.g. an iptables rule) log.Warn(). Msg("Warning: when using tls_letsencrypt_hostname with TLS-ALPN-01 as challenge type, headscale must be reachable on port 443, i.e. listen_addr should probably end in :443") } - if (viper.GetString("tls_letsencrypt_challenge_type") != http01ChallengeType) && - (viper.GetString("tls_letsencrypt_challenge_type") != tlsALPN01ChallengeType) { + if (viper.GetString("tls_letsencrypt_challenge_type") != Http01ChallengeType) && + (viper.GetString("tls_letsencrypt_challenge_type") != TlsALPN01ChallengeType) { errorText += "Fatal config error: the only supported values for tls_letsencrypt_challenge_type are HTTP-01 and TLS-ALPN-01\n" } diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index 452f8520..9b05af93 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -14,7 +14,7 @@ import ( "testing" "time" - "github.com/juanfont/headscale/hscontrol" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/hsic" @@ -214,7 +214,7 @@ func (s *AuthOIDCScenario) CreateHeadscaleEnv( return nil } -func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*hscontrol.OIDCConfig, error) { +func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*types.OIDCConfig, error) { port, err := dockertestutil.RandomFreeHostPort() if err != nil { log.Fatalf("could not find an open port: %s", err) @@ -288,7 +288,7 @@ func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*hscontrol.OIDC log.Printf("headscale mock oidc is ready for tests at %s", hostEndpoint) - return &hscontrol.OIDCConfig{ + return &types.OIDCConfig{ Issuer: fmt.Sprintf( "http://%s/oidc", net.JoinHostPort(s.mockOIDC.GetIPInNetwork(s.network), strconv.Itoa(port)),