mirror of
https://github.com/juanfont/headscale.git
synced 2024-12-02 03:33:05 +00:00
Merge pull request #28 from juanfont/sqlite-support-plus-integration
Adding SQLite support
This commit is contained in:
commit
1faed2764f
9 changed files with 84 additions and 13 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -17,3 +17,4 @@
|
||||||
/headscale
|
/headscale
|
||||||
config.json
|
config.json
|
||||||
*.key
|
*.key
|
||||||
|
/db.sqlite
|
||||||
|
|
|
@ -44,7 +44,7 @@ Suggestions/PRs welcomed!
|
||||||
make
|
make
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Get yourself a PostgreSQL DB running (yes, [I know](https://tailscale.com/blog/an-unlikely-database-migration/))
|
2. (Optional, you can also use SQLite) Get yourself a PostgreSQL DB running
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
docker run --name headscale -e POSTGRES_DB=headscale -e \
|
docker run --name headscale -e POSTGRES_DB=headscale -e \
|
||||||
|
@ -55,7 +55,12 @@ Suggestions/PRs welcomed!
|
||||||
```shell
|
```shell
|
||||||
wg genkey > private.key
|
wg genkey > private.key
|
||||||
wg pubkey < private.key > public.key # not needed
|
wg pubkey < private.key > public.key # not needed
|
||||||
cp config.json.example config.json
|
|
||||||
|
# Postgres
|
||||||
|
cp config.json.postgres.example config.json
|
||||||
|
# or
|
||||||
|
# SQLite
|
||||||
|
cp config.json.sqlite.example config.json
|
||||||
```
|
```
|
||||||
|
|
||||||
4. Create a namespace (equivalent to a user in tailscale.com)
|
4. Create a namespace (equivalent to a user in tailscale.com)
|
||||||
|
|
2
api.go
2
api.go
|
@ -279,12 +279,14 @@ func (h *Headscale) keepAlive(cancel chan []byte, pollData chan []byte, mKey wgc
|
||||||
return
|
return
|
||||||
|
|
||||||
default:
|
default:
|
||||||
|
h.pollMu.Lock()
|
||||||
data, err := h.getMapKeepAliveResponse(mKey, req, m)
|
data, err := h.getMapKeepAliveResponse(mKey, req, m)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error generating the keep alive msg: %s", err)
|
log.Printf("Error generating the keep alive msg: %s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
pollData <- *data
|
pollData <- *data
|
||||||
|
h.pollMu.Unlock()
|
||||||
time.Sleep(60 * time.Second)
|
time.Sleep(60 * time.Second)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
21
app.go
21
app.go
|
@ -22,6 +22,8 @@ type Config struct {
|
||||||
PrivateKeyPath string
|
PrivateKeyPath string
|
||||||
DerpMap *tailcfg.DERPMap
|
DerpMap *tailcfg.DERPMap
|
||||||
|
|
||||||
|
DBtype string
|
||||||
|
DBpath string
|
||||||
DBhost string
|
DBhost string
|
||||||
DBport int
|
DBport int
|
||||||
DBname string
|
DBname string
|
||||||
|
@ -60,11 +62,22 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
pubKey := privKey.Public()
|
pubKey := privKey.Public()
|
||||||
|
|
||||||
|
var dbString string
|
||||||
|
switch cfg.DBtype {
|
||||||
|
case "postgres":
|
||||||
|
dbString = fmt.Sprintf("host=%s port=%d dbname=%s user=%s password=%s sslmode=disable", cfg.DBhost,
|
||||||
|
cfg.DBport, cfg.DBname, cfg.DBuser, cfg.DBpass)
|
||||||
|
case "sqlite3":
|
||||||
|
dbString = cfg.DBpath
|
||||||
|
default:
|
||||||
|
return nil, errors.New("Unsupported DB")
|
||||||
|
}
|
||||||
|
|
||||||
h := Headscale{
|
h := Headscale{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
dbType: "postgres",
|
dbType: cfg.DBtype,
|
||||||
dbString: fmt.Sprintf("host=%s port=%d dbname=%s user=%s password=%s sslmode=disable", cfg.DBhost,
|
dbString: dbString,
|
||||||
cfg.DBport, cfg.DBname, cfg.DBuser, cfg.DBpass),
|
|
||||||
privateKey: privKey,
|
privateKey: privKey,
|
||||||
publicKey: &pubKey,
|
publicKey: &pubKey,
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,10 +22,10 @@ type ErrorOutput struct {
|
||||||
func absPath(path string) string {
|
func absPath(path string) string {
|
||||||
// If a relative path is provided, prefix it with the the directory where
|
// If a relative path is provided, prefix it with the the directory where
|
||||||
// the config file was found.
|
// the config file was found.
|
||||||
if (path != "") && !strings.HasPrefix(path, "/") {
|
if (path != "") && !strings.HasPrefix(path, string(os.PathSeparator)) {
|
||||||
dir, _ := filepath.Split(viper.ConfigFileUsed())
|
dir, _ := filepath.Split(viper.ConfigFileUsed())
|
||||||
if dir != "" {
|
if dir != "" {
|
||||||
path = dir + "/" + path
|
path = filepath.Join(dir, path)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return path
|
return path
|
||||||
|
@ -43,6 +43,8 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
|
||||||
PrivateKeyPath: absPath(viper.GetString("private_key_path")),
|
PrivateKeyPath: absPath(viper.GetString("private_key_path")),
|
||||||
DerpMap: derpMap,
|
DerpMap: derpMap,
|
||||||
|
|
||||||
|
DBtype: viper.GetString("db_type"),
|
||||||
|
DBpath: absPath(viper.GetString("db_path")),
|
||||||
DBhost: viper.GetString("db_host"),
|
DBhost: viper.GetString("db_host"),
|
||||||
DBport: viper.GetInt("db_port"),
|
DBport: viper.GetInt("db_port"),
|
||||||
DBname: viper.GetString("db_name"),
|
DBname: viper.GetString("db_name"),
|
||||||
|
|
|
@ -27,7 +27,7 @@ func (s *Suite) TearDownSuite(c *check.C) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestConfigLoading(c *check.C) {
|
func (*Suite) TestPostgresConfigLoading(c *check.C) {
|
||||||
tmpDir, err := ioutil.TempDir("", "headscale")
|
tmpDir, err := ioutil.TempDir("", "headscale")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Fatal(err)
|
c.Fatal(err)
|
||||||
|
@ -40,7 +40,7 @@ func (*Suite) TestConfigLoading(c *check.C) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Symlink the example config file
|
// Symlink the example config file
|
||||||
err = os.Symlink(filepath.Clean(path+"/../../config.json.example"), filepath.Join(tmpDir, "config.json"))
|
err = os.Symlink(filepath.Clean(path+"/../../config.json.postgres.example"), filepath.Join(tmpDir, "config.json"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Fatal(err)
|
c.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -50,14 +50,47 @@ func (*Suite) TestConfigLoading(c *check.C) {
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
// Test that config file was interpreted correctly
|
// Test that config file was interpreted correctly
|
||||||
c.Assert(viper.GetString("server_url"), check.Equals, "http://192.168.1.12:8000")
|
c.Assert(viper.GetString("server_url"), check.Equals, "http://127.0.0.1:8000")
|
||||||
c.Assert(viper.GetString("listen_addr"), check.Equals, "0.0.0.0:8000")
|
c.Assert(viper.GetString("listen_addr"), check.Equals, "0.0.0.0:8000")
|
||||||
c.Assert(viper.GetString("derp_map_path"), check.Equals, "derp.yaml")
|
c.Assert(viper.GetString("derp_map_path"), check.Equals, "derp.yaml")
|
||||||
|
c.Assert(viper.GetString("db_type"), check.Equals, "postgres")
|
||||||
c.Assert(viper.GetString("db_port"), check.Equals, "5432")
|
c.Assert(viper.GetString("db_port"), check.Equals, "5432")
|
||||||
c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "")
|
c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "")
|
||||||
c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01")
|
c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (*Suite) TestSqliteConfigLoading(c *check.C) {
|
||||||
|
tmpDir, err := ioutil.TempDir("", "headscale")
|
||||||
|
if err != nil {
|
||||||
|
c.Fatal(err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tmpDir)
|
||||||
|
|
||||||
|
path, err := os.Getwd()
|
||||||
|
if err != nil {
|
||||||
|
c.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Symlink the example config file
|
||||||
|
err = os.Symlink(filepath.Clean(path+"/../../config.json.sqlite.example"), filepath.Join(tmpDir, "config.json"))
|
||||||
|
if err != nil {
|
||||||
|
c.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load example config, it should load without validation errors
|
||||||
|
err = loadConfig(tmpDir)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
// Test that config file was interpreted correctly
|
||||||
|
c.Assert(viper.GetString("server_url"), check.Equals, "http://127.0.0.1:8000")
|
||||||
|
c.Assert(viper.GetString("listen_addr"), check.Equals, "0.0.0.0:8000")
|
||||||
|
c.Assert(viper.GetString("derp_map_path"), check.Equals, "derp.yaml")
|
||||||
|
c.Assert(viper.GetString("db_type"), check.Equals, "sqlite3")
|
||||||
|
c.Assert(viper.GetString("db_path"), check.Equals, "db.sqlite")
|
||||||
|
c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "")
|
||||||
|
c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01")
|
||||||
|
}
|
||||||
|
|
||||||
func writeConfig(c *check.C, tmpDir string, configYaml []byte) {
|
func writeConfig(c *check.C, tmpDir string, configYaml []byte) {
|
||||||
// Populate a custom config file
|
// Populate a custom config file
|
||||||
configFile := filepath.Join(tmpDir, "config.yaml")
|
configFile := filepath.Join(tmpDir, "config.yaml")
|
||||||
|
@ -89,7 +122,7 @@ func (*Suite) TestTLSConfigValidation(c *check.C) {
|
||||||
fmt.Println(tmp)
|
fmt.Println(tmp)
|
||||||
|
|
||||||
// Check configuration validation errors (2)
|
// Check configuration validation errors (2)
|
||||||
configYaml = []byte("---\nserver_url: \"http://192.168.1.12:8000\"\ntls_letsencrypt_hostname: \"example.com\"\ntls_letsencrypt_challenge_type: \"TLS-ALPN-01\"")
|
configYaml = []byte("---\nserver_url: \"http://127.0.0.1:8000\"\ntls_letsencrypt_hostname: \"example.com\"\ntls_letsencrypt_challenge_type: \"TLS-ALPN-01\"")
|
||||||
writeConfig(c, tmpDir, configYaml)
|
writeConfig(c, tmpDir, configYaml)
|
||||||
err = loadConfig(tmpDir)
|
err = loadConfig(tmpDir)
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
{
|
{
|
||||||
"server_url": "http://192.168.1.12:8000",
|
"server_url": "http://127.0.0.1:8000",
|
||||||
"listen_addr": "0.0.0.0:8000",
|
"listen_addr": "0.0.0.0:8000",
|
||||||
"private_key_path": "private.key",
|
"private_key_path": "private.key",
|
||||||
"derp_map_path": "derp.yaml",
|
"derp_map_path": "derp.yaml",
|
||||||
|
"db_type": "postgres",
|
||||||
"db_host": "localhost",
|
"db_host": "localhost",
|
||||||
"db_port": 5432,
|
"db_port": 5432,
|
||||||
"db_name": "headscale",
|
"db_name": "headscale",
|
13
config.json.sqlite.example
Normal file
13
config.json.sqlite.example
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
{
|
||||||
|
"server_url": "http://127.0.0.1:8000",
|
||||||
|
"listen_addr": "0.0.0.0:8000",
|
||||||
|
"private_key_path": "private.key",
|
||||||
|
"derp_map_path": "derp.yaml",
|
||||||
|
"db_type": "sqlite3",
|
||||||
|
"db_path": "db.sqlite",
|
||||||
|
"tls_letsencrypt_hostname": "",
|
||||||
|
"tls_letsencrypt_cache_dir": ".cache",
|
||||||
|
"tls_letsencrypt_challenge_type": "HTTP-01",
|
||||||
|
"tls_cert_path": "",
|
||||||
|
"tls_key_path": ""
|
||||||
|
}
|
1
db.go
1
db.go
|
@ -5,6 +5,7 @@ import (
|
||||||
|
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
_ "github.com/jinzhu/gorm/dialects/postgres" // sql driver
|
_ "github.com/jinzhu/gorm/dialects/postgres" // sql driver
|
||||||
|
_ "github.com/jinzhu/gorm/dialects/sqlite" // sql driver
|
||||||
)
|
)
|
||||||
|
|
||||||
const dbVersion = "1"
|
const dbVersion = "1"
|
||||||
|
|
Loading…
Reference in a new issue