Update Headscale to depend on gorm v2

This commit is contained in:
Juan Font Alonso 2021-06-24 15:44:19 +02:00
parent df0d214faf
commit 69ba750b38
9 changed files with 47 additions and 40 deletions

9
api.go
View file

@ -3,6 +3,7 @@ package headscale
import (
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"log"
@ -10,9 +11,9 @@ import (
"time"
"github.com/gin-gonic/gin"
"github.com/jinzhu/gorm"
"github.com/klauspost/compress/zstd"
"gorm.io/datatypes"
"gorm.io/gorm"
"inet.af/netaddr"
"tailscale.com/tailcfg"
"tailscale.com/wgengine/wgcfg"
@ -80,10 +81,9 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
c.String(http.StatusInternalServerError, ":(")
return
}
defer db.Close()
var m Machine
if db.First(&m, "machine_key = ?", mKey.HexString()).RecordNotFound() {
if result := db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
log.Println("New Machine!")
m = Machine{
Expiry: &req.Expiry,
@ -209,9 +209,8 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
log.Printf("Cannot open DB: %s", err)
return
}
defer db.Close()
var m Machine
if db.First(&m, "machine_key = ?", mKey.HexString()).RecordNotFound() {
if result := db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
log.Printf("Ignoring request, cannot find machine with key %s", mKey.HexString())
return
}

1
app.go
View file

@ -112,7 +112,6 @@ func (h *Headscale) expireEphemeralNodesWorker() {
log.Printf("Cannot open DB: %s", err)
return
}
defer db.Close()
namespaces, err := h.ListNamespaces()
if err != nil {

4
cli.go
View file

@ -4,6 +4,7 @@ import (
"errors"
"log"
"gorm.io/gorm"
"tailscale.com/wgengine/wgcfg"
)
@ -22,9 +23,8 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err
log.Printf("Cannot open DB: %s", err)
return nil, err
}
defer db.Close()
m := Machine{}
if db.First(&m, "machine_key = ?", mKey.HexString()).RecordNotFound() {
if result := db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, errors.New("Machine not found")
}

45
db.go
View file

@ -3,9 +3,9 @@ package headscale
import (
"errors"
"github.com/jinzhu/gorm"
_ "github.com/jinzhu/gorm/dialects/postgres" // sql driver
_ "github.com/jinzhu/gorm/dialects/sqlite" // sql driver
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
const dbVersion = "1"
@ -17,30 +17,49 @@ type KV struct {
}
func (h *Headscale) initDB() error {
db, err := gorm.Open(h.dbType, h.dbString)
db, err := h.db()
if err != nil {
return err
}
if h.dbType == "postgres" {
db.Exec("create extension if not exists \"uuid-ossp\";")
}
db.AutoMigrate(&Machine{})
db.AutoMigrate(&KV{})
db.AutoMigrate(&Namespace{})
db.AutoMigrate(&PreAuthKey{})
db.Close()
err = db.AutoMigrate(&Machine{})
if err != nil {
return err
}
err = db.AutoMigrate(&KV{})
if err != nil {
return err
}
err = db.AutoMigrate(&Namespace{})
if err != nil {
return err
}
err = db.AutoMigrate(&PreAuthKey{})
if err != nil {
return err
}
err = h.setValue("db_version", dbVersion)
return err
}
func (h *Headscale) db() (*gorm.DB, error) {
db, err := gorm.Open(h.dbType, h.dbString)
var db *gorm.DB
var err error
switch h.dbType {
case "sqlite3":
db, err = gorm.Open(sqlite.Open(h.dbString), &gorm.Config{})
case "postgres":
db, err = gorm.Open(postgres.Open(h.dbString), &gorm.Config{})
}
if err != nil {
return nil, err
}
if h.dbDebug {
db.LogMode(true)
db.Debug()
}
return db, nil
}
@ -50,9 +69,8 @@ func (h *Headscale) getValue(key string) (string, error) {
if err != nil {
return "", err
}
defer db.Close()
var row KV
if db.First(&row, "key = ?", key).RecordNotFound() {
if result := db.First(&row, "key = ?", key); errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", errors.New("not found")
}
return row.Value, nil
@ -67,7 +85,6 @@ func (h *Headscale) setValue(key string, value string) error {
if err != nil {
return err
}
defer db.Close()
_, err = h.getValue(key)
if err == nil {
db.Model(&kv).Where("key = ?", key).Update("value", value)

View file

@ -159,7 +159,6 @@ func (h *Headscale) getPeers(m Machine) (*[]*tailcfg.Node, error) {
log.Printf("Cannot open DB: %s", err)
return nil, err
}
defer db.Close()
machines := []Machine{}
if err = db.Where("namespace_id = ? AND machine_key <> ? AND registered",

View file

@ -1,10 +1,11 @@
package headscale
import (
"errors"
"log"
"time"
"github.com/jinzhu/gorm"
"gorm.io/gorm"
"tailscale.com/tailcfg"
)
@ -29,7 +30,6 @@ func (h *Headscale) CreateNamespace(name string) (*Namespace, error) {
log.Printf("Cannot open DB: %s", err)
return nil, err
}
defer db.Close()
n := Namespace{}
if err := db.Where("name = ?", name).First(&n).Error; err == nil {
@ -51,7 +51,6 @@ func (h *Headscale) DestroyNamespace(name string) error {
log.Printf("Cannot open DB: %s", err)
return err
}
defer db.Close()
n, err := h.GetNamespace(name)
if err != nil {
@ -66,8 +65,7 @@ func (h *Headscale) DestroyNamespace(name string) error {
return errorNamespaceNotEmpty
}
err = db.Unscoped().Delete(&n).Error
if err != nil {
if result := db.Unscoped().Delete(&n); result.Error != nil {
return err
}
@ -81,10 +79,9 @@ func (h *Headscale) GetNamespace(name string) (*Namespace, error) {
log.Printf("Cannot open DB: %s", err)
return nil, err
}
defer db.Close()
n := Namespace{}
if db.First(&n, "name = ?", name).RecordNotFound() {
if result := db.First(&n, "name = ?", name); errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, errorNamespaceNotFound
}
return &n, nil
@ -97,7 +94,6 @@ func (h *Headscale) ListNamespaces() (*[]Namespace, error) {
log.Printf("Cannot open DB: %s", err)
return nil, err
}
defer db.Close()
namespaces := []Namespace{}
if err := db.Find(&namespaces).Error; err != nil {
return nil, err
@ -116,7 +112,6 @@ func (h *Headscale) ListMachinesInNamespace(name string) (*[]Machine, error) {
log.Printf("Cannot open DB: %s", err)
return nil, err
}
defer db.Close()
machines := []Machine{}
if err := db.Preload("AuthKey").Where(&Machine{NamespaceID: n.ID}).Find(&machines).Error; err != nil {
@ -136,7 +131,6 @@ func (h *Headscale) SetMachineNamespace(m *Machine, namespaceName string) error
log.Printf("Cannot open DB: %s", err)
return err
}
defer db.Close()
m.NamespaceID = n.ID
db.Save(&m)
return nil

View file

@ -3,8 +3,11 @@ package headscale
import (
"crypto/rand"
"encoding/hex"
"errors"
"log"
"time"
"gorm.io/gorm"
)
const errorAuthKeyNotFound = Error("AuthKey not found")
@ -36,7 +39,6 @@ func (h *Headscale) CreatePreAuthKey(namespaceName string, reusable bool, epheme
log.Printf("Cannot open DB: %s", err)
return nil, err
}
defer db.Close()
now := time.Now().UTC()
kstr, err := h.generateKey()
@ -69,7 +71,6 @@ func (h *Headscale) GetPreAuthKeys(namespaceName string) (*[]PreAuthKey, error)
log.Printf("Cannot open DB: %s", err)
return nil, err
}
defer db.Close()
keys := []PreAuthKey{}
if err := db.Preload("Namespace").Where(&PreAuthKey{NamespaceID: n.ID}).Find(&keys).Error; err != nil {
@ -85,10 +86,9 @@ func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) {
if err != nil {
return nil, err
}
defer db.Close()
pak := PreAuthKey{}
if db.Preload("Namespace").First(&pak, "key = ?", k).RecordNotFound() {
if result := db.Preload("Namespace").First(&pak, "key = ?", k); errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, errorAuthKeyNotFound
}

View file

@ -51,7 +51,6 @@ func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr
routes, _ := json.Marshal([]string{routeStr}) // TODO: only one for the time being, so overwriting the rest
m.EnabledRoutes = datatypes.JSON(routes)
db.Save(&m)
db.Close()
// THIS IS COMPLETELY USELESS.
// The peers map is stored in memory in the server process.

View file

@ -18,6 +18,7 @@ import (
mathrand "math/rand"
"golang.org/x/crypto/nacl/box"
"gorm.io/gorm"
"tailscale.com/wgengine/wgcfg"
)
@ -81,7 +82,6 @@ func (h *Headscale) getAvailableIP() (*net.IP, error) {
if err != nil {
return nil, err
}
defer db.Close()
i := 0
for {
ip, err := getRandomIP()
@ -89,7 +89,7 @@ func (h *Headscale) getAvailableIP() (*net.IP, error) {
return nil, err
}
m := Machine{}
if db.First(&m, "ip_address = ?", ip.String()).RecordNotFound() {
if result := db.First(&m, "ip_address = ?", ip.String()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
return ip, nil
}
i++