Merge pull request #28 from juanfont/sqlite-support-plus-integration

Adding SQLite support
This commit is contained in:
Juan Font 2021-05-18 23:49:04 +02:00 committed by GitHub
commit 1faed2764f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 84 additions and 13 deletions

1
.gitignore vendored
View file

@ -17,3 +17,4 @@
/headscale /headscale
config.json config.json
*.key *.key
/db.sqlite

View file

@ -44,7 +44,7 @@ Suggestions/PRs welcomed!
make make
``` ```
2. Get yourself a PostgreSQL DB running (yes, [I know](https://tailscale.com/blog/an-unlikely-database-migration/)) 2. (Optional, you can also use SQLite) Get yourself a PostgreSQL DB running
```shell ```shell
docker run --name headscale -e POSTGRES_DB=headscale -e \ docker run --name headscale -e POSTGRES_DB=headscale -e \
@ -55,7 +55,12 @@ Suggestions/PRs welcomed!
```shell ```shell
wg genkey > private.key wg genkey > private.key
wg pubkey < private.key > public.key # not needed wg pubkey < private.key > public.key # not needed
cp config.json.example config.json
# Postgres
cp config.json.postgres.example config.json
# or
# SQLite
cp config.json.sqlite.example config.json
``` ```
4. Create a namespace (equivalent to a user in tailscale.com) 4. Create a namespace (equivalent to a user in tailscale.com)

2
api.go
View file

@ -279,12 +279,14 @@ func (h *Headscale) keepAlive(cancel chan []byte, pollData chan []byte, mKey wgc
return return
default: default:
h.pollMu.Lock()
data, err := h.getMapKeepAliveResponse(mKey, req, m) data, err := h.getMapKeepAliveResponse(mKey, req, m)
if err != nil { if err != nil {
log.Printf("Error generating the keep alive msg: %s", err) log.Printf("Error generating the keep alive msg: %s", err)
return return
} }
pollData <- *data pollData <- *data
h.pollMu.Unlock()
time.Sleep(60 * time.Second) time.Sleep(60 * time.Second)
} }
} }

21
app.go
View file

@ -22,6 +22,8 @@ type Config struct {
PrivateKeyPath string PrivateKeyPath string
DerpMap *tailcfg.DERPMap DerpMap *tailcfg.DERPMap
DBtype string
DBpath string
DBhost string DBhost string
DBport int DBport int
DBname string DBname string
@ -60,11 +62,22 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
return nil, err return nil, err
} }
pubKey := privKey.Public() pubKey := privKey.Public()
var dbString string
switch cfg.DBtype {
case "postgres":
dbString = fmt.Sprintf("host=%s port=%d dbname=%s user=%s password=%s sslmode=disable", cfg.DBhost,
cfg.DBport, cfg.DBname, cfg.DBuser, cfg.DBpass)
case "sqlite3":
dbString = cfg.DBpath
default:
return nil, errors.New("Unsupported DB")
}
h := Headscale{ h := Headscale{
cfg: cfg, cfg: cfg,
dbType: "postgres", dbType: cfg.DBtype,
dbString: fmt.Sprintf("host=%s port=%d dbname=%s user=%s password=%s sslmode=disable", cfg.DBhost, dbString: dbString,
cfg.DBport, cfg.DBname, cfg.DBuser, cfg.DBpass),
privateKey: privKey, privateKey: privKey,
publicKey: &pubKey, publicKey: &pubKey,
} }

View file

@ -22,10 +22,10 @@ type ErrorOutput struct {
func absPath(path string) string { func absPath(path string) string {
// If a relative path is provided, prefix it with the the directory where // If a relative path is provided, prefix it with the the directory where
// the config file was found. // the config file was found.
if (path != "") && !strings.HasPrefix(path, "/") { if (path != "") && !strings.HasPrefix(path, string(os.PathSeparator)) {
dir, _ := filepath.Split(viper.ConfigFileUsed()) dir, _ := filepath.Split(viper.ConfigFileUsed())
if dir != "" { if dir != "" {
path = dir + "/" + path path = filepath.Join(dir, path)
} }
} }
return path return path
@ -43,6 +43,8 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
PrivateKeyPath: absPath(viper.GetString("private_key_path")), PrivateKeyPath: absPath(viper.GetString("private_key_path")),
DerpMap: derpMap, DerpMap: derpMap,
DBtype: viper.GetString("db_type"),
DBpath: absPath(viper.GetString("db_path")),
DBhost: viper.GetString("db_host"), DBhost: viper.GetString("db_host"),
DBport: viper.GetInt("db_port"), DBport: viper.GetInt("db_port"),
DBname: viper.GetString("db_name"), DBname: viper.GetString("db_name"),

View file

@ -27,7 +27,7 @@ func (s *Suite) TearDownSuite(c *check.C) {
} }
func (*Suite) TestConfigLoading(c *check.C) { func (*Suite) TestPostgresConfigLoading(c *check.C) {
tmpDir, err := ioutil.TempDir("", "headscale") tmpDir, err := ioutil.TempDir("", "headscale")
if err != nil { if err != nil {
c.Fatal(err) c.Fatal(err)
@ -40,7 +40,7 @@ func (*Suite) TestConfigLoading(c *check.C) {
} }
// Symlink the example config file // Symlink the example config file
err = os.Symlink(filepath.Clean(path+"/../../config.json.example"), filepath.Join(tmpDir, "config.json")) err = os.Symlink(filepath.Clean(path+"/../../config.json.postgres.example"), filepath.Join(tmpDir, "config.json"))
if err != nil { if err != nil {
c.Fatal(err) c.Fatal(err)
} }
@ -50,14 +50,47 @@ func (*Suite) TestConfigLoading(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
// Test that config file was interpreted correctly // Test that config file was interpreted correctly
c.Assert(viper.GetString("server_url"), check.Equals, "http://192.168.1.12:8000") c.Assert(viper.GetString("server_url"), check.Equals, "http://127.0.0.1:8000")
c.Assert(viper.GetString("listen_addr"), check.Equals, "0.0.0.0:8000") c.Assert(viper.GetString("listen_addr"), check.Equals, "0.0.0.0:8000")
c.Assert(viper.GetString("derp_map_path"), check.Equals, "derp.yaml") c.Assert(viper.GetString("derp_map_path"), check.Equals, "derp.yaml")
c.Assert(viper.GetString("db_type"), check.Equals, "postgres")
c.Assert(viper.GetString("db_port"), check.Equals, "5432") c.Assert(viper.GetString("db_port"), check.Equals, "5432")
c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "") c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "")
c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01") c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01")
} }
func (*Suite) TestSqliteConfigLoading(c *check.C) {
tmpDir, err := ioutil.TempDir("", "headscale")
if err != nil {
c.Fatal(err)
}
defer os.RemoveAll(tmpDir)
path, err := os.Getwd()
if err != nil {
c.Fatal(err)
}
// Symlink the example config file
err = os.Symlink(filepath.Clean(path+"/../../config.json.sqlite.example"), filepath.Join(tmpDir, "config.json"))
if err != nil {
c.Fatal(err)
}
// Load example config, it should load without validation errors
err = loadConfig(tmpDir)
c.Assert(err, check.IsNil)
// Test that config file was interpreted correctly
c.Assert(viper.GetString("server_url"), check.Equals, "http://127.0.0.1:8000")
c.Assert(viper.GetString("listen_addr"), check.Equals, "0.0.0.0:8000")
c.Assert(viper.GetString("derp_map_path"), check.Equals, "derp.yaml")
c.Assert(viper.GetString("db_type"), check.Equals, "sqlite3")
c.Assert(viper.GetString("db_path"), check.Equals, "db.sqlite")
c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "")
c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01")
}
func writeConfig(c *check.C, tmpDir string, configYaml []byte) { func writeConfig(c *check.C, tmpDir string, configYaml []byte) {
// Populate a custom config file // Populate a custom config file
configFile := filepath.Join(tmpDir, "config.yaml") configFile := filepath.Join(tmpDir, "config.yaml")
@ -89,7 +122,7 @@ func (*Suite) TestTLSConfigValidation(c *check.C) {
fmt.Println(tmp) fmt.Println(tmp)
// Check configuration validation errors (2) // Check configuration validation errors (2)
configYaml = []byte("---\nserver_url: \"http://192.168.1.12:8000\"\ntls_letsencrypt_hostname: \"example.com\"\ntls_letsencrypt_challenge_type: \"TLS-ALPN-01\"") configYaml = []byte("---\nserver_url: \"http://127.0.0.1:8000\"\ntls_letsencrypt_hostname: \"example.com\"\ntls_letsencrypt_challenge_type: \"TLS-ALPN-01\"")
writeConfig(c, tmpDir, configYaml) writeConfig(c, tmpDir, configYaml)
err = loadConfig(tmpDir) err = loadConfig(tmpDir)
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)

View file

@ -1,8 +1,9 @@
{ {
"server_url": "http://192.168.1.12:8000", "server_url": "http://127.0.0.1:8000",
"listen_addr": "0.0.0.0:8000", "listen_addr": "0.0.0.0:8000",
"private_key_path": "private.key", "private_key_path": "private.key",
"derp_map_path": "derp.yaml", "derp_map_path": "derp.yaml",
"db_type": "postgres",
"db_host": "localhost", "db_host": "localhost",
"db_port": 5432, "db_port": 5432,
"db_name": "headscale", "db_name": "headscale",

View file

@ -0,0 +1,13 @@
{
"server_url": "http://127.0.0.1:8000",
"listen_addr": "0.0.0.0:8000",
"private_key_path": "private.key",
"derp_map_path": "derp.yaml",
"db_type": "sqlite3",
"db_path": "db.sqlite",
"tls_letsencrypt_hostname": "",
"tls_letsencrypt_cache_dir": ".cache",
"tls_letsencrypt_challenge_type": "HTTP-01",
"tls_cert_path": "",
"tls_key_path": ""
}

1
db.go
View file

@ -5,6 +5,7 @@ import (
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
_ "github.com/jinzhu/gorm/dialects/postgres" // sql driver _ "github.com/jinzhu/gorm/dialects/postgres" // sql driver
_ "github.com/jinzhu/gorm/dialects/sqlite" // sql driver
) )
const dbVersion = "1" const dbVersion = "1"