diff --git a/app.go b/app.go index 054fd178..01528fb9 100644 --- a/app.go +++ b/app.go @@ -6,10 +6,8 @@ import ( "errors" "fmt" "io" - "io/fs" "net" "net/http" - "net/url" "os" "os/signal" "sort" @@ -42,7 +40,6 @@ import ( "google.golang.org/grpc/reflection" "google.golang.org/grpc/status" "gorm.io/gorm" - "inet.af/netaddr" "tailscale.com/tailcfg" "tailscale.com/types/dnstype" "tailscale.com/types/key" @@ -72,95 +69,9 @@ const ( EnforcedClientAuth = "enforced" ) -// Config contains the initial Headscale configuration. -type Config struct { - ServerURL string - Addr string - MetricsAddr string - GRPCAddr string - GRPCAllowInsecure bool - EphemeralNodeInactivityTimeout time.Duration - IPPrefixes []netaddr.IPPrefix - PrivateKeyPath string - BaseDomain string - - DERP DERPConfig - - DBtype string - DBpath string - DBhost string - DBport int - DBname string - DBuser string - DBpass string - - TLSLetsEncryptListen string - TLSLetsEncryptHostname string - TLSLetsEncryptCacheDir string - TLSLetsEncryptChallengeType string - - TLSCertPath string - TLSKeyPath string - TLSClientAuthMode tls.ClientAuthType - - ACMEURL string - ACMEEmail string - - DNSConfig *tailcfg.DNSConfig - - UnixSocket string - UnixSocketPermission fs.FileMode - - OIDC OIDCConfig - - LogTail LogTailConfig - - CLI CLIConfig - - ACL ACLConfig -} - -type OIDCConfig struct { - Issuer string - ClientID string - ClientSecret string - Scope []string - ExtraParams map[string]string - AllowedDomains []string - AllowedUsers []string - StripEmaildomain bool -} - -type DERPConfig struct { - ServerEnabled bool - ServerRegionID int - ServerRegionCode string - ServerRegionName string - STUNAddr string - URLs []url.URL - Paths []string - AutoUpdate bool - UpdateFrequency time.Duration -} - -type LogTailConfig struct { - Enabled bool -} - -type CLIConfig struct { - Address string - APIKey string - Timeout time.Duration - Insecure bool -} - -type ACLConfig struct { - PolicyPath string -} - // Headscale represents the base app of the service. type Headscale struct { - cfg Config + cfg *Config db *gorm.DB dbString string dbType string @@ -204,7 +115,7 @@ func LookupTLSClientAuthMode(mode string) (tls.ClientAuthType, bool) { } } -func NewHeadscale(cfg Config) (*Headscale, error) { +func NewHeadscale(cfg *Config) (*Headscale, error) { privKey, err := readOrCreatePrivateKey(cfg.PrivateKeyPath) if err != nil { return nil, fmt.Errorf("failed to read or create private key: %w", err) @@ -778,7 +689,7 @@ func (h *Headscale) Serve() error { func (h *Headscale) getTLSSettings() (*tls.Config, error) { var err error - if h.cfg.TLSLetsEncryptHostname != "" { + if h.cfg.TLS.LetsEncrypt.Hostname != "" { if !strings.HasPrefix(h.cfg.ServerURL, "https://") { log.Warn(). Msg("Listening with TLS but ServerURL does not start with https://") @@ -786,15 +697,15 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) { certManager := autocert.Manager{ Prompt: autocert.AcceptTOS, - HostPolicy: autocert.HostWhitelist(h.cfg.TLSLetsEncryptHostname), - Cache: autocert.DirCache(h.cfg.TLSLetsEncryptCacheDir), + HostPolicy: autocert.HostWhitelist(h.cfg.TLS.LetsEncrypt.Hostname), + Cache: autocert.DirCache(h.cfg.TLS.LetsEncrypt.CacheDir), Client: &acme.Client{ DirectoryURL: h.cfg.ACMEURL, }, Email: h.cfg.ACMEEmail, } - switch h.cfg.TLSLetsEncryptChallengeType { + switch h.cfg.TLS.LetsEncrypt.ChallengeType { case "TLS-ALPN-01": // Configuration via autocert with TLS-ALPN-01 (https://tools.ietf.org/html/rfc8737) // The RFC requires that the validation is done on port 443; in other words, headscale @@ -808,7 +719,7 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) { go func() { log.Fatal(). Caller(). - Err(http.ListenAndServe(h.cfg.TLSLetsEncryptListen, certManager.HTTPHandler(http.HandlerFunc(h.redirect)))). + Err(http.ListenAndServe(h.cfg.TLS.LetsEncrypt.Listen, certManager.HTTPHandler(http.HandlerFunc(h.redirect)))). Msg("failed to set up a HTTP server") }() @@ -817,7 +728,7 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) { default: return nil, errUnsupportedLetsEncryptChallengeType } - } else if h.cfg.TLSCertPath == "" { + } else if h.cfg.TLS.CertPath == "" { if !strings.HasPrefix(h.cfg.ServerURL, "http://") { log.Warn().Msg("Listening without TLS but ServerURL does not start with http://") } @@ -830,16 +741,16 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) { log.Info().Msg(fmt.Sprintf( "Client authentication (mTLS) is \"%s\". See the docs to learn about configuring this setting.", - h.cfg.TLSClientAuthMode)) + h.cfg.TLS.ClientAuthMode)) tlsConfig := &tls.Config{ - ClientAuth: h.cfg.TLSClientAuthMode, + ClientAuth: h.cfg.TLS.ClientAuthMode, NextProtos: []string{"http/1.1"}, Certificates: make([]tls.Certificate, 1), MinVersion: tls.VersionTLS12, } - tlsConfig.Certificates[0], err = tls.LoadX509KeyPair(h.cfg.TLSCertPath, h.cfg.TLSKeyPath) + tlsConfig.Certificates[0], err = tls.LoadX509KeyPair(h.cfg.TLS.CertPath, h.cfg.TLS.KeyPath) return tlsConfig, err } diff --git a/app_test.go b/app_test.go index 96036a1d..170db482 100644 --- a/app_test.go +++ b/app_test.go @@ -46,7 +46,7 @@ func (s *Suite) ResetDB(c *check.C) { } app = Headscale{ - cfg: cfg, + cfg: &cfg, dbType: "sqlite3", dbString: tmpDir + "/headscale_test.db", } diff --git a/cmd/headscale/cli/server.go b/cmd/headscale/cli/server.go index c19580b9..a1d19600 100644 --- a/cmd/headscale/cli/server.go +++ b/cmd/headscale/cli/server.go @@ -16,12 +16,12 @@ var serveCmd = &cobra.Command{ return nil }, Run: func(cmd *cobra.Command, args []string) { - h, err := getHeadscaleApp() + app, err := getHeadscaleApp() if err != nil { log.Fatal().Caller().Err(err).Msg("Error initializing") } - err = h.Serve() + err = app.Serve() if err != nil { log.Fatal().Caller().Err(err).Msg("Error starting server") } diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index af4391a3..f5c679c2 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -4,17 +4,11 @@ import ( "context" "crypto/tls" "encoding/json" - "errors" "fmt" - "io/fs" - "net/url" "os" "reflect" - "strconv" - "strings" "time" - "github.com/coreos/go-oidc/v3/oidc" "github.com/juanfont/headscale" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/rs/zerolog/log" @@ -23,380 +17,18 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "gopkg.in/yaml.v2" - "inet.af/netaddr" - "tailscale.com/tailcfg" - "tailscale.com/types/dnstype" ) const ( - PermissionFallback = 0o700 HeadscaleDateTimeFormat = "2006-01-02 15:04:05" ) -func LoadConfig(path string) error { - viper.SetConfigName("config") - if path == "" { - viper.AddConfigPath("/etc/headscale/") - viper.AddConfigPath("$HOME/.headscale") - viper.AddConfigPath(".") - } else { - // For testing - viper.AddConfigPath(path) - } - - viper.SetEnvPrefix("headscale") - viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) - viper.AutomaticEnv() - - viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache") - viper.SetDefault("tls_letsencrypt_challenge_type", "HTTP-01") - viper.SetDefault("tls_client_auth_mode", "relaxed") - - viper.SetDefault("log_level", "info") - - viper.SetDefault("dns_config", nil) - - viper.SetDefault("derp.server.enabled", false) - viper.SetDefault("derp.server.stun.enabled", true) - - viper.SetDefault("unix_socket", "/var/run/headscale.sock") - viper.SetDefault("unix_socket_permission", "0o770") - - viper.SetDefault("grpc_listen_addr", ":50443") - viper.SetDefault("grpc_allow_insecure", false) - - viper.SetDefault("cli.timeout", "5s") - viper.SetDefault("cli.insecure", false) - - viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"}) - viper.SetDefault("oidc.strip_email_domain", true) - - viper.SetDefault("logtail.enabled", false) - - if err := viper.ReadInConfig(); err != nil { - return fmt.Errorf("fatal error reading config file: %w", err) - } - - // Collect any validation errors and return them all at once - var errorText string - if (viper.GetString("tls_letsencrypt_hostname") != "") && - ((viper.GetString("tls_cert_path") != "") || (viper.GetString("tls_key_path") != "")) { - errorText += "Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both\n" - } - - if (viper.GetString("tls_letsencrypt_hostname") != "") && - (viper.GetString("tls_letsencrypt_challenge_type") == "TLS-ALPN-01") && - (!strings.HasSuffix(viper.GetString("listen_addr"), ":443")) { - // this is only a warning because there could be something sitting in front of headscale that redirects the traffic (e.g. an iptables rule) - log.Warn(). - Msg("Warning: when using tls_letsencrypt_hostname with TLS-ALPN-01 as challenge type, headscale must be reachable on port 443, i.e. listen_addr should probably end in :443") - } - - if (viper.GetString("tls_letsencrypt_challenge_type") != "HTTP-01") && - (viper.GetString("tls_letsencrypt_challenge_type") != "TLS-ALPN-01") { - errorText += "Fatal config error: the only supported values for tls_letsencrypt_challenge_type are HTTP-01 and TLS-ALPN-01\n" - } - - if !strings.HasPrefix(viper.GetString("server_url"), "http://") && - !strings.HasPrefix(viper.GetString("server_url"), "https://") { - errorText += "Fatal config error: server_url must start with https:// or http://\n" - } - - _, authModeValid := headscale.LookupTLSClientAuthMode( - viper.GetString("tls_client_auth_mode"), - ) - - if !authModeValid { - errorText += fmt.Sprintf( - "Invalid tls_client_auth_mode supplied: %s. Accepted values: %s, %s, %s.", - viper.GetString("tls_client_auth_mode"), - headscale.DisabledClientAuth, - headscale.RelaxedClientAuth, - headscale.EnforcedClientAuth) - } - - if errorText != "" { - //nolint - return errors.New(strings.TrimSuffix(errorText, "\n")) - } else { - return nil - } -} - -func GetDERPConfig() headscale.DERPConfig { - serverEnabled := viper.GetBool("derp.server.enabled") - serverRegionID := viper.GetInt("derp.server.region_id") - serverRegionCode := viper.GetString("derp.server.region_code") - serverRegionName := viper.GetString("derp.server.region_name") - stunAddr := viper.GetString("derp.server.stun_listen_addr") - - if serverEnabled && stunAddr == "" { - log.Fatal(). - Msg("derp.server.stun_listen_addr must be set if derp.server.enabled is true") - } - - urlStrs := viper.GetStringSlice("derp.urls") - - urls := make([]url.URL, len(urlStrs)) - for index, urlStr := range urlStrs { - urlAddr, err := url.Parse(urlStr) - if err != nil { - log.Error(). - Str("url", urlStr). - Err(err). - Msg("Failed to parse url, ignoring...") - } - - urls[index] = *urlAddr - } - - paths := viper.GetStringSlice("derp.paths") - - autoUpdate := viper.GetBool("derp.auto_update_enabled") - updateFrequency := viper.GetDuration("derp.update_frequency") - - return headscale.DERPConfig{ - ServerEnabled: serverEnabled, - ServerRegionID: serverRegionID, - ServerRegionCode: serverRegionCode, - ServerRegionName: serverRegionName, - STUNAddr: stunAddr, - URLs: urls, - Paths: paths, - AutoUpdate: autoUpdate, - UpdateFrequency: updateFrequency, - } -} - -func GetLogTailConfig() headscale.LogTailConfig { - enabled := viper.GetBool("logtail.enabled") - - return headscale.LogTailConfig{ - Enabled: enabled, - } -} - -func GetACLConfig() headscale.ACLConfig { - policyPath := viper.GetString("acl_policy_path") - - return headscale.ACLConfig{ - PolicyPath: policyPath, - } -} - -func GetDNSConfig() (*tailcfg.DNSConfig, string) { - if viper.IsSet("dns_config") { - dnsConfig := &tailcfg.DNSConfig{} - - if viper.IsSet("dns_config.nameservers") { - nameserversStr := viper.GetStringSlice("dns_config.nameservers") - - nameservers := make([]netaddr.IP, len(nameserversStr)) - resolvers := make([]dnstype.Resolver, len(nameserversStr)) - - for index, nameserverStr := range nameserversStr { - nameserver, err := netaddr.ParseIP(nameserverStr) - if err != nil { - log.Error(). - Str("func", "getDNSConfig"). - Err(err). - Msgf("Could not parse nameserver IP: %s", nameserverStr) - } - - nameservers[index] = nameserver - resolvers[index] = dnstype.Resolver{ - Addr: nameserver.String(), - } - } - - dnsConfig.Nameservers = nameservers - dnsConfig.Resolvers = resolvers - } - - if viper.IsSet("dns_config.restricted_nameservers") { - if len(dnsConfig.Nameservers) > 0 { - dnsConfig.Routes = make(map[string][]dnstype.Resolver) - restrictedDNS := viper.GetStringMapStringSlice( - "dns_config.restricted_nameservers", - ) - for domain, restrictedNameservers := range restrictedDNS { - restrictedResolvers := make( - []dnstype.Resolver, - len(restrictedNameservers), - ) - for index, nameserverStr := range restrictedNameservers { - nameserver, err := netaddr.ParseIP(nameserverStr) - if err != nil { - log.Error(). - Str("func", "getDNSConfig"). - Err(err). - Msgf("Could not parse restricted nameserver IP: %s", nameserverStr) - } - restrictedResolvers[index] = dnstype.Resolver{ - Addr: nameserver.String(), - } - } - dnsConfig.Routes[domain] = restrictedResolvers - } - } else { - log.Warn(). - Msg("Warning: dns_config.restricted_nameservers is set, but no nameservers are configured. Ignoring restricted_nameservers.") - } - } - - if viper.IsSet("dns_config.domains") { - dnsConfig.Domains = viper.GetStringSlice("dns_config.domains") - } - - if viper.IsSet("dns_config.magic_dns") { - magicDNS := viper.GetBool("dns_config.magic_dns") - if len(dnsConfig.Nameservers) > 0 { - dnsConfig.Proxied = magicDNS - } else if magicDNS { - log.Warn(). - Msg("Warning: dns_config.magic_dns is set, but no nameservers are configured. Ignoring magic_dns.") - } - } - - var baseDomain string - if viper.IsSet("dns_config.base_domain") { - baseDomain = viper.GetString("dns_config.base_domain") - } else { - baseDomain = "headscale.net" // does not really matter when MagicDNS is not enabled - } - - return dnsConfig, baseDomain - } - - return nil, "" -} - -func GetHeadscaleConfig() headscale.Config { - dnsConfig, baseDomain := GetDNSConfig() - derpConfig := GetDERPConfig() - logConfig := GetLogTailConfig() - - configuredPrefixes := viper.GetStringSlice("ip_prefixes") - parsedPrefixes := make([]netaddr.IPPrefix, 0, len(configuredPrefixes)+1) - - legacyPrefixField := viper.GetString("ip_prefix") - if len(legacyPrefixField) > 0 { - log. - Warn(). - Msgf( - "%s, %s", - "use of 'ip_prefix' for configuration is deprecated", - "please see 'ip_prefixes' in the shipped example.", - ) - legacyPrefix, err := netaddr.ParseIPPrefix(legacyPrefixField) - if err != nil { - panic(fmt.Errorf("failed to parse ip_prefix: %w", err)) - } - parsedPrefixes = append(parsedPrefixes, legacyPrefix) - } - - for i, prefixInConfig := range configuredPrefixes { - prefix, err := netaddr.ParseIPPrefix(prefixInConfig) - if err != nil { - panic(fmt.Errorf("failed to parse ip_prefixes[%d]: %w", i, err)) - } - parsedPrefixes = append(parsedPrefixes, prefix) - } - - prefixes := make([]netaddr.IPPrefix, 0, len(parsedPrefixes)) - { - // dedup - normalizedPrefixes := make(map[string]int, len(parsedPrefixes)) - for i, p := range parsedPrefixes { - normalized, _ := p.Range().Prefix() - normalizedPrefixes[normalized.String()] = i - } - - // convert back to list - for _, i := range normalizedPrefixes { - prefixes = append(prefixes, parsedPrefixes[i]) - } - } - - if len(prefixes) < 1 { - prefixes = append(prefixes, netaddr.MustParseIPPrefix("100.64.0.0/10")) - log.Warn(). - Msgf("'ip_prefixes' not configured, falling back to default: %v", prefixes) - } - - tlsClientAuthMode, _ := headscale.LookupTLSClientAuthMode( - viper.GetString("tls_client_auth_mode"), - ) - - return headscale.Config{ - ServerURL: viper.GetString("server_url"), - Addr: viper.GetString("listen_addr"), - MetricsAddr: viper.GetString("metrics_listen_addr"), - GRPCAddr: viper.GetString("grpc_listen_addr"), - GRPCAllowInsecure: viper.GetBool("grpc_allow_insecure"), - - IPPrefixes: prefixes, - PrivateKeyPath: headscale.AbsolutePathFromConfigPath(viper.GetString("private_key_path")), - BaseDomain: baseDomain, - - DERP: derpConfig, - - EphemeralNodeInactivityTimeout: viper.GetDuration( - "ephemeral_node_inactivity_timeout", - ), - - DBtype: viper.GetString("db_type"), - DBpath: headscale.AbsolutePathFromConfigPath(viper.GetString("db_path")), - DBhost: viper.GetString("db_host"), - DBport: viper.GetInt("db_port"), - DBname: viper.GetString("db_name"), - DBuser: viper.GetString("db_user"), - DBpass: viper.GetString("db_pass"), - - TLSLetsEncryptHostname: viper.GetString("tls_letsencrypt_hostname"), - TLSLetsEncryptListen: viper.GetString("tls_letsencrypt_listen"), - TLSLetsEncryptCacheDir: headscale.AbsolutePathFromConfigPath( - viper.GetString("tls_letsencrypt_cache_dir"), - ), - TLSLetsEncryptChallengeType: viper.GetString("tls_letsencrypt_challenge_type"), - - TLSCertPath: headscale.AbsolutePathFromConfigPath(viper.GetString("tls_cert_path")), - TLSKeyPath: headscale.AbsolutePathFromConfigPath(viper.GetString("tls_key_path")), - TLSClientAuthMode: tlsClientAuthMode, - - DNSConfig: dnsConfig, - - ACMEEmail: viper.GetString("acme_email"), - ACMEURL: viper.GetString("acme_url"), - - UnixSocket: viper.GetString("unix_socket"), - UnixSocketPermission: GetFileMode("unix_socket_permission"), - - OIDC: headscale.OIDCConfig{ - Issuer: viper.GetString("oidc.issuer"), - ClientID: viper.GetString("oidc.client_id"), - ClientSecret: viper.GetString("oidc.client_secret"), - Scope: viper.GetStringSlice("oidc.scope"), - ExtraParams: viper.GetStringMapString("oidc.extra_params"), - AllowedDomains: viper.GetStringSlice("oidc.allowed_domains"), - AllowedUsers: viper.GetStringSlice("oidc.allowed_users"), - StripEmaildomain: viper.GetBool("oidc.strip_email_domain"), - }, - - LogTail: logConfig, - - CLI: headscale.CLIConfig{ - Address: viper.GetString("cli.address"), - APIKey: viper.GetString("cli.api_key"), - Timeout: viper.GetDuration("cli.timeout"), - Insecure: viper.GetBool("cli.insecure"), - }, - - ACL: GetACLConfig(), - } -} - func getHeadscaleApp() (*headscale.Headscale, error) { + cfg, err := headscale.GetHeadscaleConfig() + if err != nil { + return nil, fmt.Errorf("failed to load configuration while creating headscale instance: %w", err) + } + // Minimum inactivity time out is keepalive timeout (60s) plus a few seconds // to avoid races minInactivityTimeout, _ := time.ParseDuration("65s") @@ -412,8 +44,6 @@ func getHeadscaleApp() (*headscale.Headscale, error) { return nil, err } - cfg := GetHeadscaleConfig() - app, err := headscale.NewHeadscale(cfg) if err != nil { return nil, err @@ -436,7 +66,13 @@ func getHeadscaleApp() (*headscale.Headscale, error) { } func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) { - cfg := GetHeadscaleConfig() + cfg, err := headscale.GetHeadscaleConfig() + if err != nil { + log.Fatal(). + Err(err). + Caller(). + Msgf("Failed to load configuration") + } log.Debug(). Dur("timeout", cfg.CLI.Timeout). @@ -570,17 +206,6 @@ func (tokenAuth) RequireTransportSecurity() bool { return true } -func GetFileMode(key string) fs.FileMode { - modeStr := viper.GetString(key) - - mode, err := strconv.ParseUint(modeStr, headscale.Base8, headscale.BitSize64) - if err != nil { - return PermissionFallback - } - - return fs.FileMode(mode) -} - func contains[T string](ts []T, t T) bool { for _, v := range ts { if reflect.DeepEqual(v, t) { diff --git a/cmd/headscale/headscale.go b/cmd/headscale/headscale.go index 600b186e..f5e28661 100644 --- a/cmd/headscale/headscale.go +++ b/cmd/headscale/headscale.go @@ -7,10 +7,10 @@ import ( "time" "github.com/efekarakus/termcolor" + "github.com/juanfont/headscale" "github.com/juanfont/headscale/cmd/headscale/cli" "github.com/rs/zerolog" "github.com/rs/zerolog/log" - "github.com/spf13/viper" "github.com/tcnksm/go-latest" ) @@ -43,19 +43,14 @@ func main() { NoColor: !colors, }) - if err := cli.LoadConfig(""); err != nil { + cfg, err := headscale.GetHeadscaleConfig() + if err != nil { log.Fatal().Caller().Err(err) } machineOutput := cli.HasMachineOutputFlag() - logLevel := viper.GetString("log_level") - level, err := zerolog.ParseLevel(logLevel) - if err != nil { - zerolog.SetGlobalLevel(zerolog.DebugLevel) - } else { - zerolog.SetGlobalLevel(level) - } + zerolog.SetGlobalLevel(cfg.LogLevel) // If the user has requested a "machine" readable format, // then disable login so the output remains valid. @@ -63,7 +58,7 @@ func main() { zerolog.SetGlobalLevel(zerolog.Disabled) } - if !viper.GetBool("disable_check_updates") && !machineOutput { + if !cfg.DisableUpdateCheck && !machineOutput { if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") && cli.Version != "dev" { githubTag := &latest.GithubTag{ diff --git a/cmd/headscale/headscale_test.go b/cmd/headscale/headscale_test.go index faf55f4c..9ca4a2c3 100644 --- a/cmd/headscale/headscale_test.go +++ b/cmd/headscale/headscale_test.go @@ -8,7 +8,7 @@ import ( "strings" "testing" - "github.com/juanfont/headscale/cmd/headscale/cli" + "github.com/juanfont/headscale" "github.com/spf13/viper" "gopkg.in/check.v1" ) @@ -49,7 +49,7 @@ func (*Suite) TestConfigLoading(c *check.C) { } // Load example config, it should load without validation errors - err = cli.LoadConfig(tmpDir) + err = headscale.LoadConfig(tmpDir) c.Assert(err, check.IsNil) // Test that config file was interpreted correctly @@ -63,7 +63,7 @@ func (*Suite) TestConfigLoading(c *check.C) { c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01") c.Assert(viper.GetStringSlice("dns_config.nameservers")[0], check.Equals, "1.1.1.1") c.Assert( - cli.GetFileMode("unix_socket_permission"), + headscale.GetFileMode("unix_socket_permission"), check.Equals, fs.FileMode(0o770), ) @@ -92,10 +92,10 @@ func (*Suite) TestDNSConfigLoading(c *check.C) { } // Load example config, it should load without validation errors - err = cli.LoadConfig(tmpDir) + err = headscale.LoadConfig(tmpDir) c.Assert(err, check.IsNil) - dnsConfig, baseDomain := cli.GetDNSConfig() + dnsConfig, baseDomain := headscale.GetDNSConfig() c.Assert(dnsConfig.Nameservers[0].String(), check.Equals, "1.1.1.1") c.Assert(dnsConfig.Resolvers[0].Addr, check.Equals, "1.1.1.1") @@ -125,7 +125,7 @@ func (*Suite) TestTLSConfigValidation(c *check.C) { writeConfig(c, tmpDir, configYaml) // Check configuration validation errors (1) - err = cli.LoadConfig(tmpDir) + err = headscale.LoadConfig(tmpDir) c.Assert(err, check.NotNil) // check.Matches can not handle multiline strings tmp := strings.ReplaceAll(err.Error(), "\n", "***") @@ -150,6 +150,6 @@ func (*Suite) TestTLSConfigValidation(c *check.C) { "---\nserver_url: \"http://127.0.0.1:8080\"\ntls_letsencrypt_hostname: \"example.com\"\ntls_letsencrypt_challenge_type: \"TLS-ALPN-01\"", ) writeConfig(c, tmpDir, configYaml) - err = cli.LoadConfig(tmpDir) + err = headscale.LoadConfig(tmpDir) c.Assert(err, check.IsNil) } diff --git a/config.go b/config.go new file mode 100644 index 00000000..909a48c4 --- /dev/null +++ b/config.go @@ -0,0 +1,504 @@ +package headscale + +import ( + "crypto/tls" + "errors" + "fmt" + "io/fs" + "net/url" + "strings" + "time" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + "github.com/spf13/viper" + "inet.af/netaddr" + "tailscale.com/tailcfg" + "tailscale.com/types/dnstype" +) + +// Config contains the initial Headscale configuration. +type Config struct { + ServerURL string + Addr string + MetricsAddr string + GRPCAddr string + GRPCAllowInsecure bool + EphemeralNodeInactivityTimeout time.Duration + IPPrefixes []netaddr.IPPrefix + PrivateKeyPath string + BaseDomain string + LogLevel zerolog.Level + DisableUpdateCheck bool + + DERP DERPConfig + + DBtype string + DBpath string + DBhost string + DBport int + DBname string + DBuser string + DBpass string + + TLS TLSConfig + + ACMEURL string + ACMEEmail string + + DNSConfig *tailcfg.DNSConfig + + UnixSocket string + UnixSocketPermission fs.FileMode + + OIDC OIDCConfig + + LogTail LogTailConfig + + CLI CLIConfig + + ACL ACLConfig +} + +type TLSConfig struct { + CertPath string + KeyPath string + ClientAuthMode tls.ClientAuthType + + LetsEncrypt LetsEncryptConfig +} + +type LetsEncryptConfig struct { + Listen string + Hostname string + CacheDir string + ChallengeType string +} + +type OIDCConfig struct { + Issuer string + ClientID string + ClientSecret string + Scope []string + ExtraParams map[string]string + AllowedDomains []string + AllowedUsers []string + StripEmaildomain bool +} + +type DERPConfig struct { + ServerEnabled bool + ServerRegionID int + ServerRegionCode string + ServerRegionName string + STUNAddr string + URLs []url.URL + Paths []string + AutoUpdate bool + UpdateFrequency time.Duration +} + +type LogTailConfig struct { + Enabled bool +} + +type CLIConfig struct { + Address string + APIKey string + Timeout time.Duration + Insecure bool +} + +type ACLConfig struct { + PolicyPath string +} + +func LoadConfig(path string) error { + viper.SetConfigName("config") + if path == "" { + viper.AddConfigPath("/etc/headscale/") + viper.AddConfigPath("$HOME/.headscale") + viper.AddConfigPath(".") + } else { + // For testing + viper.AddConfigPath(path) + } + + viper.SetEnvPrefix("headscale") + viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) + viper.AutomaticEnv() + + viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache") + viper.SetDefault("tls_letsencrypt_challenge_type", "HTTP-01") + viper.SetDefault("tls_client_auth_mode", "relaxed") + + viper.SetDefault("log_level", "info") + + viper.SetDefault("dns_config", nil) + + viper.SetDefault("derp.server.enabled", false) + viper.SetDefault("derp.server.stun.enabled", true) + + viper.SetDefault("unix_socket", "/var/run/headscale.sock") + viper.SetDefault("unix_socket_permission", "0o770") + + viper.SetDefault("grpc_listen_addr", ":50443") + viper.SetDefault("grpc_allow_insecure", false) + + viper.SetDefault("cli.timeout", "5s") + viper.SetDefault("cli.insecure", false) + + viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"}) + viper.SetDefault("oidc.strip_email_domain", true) + + viper.SetDefault("logtail.enabled", false) + + if err := viper.ReadInConfig(); err != nil { + return fmt.Errorf("fatal error reading config file: %w", err) + } + + // Collect any validation errors and return them all at once + var errorText string + if (viper.GetString("tls_letsencrypt_hostname") != "") && + ((viper.GetString("tls_cert_path") != "") || (viper.GetString("tls_key_path") != "")) { + errorText += "Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both\n" + } + + if (viper.GetString("tls_letsencrypt_hostname") != "") && + (viper.GetString("tls_letsencrypt_challenge_type") == "TLS-ALPN-01") && + (!strings.HasSuffix(viper.GetString("listen_addr"), ":443")) { + // this is only a warning because there could be something sitting in front of headscale that redirects the traffic (e.g. an iptables rule) + log.Warn(). + Msg("Warning: when using tls_letsencrypt_hostname with TLS-ALPN-01 as challenge type, headscale must be reachable on port 443, i.e. listen_addr should probably end in :443") + } + + if (viper.GetString("tls_letsencrypt_challenge_type") != "HTTP-01") && + (viper.GetString("tls_letsencrypt_challenge_type") != "TLS-ALPN-01") { + errorText += "Fatal config error: the only supported values for tls_letsencrypt_challenge_type are HTTP-01 and TLS-ALPN-01\n" + } + + if !strings.HasPrefix(viper.GetString("server_url"), "http://") && + !strings.HasPrefix(viper.GetString("server_url"), "https://") { + errorText += "Fatal config error: server_url must start with https:// or http://\n" + } + + _, authModeValid := LookupTLSClientAuthMode( + viper.GetString("tls_client_auth_mode"), + ) + + if !authModeValid { + errorText += fmt.Sprintf( + "Invalid tls_client_auth_mode supplied: %s. Accepted values: %s, %s, %s.", + viper.GetString("tls_client_auth_mode"), + DisabledClientAuth, + RelaxedClientAuth, + EnforcedClientAuth) + } + + if errorText != "" { + //nolint + return errors.New(strings.TrimSuffix(errorText, "\n")) + } else { + return nil + } +} + +func GetTLSConfig() TLSConfig { + tlsClientAuthMode, _ := LookupTLSClientAuthMode( + viper.GetString("tls_client_auth_mode"), + ) + + return TLSConfig{ + LetsEncrypt: LetsEncryptConfig{ + Hostname: viper.GetString("tls_letsencrypt_hostname"), + Listen: viper.GetString("tls_letsencrypt_listen"), + CacheDir: AbsolutePathFromConfigPath( + viper.GetString("tls_letsencrypt_cache_dir"), + ), + ChallengeType: viper.GetString("tls_letsencrypt_challenge_type"), + }, + CertPath: AbsolutePathFromConfigPath( + viper.GetString("tls_cert_path"), + ), + KeyPath: AbsolutePathFromConfigPath( + viper.GetString("tls_key_path"), + ), + ClientAuthMode: tlsClientAuthMode, + } +} + +func GetDERPConfig() DERPConfig { + serverEnabled := viper.GetBool("derp.server.enabled") + serverRegionID := viper.GetInt("derp.server.region_id") + serverRegionCode := viper.GetString("derp.server.region_code") + serverRegionName := viper.GetString("derp.server.region_name") + stunAddr := viper.GetString("derp.server.stun_listen_addr") + + if serverEnabled && stunAddr == "" { + log.Fatal(). + Msg("derp.server.stun_listen_addr must be set if derp.server.enabled is true") + } + + urlStrs := viper.GetStringSlice("derp.urls") + + urls := make([]url.URL, len(urlStrs)) + for index, urlStr := range urlStrs { + urlAddr, err := url.Parse(urlStr) + if err != nil { + log.Error(). + Str("url", urlStr). + Err(err). + Msg("Failed to parse url, ignoring...") + } + + urls[index] = *urlAddr + } + + paths := viper.GetStringSlice("derp.paths") + + autoUpdate := viper.GetBool("derp.auto_update_enabled") + updateFrequency := viper.GetDuration("derp.update_frequency") + + return DERPConfig{ + ServerEnabled: serverEnabled, + ServerRegionID: serverRegionID, + ServerRegionCode: serverRegionCode, + ServerRegionName: serverRegionName, + STUNAddr: stunAddr, + URLs: urls, + Paths: paths, + AutoUpdate: autoUpdate, + UpdateFrequency: updateFrequency, + } +} + +func GetLogTailConfig() LogTailConfig { + enabled := viper.GetBool("logtail.enabled") + + return LogTailConfig{ + Enabled: enabled, + } +} + +func GetACLConfig() ACLConfig { + policyPath := viper.GetString("acl_policy_path") + + return ACLConfig{ + PolicyPath: policyPath, + } +} + +func GetDNSConfig() (*tailcfg.DNSConfig, string) { + if viper.IsSet("dns_config") { + dnsConfig := &tailcfg.DNSConfig{} + + if viper.IsSet("dns_config.nameservers") { + nameserversStr := viper.GetStringSlice("dns_config.nameservers") + + nameservers := make([]netaddr.IP, len(nameserversStr)) + resolvers := make([]dnstype.Resolver, len(nameserversStr)) + + for index, nameserverStr := range nameserversStr { + nameserver, err := netaddr.ParseIP(nameserverStr) + if err != nil { + log.Error(). + Str("func", "getDNSConfig"). + Err(err). + Msgf("Could not parse nameserver IP: %s", nameserverStr) + } + + nameservers[index] = nameserver + resolvers[index] = dnstype.Resolver{ + Addr: nameserver.String(), + } + } + + dnsConfig.Nameservers = nameservers + dnsConfig.Resolvers = resolvers + } + + if viper.IsSet("dns_config.restricted_nameservers") { + if len(dnsConfig.Nameservers) > 0 { + dnsConfig.Routes = make(map[string][]dnstype.Resolver) + restrictedDNS := viper.GetStringMapStringSlice( + "dns_config.restricted_nameservers", + ) + for domain, restrictedNameservers := range restrictedDNS { + restrictedResolvers := make( + []dnstype.Resolver, + len(restrictedNameservers), + ) + for index, nameserverStr := range restrictedNameservers { + nameserver, err := netaddr.ParseIP(nameserverStr) + if err != nil { + log.Error(). + Str("func", "getDNSConfig"). + Err(err). + Msgf("Could not parse restricted nameserver IP: %s", nameserverStr) + } + restrictedResolvers[index] = dnstype.Resolver{ + Addr: nameserver.String(), + } + } + dnsConfig.Routes[domain] = restrictedResolvers + } + } else { + log.Warn(). + Msg("Warning: dns_config.restricted_nameservers is set, but no nameservers are configured. Ignoring restricted_nameservers.") + } + } + + if viper.IsSet("dns_config.domains") { + dnsConfig.Domains = viper.GetStringSlice("dns_config.domains") + } + + if viper.IsSet("dns_config.magic_dns") { + magicDNS := viper.GetBool("dns_config.magic_dns") + if len(dnsConfig.Nameservers) > 0 { + dnsConfig.Proxied = magicDNS + } else if magicDNS { + log.Warn(). + Msg("Warning: dns_config.magic_dns is set, but no nameservers are configured. Ignoring magic_dns.") + } + } + + var baseDomain string + if viper.IsSet("dns_config.base_domain") { + baseDomain = viper.GetString("dns_config.base_domain") + } else { + baseDomain = "headscale.net" // does not really matter when MagicDNS is not enabled + } + + return dnsConfig, baseDomain + } + + return nil, "" +} + +func GetHeadscaleConfig() (*Config, error) { + err := LoadConfig("") + if err != nil { + return nil, err + } + + dnsConfig, baseDomain := GetDNSConfig() + derpConfig := GetDERPConfig() + logConfig := GetLogTailConfig() + + configuredPrefixes := viper.GetStringSlice("ip_prefixes") + parsedPrefixes := make([]netaddr.IPPrefix, 0, len(configuredPrefixes)+1) + + logLevelStr := viper.GetString("log_level") + logLevel, err := zerolog.ParseLevel(logLevelStr) + if err != nil { + logLevel = zerolog.DebugLevel + } + + legacyPrefixField := viper.GetString("ip_prefix") + if len(legacyPrefixField) > 0 { + log. + Warn(). + Msgf( + "%s, %s", + "use of 'ip_prefix' for configuration is deprecated", + "please see 'ip_prefixes' in the shipped example.", + ) + legacyPrefix, err := netaddr.ParseIPPrefix(legacyPrefixField) + if err != nil { + panic(fmt.Errorf("failed to parse ip_prefix: %w", err)) + } + parsedPrefixes = append(parsedPrefixes, legacyPrefix) + } + + for i, prefixInConfig := range configuredPrefixes { + prefix, err := netaddr.ParseIPPrefix(prefixInConfig) + if err != nil { + panic(fmt.Errorf("failed to parse ip_prefixes[%d]: %w", i, err)) + } + parsedPrefixes = append(parsedPrefixes, prefix) + } + + prefixes := make([]netaddr.IPPrefix, 0, len(parsedPrefixes)) + { + // dedup + normalizedPrefixes := make(map[string]int, len(parsedPrefixes)) + for i, p := range parsedPrefixes { + normalized, _ := p.Range().Prefix() + normalizedPrefixes[normalized.String()] = i + } + + // convert back to list + for _, i := range normalizedPrefixes { + prefixes = append(prefixes, parsedPrefixes[i]) + } + } + + if len(prefixes) < 1 { + prefixes = append(prefixes, netaddr.MustParseIPPrefix("100.64.0.0/10")) + log.Warn(). + Msgf("'ip_prefixes' not configured, falling back to default: %v", prefixes) + } + + return &Config{ + ServerURL: viper.GetString("server_url"), + Addr: viper.GetString("listen_addr"), + MetricsAddr: viper.GetString("metrics_listen_addr"), + GRPCAddr: viper.GetString("grpc_listen_addr"), + GRPCAllowInsecure: viper.GetBool("grpc_allow_insecure"), + DisableUpdateCheck: viper.GetBool("disable_check_updates"), + LogLevel: logLevel, + + IPPrefixes: prefixes, + PrivateKeyPath: AbsolutePathFromConfigPath( + viper.GetString("private_key_path"), + ), + BaseDomain: baseDomain, + + DERP: derpConfig, + + EphemeralNodeInactivityTimeout: viper.GetDuration( + "ephemeral_node_inactivity_timeout", + ), + + DBtype: viper.GetString("db_type"), + DBpath: AbsolutePathFromConfigPath(viper.GetString("db_path")), + DBhost: viper.GetString("db_host"), + DBport: viper.GetInt("db_port"), + DBname: viper.GetString("db_name"), + DBuser: viper.GetString("db_user"), + DBpass: viper.GetString("db_pass"), + + TLS: GetTLSConfig(), + + DNSConfig: dnsConfig, + + ACMEEmail: viper.GetString("acme_email"), + ACMEURL: viper.GetString("acme_url"), + + UnixSocket: viper.GetString("unix_socket"), + UnixSocketPermission: GetFileMode("unix_socket_permission"), + + OIDC: OIDCConfig{ + Issuer: viper.GetString("oidc.issuer"), + ClientID: viper.GetString("oidc.client_id"), + ClientSecret: viper.GetString("oidc.client_secret"), + Scope: viper.GetStringSlice("oidc.scope"), + ExtraParams: viper.GetStringMapString("oidc.extra_params"), + AllowedDomains: viper.GetStringSlice("oidc.allowed_domains"), + AllowedUsers: viper.GetStringSlice("oidc.allowed_users"), + StripEmaildomain: viper.GetBool("oidc.strip_email_domain"), + }, + + LogTail: logConfig, + + CLI: CLIConfig{ + Address: viper.GetString("cli.address"), + APIKey: viper.GetString("cli.api_key"), + Timeout: viper.GetDuration("cli.timeout"), + Insecure: viper.GetBool("cli.insecure"), + }, + + ACL: GetACLConfig(), + }, nil +} diff --git a/machine_test.go b/machine_test.go index 0fb3ed78..bde96057 100644 --- a/machine_test.go +++ b/machine_test.go @@ -821,7 +821,7 @@ func TestHeadscale_GenerateGivenName(t *testing.T) { { name: "simple machine name generation", h: &Headscale{ - cfg: Config{ + cfg: &Config{ OIDC: OIDCConfig{ StripEmaildomain: true, }, @@ -836,7 +836,7 @@ func TestHeadscale_GenerateGivenName(t *testing.T) { { name: "machine name with 53 chars", h: &Headscale{ - cfg: Config{ + cfg: &Config{ OIDC: OIDCConfig{ StripEmaildomain: true, }, @@ -851,7 +851,7 @@ func TestHeadscale_GenerateGivenName(t *testing.T) { { name: "machine name with 60 chars", h: &Headscale{ - cfg: Config{ + cfg: &Config{ OIDC: OIDCConfig{ StripEmaildomain: true, }, @@ -866,7 +866,7 @@ func TestHeadscale_GenerateGivenName(t *testing.T) { { name: "machine name with 63 chars", h: &Headscale{ - cfg: Config{ + cfg: &Config{ OIDC: OIDCConfig{ StripEmaildomain: true, }, @@ -881,7 +881,7 @@ func TestHeadscale_GenerateGivenName(t *testing.T) { { name: "machine name with 64 chars", h: &Headscale{ - cfg: Config{ + cfg: &Config{ OIDC: OIDCConfig{ StripEmaildomain: true, }, @@ -896,7 +896,7 @@ func TestHeadscale_GenerateGivenName(t *testing.T) { { name: "machine name with 73 chars", h: &Headscale{ - cfg: Config{ + cfg: &Config{ OIDC: OIDCConfig{ StripEmaildomain: true, }, diff --git a/utils.go b/utils.go index 6dddf4c5..8d9dec5b 100644 --- a/utils.go +++ b/utils.go @@ -11,10 +11,12 @@ import ( "encoding/base64" "encoding/json" "fmt" + "io/fs" "net" "os" "path/filepath" "reflect" + "strconv" "strings" "github.com/rs/zerolog/log" @@ -55,6 +57,8 @@ const ( // privateKey prefix. privateHexPrefix = "privkey:" + + PermissionFallback = 0o700 ) func MachinePublicKeyStripPrefix(machineKey key.MachinePublic) string { @@ -350,3 +354,14 @@ func AbsolutePathFromConfigPath(path string) string { return path } + +func GetFileMode(key string) fs.FileMode { + modeStr := viper.GetString(key) + + mode, err := strconv.ParseUint(modeStr, Base8, BitSize64) + if err != nil { + return PermissionFallback + } + + return fs.FileMode(mode) +}