Update Headscale to depend on gorm v2

This commit is contained in:
Juan Font Alonso 2021-06-24 15:44:19 +02:00
parent df0d214faf
commit 69ba750b38
9 changed files with 47 additions and 40 deletions

9
api.go
View file

@ -3,6 +3,7 @@ package headscale
import ( import (
"encoding/binary" "encoding/binary"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"log" "log"
@ -10,9 +11,9 @@ import (
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/jinzhu/gorm"
"github.com/klauspost/compress/zstd" "github.com/klauspost/compress/zstd"
"gorm.io/datatypes" "gorm.io/datatypes"
"gorm.io/gorm"
"inet.af/netaddr" "inet.af/netaddr"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/wgengine/wgcfg" "tailscale.com/wgengine/wgcfg"
@ -80,10 +81,9 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
c.String(http.StatusInternalServerError, ":(") c.String(http.StatusInternalServerError, ":(")
return return
} }
defer db.Close()
var m Machine var m Machine
if db.First(&m, "machine_key = ?", mKey.HexString()).RecordNotFound() { if result := db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
log.Println("New Machine!") log.Println("New Machine!")
m = Machine{ m = Machine{
Expiry: &req.Expiry, Expiry: &req.Expiry,
@ -209,9 +209,8 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
log.Printf("Cannot open DB: %s", err) log.Printf("Cannot open DB: %s", err)
return return
} }
defer db.Close()
var m Machine var m Machine
if db.First(&m, "machine_key = ?", mKey.HexString()).RecordNotFound() { if result := db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
log.Printf("Ignoring request, cannot find machine with key %s", mKey.HexString()) log.Printf("Ignoring request, cannot find machine with key %s", mKey.HexString())
return return
} }

1
app.go
View file

@ -112,7 +112,6 @@ func (h *Headscale) expireEphemeralNodesWorker() {
log.Printf("Cannot open DB: %s", err) log.Printf("Cannot open DB: %s", err)
return return
} }
defer db.Close()
namespaces, err := h.ListNamespaces() namespaces, err := h.ListNamespaces()
if err != nil { if err != nil {

4
cli.go
View file

@ -4,6 +4,7 @@ import (
"errors" "errors"
"log" "log"
"gorm.io/gorm"
"tailscale.com/wgengine/wgcfg" "tailscale.com/wgengine/wgcfg"
) )
@ -22,9 +23,8 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err
log.Printf("Cannot open DB: %s", err) log.Printf("Cannot open DB: %s", err)
return nil, err return nil, err
} }
defer db.Close()
m := Machine{} m := Machine{}
if db.First(&m, "machine_key = ?", mKey.HexString()).RecordNotFound() { if result := db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, errors.New("Machine not found") return nil, errors.New("Machine not found")
} }

45
db.go
View file

@ -3,9 +3,9 @@ package headscale
import ( import (
"errors" "errors"
"github.com/jinzhu/gorm" "gorm.io/driver/postgres"
_ "github.com/jinzhu/gorm/dialects/postgres" // sql driver "gorm.io/driver/sqlite"
_ "github.com/jinzhu/gorm/dialects/sqlite" // sql driver "gorm.io/gorm"
) )
const dbVersion = "1" const dbVersion = "1"
@ -17,30 +17,49 @@ type KV struct {
} }
func (h *Headscale) initDB() error { func (h *Headscale) initDB() error {
db, err := gorm.Open(h.dbType, h.dbString) db, err := h.db()
if err != nil { if err != nil {
return err return err
} }
if h.dbType == "postgres" { if h.dbType == "postgres" {
db.Exec("create extension if not exists \"uuid-ossp\";") db.Exec("create extension if not exists \"uuid-ossp\";")
} }
db.AutoMigrate(&Machine{}) err = db.AutoMigrate(&Machine{})
db.AutoMigrate(&KV{}) if err != nil {
db.AutoMigrate(&Namespace{}) return err
db.AutoMigrate(&PreAuthKey{}) }
db.Close() err = db.AutoMigrate(&KV{})
if err != nil {
return err
}
err = db.AutoMigrate(&Namespace{})
if err != nil {
return err
}
err = db.AutoMigrate(&PreAuthKey{})
if err != nil {
return err
}
err = h.setValue("db_version", dbVersion) err = h.setValue("db_version", dbVersion)
return err return err
} }
func (h *Headscale) db() (*gorm.DB, error) { func (h *Headscale) db() (*gorm.DB, error) {
db, err := gorm.Open(h.dbType, h.dbString) var db *gorm.DB
var err error
switch h.dbType {
case "sqlite3":
db, err = gorm.Open(sqlite.Open(h.dbString), &gorm.Config{})
case "postgres":
db, err = gorm.Open(postgres.Open(h.dbString), &gorm.Config{})
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
if h.dbDebug { if h.dbDebug {
db.LogMode(true) db.Debug()
} }
return db, nil return db, nil
} }
@ -50,9 +69,8 @@ func (h *Headscale) getValue(key string) (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
defer db.Close()
var row KV var row KV
if db.First(&row, "key = ?", key).RecordNotFound() { if result := db.First(&row, "key = ?", key); errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", errors.New("not found") return "", errors.New("not found")
} }
return row.Value, nil return row.Value, nil
@ -67,7 +85,6 @@ func (h *Headscale) setValue(key string, value string) error {
if err != nil { if err != nil {
return err return err
} }
defer db.Close()
_, err = h.getValue(key) _, err = h.getValue(key)
if err == nil { if err == nil {
db.Model(&kv).Where("key = ?", key).Update("value", value) db.Model(&kv).Where("key = ?", key).Update("value", value)

View file

@ -159,7 +159,6 @@ func (h *Headscale) getPeers(m Machine) (*[]*tailcfg.Node, error) {
log.Printf("Cannot open DB: %s", err) log.Printf("Cannot open DB: %s", err)
return nil, err return nil, err
} }
defer db.Close()
machines := []Machine{} machines := []Machine{}
if err = db.Where("namespace_id = ? AND machine_key <> ? AND registered", if err = db.Where("namespace_id = ? AND machine_key <> ? AND registered",

View file

@ -1,10 +1,11 @@
package headscale package headscale
import ( import (
"errors"
"log" "log"
"time" "time"
"github.com/jinzhu/gorm" "gorm.io/gorm"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
) )
@ -29,7 +30,6 @@ func (h *Headscale) CreateNamespace(name string) (*Namespace, error) {
log.Printf("Cannot open DB: %s", err) log.Printf("Cannot open DB: %s", err)
return nil, err return nil, err
} }
defer db.Close()
n := Namespace{} n := Namespace{}
if err := db.Where("name = ?", name).First(&n).Error; err == nil { if err := db.Where("name = ?", name).First(&n).Error; err == nil {
@ -51,7 +51,6 @@ func (h *Headscale) DestroyNamespace(name string) error {
log.Printf("Cannot open DB: %s", err) log.Printf("Cannot open DB: %s", err)
return err return err
} }
defer db.Close()
n, err := h.GetNamespace(name) n, err := h.GetNamespace(name)
if err != nil { if err != nil {
@ -66,8 +65,7 @@ func (h *Headscale) DestroyNamespace(name string) error {
return errorNamespaceNotEmpty return errorNamespaceNotEmpty
} }
err = db.Unscoped().Delete(&n).Error if result := db.Unscoped().Delete(&n); result.Error != nil {
if err != nil {
return err return err
} }
@ -81,10 +79,9 @@ func (h *Headscale) GetNamespace(name string) (*Namespace, error) {
log.Printf("Cannot open DB: %s", err) log.Printf("Cannot open DB: %s", err)
return nil, err return nil, err
} }
defer db.Close()
n := Namespace{} n := Namespace{}
if db.First(&n, "name = ?", name).RecordNotFound() { if result := db.First(&n, "name = ?", name); errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, errorNamespaceNotFound return nil, errorNamespaceNotFound
} }
return &n, nil return &n, nil
@ -97,7 +94,6 @@ func (h *Headscale) ListNamespaces() (*[]Namespace, error) {
log.Printf("Cannot open DB: %s", err) log.Printf("Cannot open DB: %s", err)
return nil, err return nil, err
} }
defer db.Close()
namespaces := []Namespace{} namespaces := []Namespace{}
if err := db.Find(&namespaces).Error; err != nil { if err := db.Find(&namespaces).Error; err != nil {
return nil, err return nil, err
@ -116,7 +112,6 @@ func (h *Headscale) ListMachinesInNamespace(name string) (*[]Machine, error) {
log.Printf("Cannot open DB: %s", err) log.Printf("Cannot open DB: %s", err)
return nil, err return nil, err
} }
defer db.Close()
machines := []Machine{} machines := []Machine{}
if err := db.Preload("AuthKey").Where(&Machine{NamespaceID: n.ID}).Find(&machines).Error; err != nil { if err := db.Preload("AuthKey").Where(&Machine{NamespaceID: n.ID}).Find(&machines).Error; err != nil {
@ -136,7 +131,6 @@ func (h *Headscale) SetMachineNamespace(m *Machine, namespaceName string) error
log.Printf("Cannot open DB: %s", err) log.Printf("Cannot open DB: %s", err)
return err return err
} }
defer db.Close()
m.NamespaceID = n.ID m.NamespaceID = n.ID
db.Save(&m) db.Save(&m)
return nil return nil

View file

@ -3,8 +3,11 @@ package headscale
import ( import (
"crypto/rand" "crypto/rand"
"encoding/hex" "encoding/hex"
"errors"
"log" "log"
"time" "time"
"gorm.io/gorm"
) )
const errorAuthKeyNotFound = Error("AuthKey not found") const errorAuthKeyNotFound = Error("AuthKey not found")
@ -36,7 +39,6 @@ func (h *Headscale) CreatePreAuthKey(namespaceName string, reusable bool, epheme
log.Printf("Cannot open DB: %s", err) log.Printf("Cannot open DB: %s", err)
return nil, err return nil, err
} }
defer db.Close()
now := time.Now().UTC() now := time.Now().UTC()
kstr, err := h.generateKey() kstr, err := h.generateKey()
@ -69,7 +71,6 @@ func (h *Headscale) GetPreAuthKeys(namespaceName string) (*[]PreAuthKey, error)
log.Printf("Cannot open DB: %s", err) log.Printf("Cannot open DB: %s", err)
return nil, err return nil, err
} }
defer db.Close()
keys := []PreAuthKey{} keys := []PreAuthKey{}
if err := db.Preload("Namespace").Where(&PreAuthKey{NamespaceID: n.ID}).Find(&keys).Error; err != nil { if err := db.Preload("Namespace").Where(&PreAuthKey{NamespaceID: n.ID}).Find(&keys).Error; err != nil {
@ -85,10 +86,9 @@ func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer db.Close()
pak := PreAuthKey{} pak := PreAuthKey{}
if db.Preload("Namespace").First(&pak, "key = ?", k).RecordNotFound() { if result := db.Preload("Namespace").First(&pak, "key = ?", k); errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, errorAuthKeyNotFound return nil, errorAuthKeyNotFound
} }

View file

@ -51,7 +51,6 @@ func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr
routes, _ := json.Marshal([]string{routeStr}) // TODO: only one for the time being, so overwriting the rest routes, _ := json.Marshal([]string{routeStr}) // TODO: only one for the time being, so overwriting the rest
m.EnabledRoutes = datatypes.JSON(routes) m.EnabledRoutes = datatypes.JSON(routes)
db.Save(&m) db.Save(&m)
db.Close()
// THIS IS COMPLETELY USELESS. // THIS IS COMPLETELY USELESS.
// The peers map is stored in memory in the server process. // The peers map is stored in memory in the server process.

View file

@ -18,6 +18,7 @@ import (
mathrand "math/rand" mathrand "math/rand"
"golang.org/x/crypto/nacl/box" "golang.org/x/crypto/nacl/box"
"gorm.io/gorm"
"tailscale.com/wgengine/wgcfg" "tailscale.com/wgengine/wgcfg"
) )
@ -81,7 +82,6 @@ func (h *Headscale) getAvailableIP() (*net.IP, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer db.Close()
i := 0 i := 0
for { for {
ip, err := getRandomIP() ip, err := getRandomIP()
@ -89,7 +89,7 @@ func (h *Headscale) getAvailableIP() (*net.IP, error) {
return nil, err return nil, err
} }
m := Machine{} m := Machine{}
if db.First(&m, "ip_address = ?", ip.String()).RecordNotFound() { if result := db.First(&m, "ip_address = ?", ip.String()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
return ip, nil return ip, nil
} }
i++ i++