Replace database locks with transactions (#1701)

This commits removes the locks used to guard data integrity for the
database and replaces them with Transactions, turns out that SQL had
a way to deal with this all along.

This reduces the complexity we had with multiple locks that might stack
or recurse (database, nofitifer, mapper). All notifications and state
updates are now triggered _after_ a database change.


Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2024-02-08 17:28:19 +01:00 committed by GitHub
parent cbf57e27a7
commit 83769ba715
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 1496 additions and 1128 deletions

View file

@ -26,7 +26,7 @@ after improving the test harness as part of adopting [#1460](https://github.com/
- Code reorganisation, a lot of code has moved, please review the following PRs accordingly [#1473](https://github.com/juanfont/headscale/pull/1473)
- API: Machine is now Node [#1553](https://github.com/juanfont/headscale/pull/1553)
- Remove support for older Tailscale clients [#1611](https://github.com/juanfont/headscale/pull/1611)
- The latest supported client is 1.36
- The latest supported client is 1.38
- Headscale checks that _at least_ one DERP is defined at start [#1564](https://github.com/juanfont/headscale/pull/1564)
- If no DERP is configured, the server will fail to start, this can be because it cannot load the DERPMap from file or url.
- Embedded DERP server requires a private key [#1611](https://github.com/juanfont/headscale/pull/1611)

View file

@ -6,25 +6,11 @@ import (
"github.com/efekarakus/termcolor"
"github.com/juanfont/headscale/cmd/headscale/cli"
"github.com/pkg/profile"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)
func main() {
if _, enableProfile := os.LookupEnv("HEADSCALE_PROFILING_ENABLED"); enableProfile {
if profilePath, ok := os.LookupEnv("HEADSCALE_PROFILING_PATH"); ok {
err := os.MkdirAll(profilePath, os.ModePerm)
if err != nil {
log.Fatal().Err(err).Msg("failed to create profiling directory")
}
defer profile.Start(profile.ProfilePath(profilePath)).Stop()
} else {
defer profile.Start().Stop()
}
}
var colors bool
switch l := termcolor.SupportLevel(os.Stderr); l {
case termcolor.Level16M:

View file

@ -33,6 +33,7 @@ import (
"github.com/juanfont/headscale/hscontrol/util"
"github.com/patrickmn/go-cache"
zerolog "github.com/philip-bui/grpc-zerolog"
"github.com/pkg/profile"
"github.com/prometheus/client_golang/prometheus/promhttp"
zl "github.com/rs/zerolog"
"github.com/rs/zerolog/log"
@ -48,6 +49,7 @@ import (
"google.golang.org/grpc/peer"
"google.golang.org/grpc/reflection"
"google.golang.org/grpc/status"
"gorm.io/gorm"
"tailscale.com/envknob"
"tailscale.com/tailcfg"
"tailscale.com/types/dnstype"
@ -61,7 +63,7 @@ var (
"unknown value for Lets Encrypt challenge type",
)
errEmptyInitialDERPMap = errors.New(
"initial DERPMap is empty, Headscale requries at least one entry",
"initial DERPMap is empty, Headscale requires at least one entry",
)
)
@ -166,7 +168,6 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
cfg.DBtype,
dbString,
app.dbDebug,
app.nodeNotifier,
cfg.IPPrefixes,
cfg.BaseDomain)
if err != nil {
@ -234,8 +235,23 @@ func (h *Headscale) redirect(w http.ResponseWriter, req *http.Request) {
// seen for longer than h.cfg.EphemeralNodeInactivityTimeout.
func (h *Headscale) expireEphemeralNodes(milliSeconds int64) {
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
var update types.StateUpdate
var changed bool
for range ticker.C {
h.db.ExpireEphemeralNodes(h.cfg.EphemeralNodeInactivityTimeout)
if err := h.db.DB.Transaction(func(tx *gorm.DB) error {
update, changed = db.ExpireEphemeralNodes(tx, h.cfg.EphemeralNodeInactivityTimeout)
return nil
}); err != nil {
log.Error().Err(err).Msg("database error while expiring ephemeral nodes")
continue
}
if changed && update.Valid() {
ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na")
h.nodeNotifier.NotifyAll(ctx, update)
}
}
}
@ -246,9 +262,24 @@ func (h *Headscale) expireExpiredMachines(intervalMs int64) {
ticker := time.NewTicker(interval)
lastCheck := time.Unix(0, 0)
var update types.StateUpdate
var changed bool
for range ticker.C {
lastCheck = h.db.ExpireExpiredNodes(lastCheck)
if err := h.db.DB.Transaction(func(tx *gorm.DB) error {
lastCheck, update, changed = db.ExpireExpiredNodes(tx, lastCheck)
return nil
}); err != nil {
log.Error().Err(err).Msg("database error while expiring nodes")
continue
}
log.Trace().Str("nodes", update.ChangeNodes.String()).Msgf("expiring nodes")
if changed && update.Valid() {
ctx := types.NotifyCtx(context.Background(), "expire-expired", "na")
h.nodeNotifier.NotifyAll(ctx, update)
}
}
}
@ -278,7 +309,8 @@ func (h *Headscale) scheduledDERPMapUpdateWorker(cancelChan <-chan struct{}) {
DERPMap: h.DERPMap,
}
if stateUpdate.Valid() {
h.nodeNotifier.NotifyAll(stateUpdate)
ctx := types.NotifyCtx(context.Background(), "derpmap-update", "na")
h.nodeNotifier.NotifyAll(ctx, stateUpdate)
}
}
}
@ -485,6 +517,19 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
// Serve launches a GIN server with the Headscale API.
func (h *Headscale) Serve() error {
if _, enableProfile := os.LookupEnv("HEADSCALE_PROFILING_ENABLED"); enableProfile {
if profilePath, ok := os.LookupEnv("HEADSCALE_PROFILING_PATH"); ok {
err := os.MkdirAll(profilePath, os.ModePerm)
if err != nil {
log.Fatal().Err(err).Msg("failed to create profiling directory")
}
defer profile.Start(profile.ProfilePath(profilePath)).Stop()
} else {
defer profile.Start().Stop()
}
}
var err error
// Fetch an initial DERP Map before we start serving
@ -753,7 +798,8 @@ func (h *Headscale) Serve() error {
Str("path", aclPath).
Msg("ACL policy successfully reloaded, notifying nodes of change")
h.nodeNotifier.NotifyAll(types.StateUpdate{
ctx := types.NotifyCtx(context.Background(), "acl-sighup", "na")
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StateFullUpdate,
})
}

View file

@ -1,6 +1,7 @@
package hscontrol
import (
"context"
"encoding/json"
"errors"
"fmt"
@ -8,6 +9,7 @@ import (
"strings"
"time"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
@ -243,8 +245,6 @@ func (h *Headscale) handleRegister(
// handleAuthKey contains the logic to manage auth key client registration
// When using Noise, the machineKey is Zero.
//
// TODO: check if any locks are needed around IP allocation.
func (h *Headscale) handleAuthKey(
writer http.ResponseWriter,
registerRequest tailcfg.RegisterRequest,
@ -311,6 +311,9 @@ func (h *Headscale) handleAuthKey(
nodeKey := registerRequest.NodeKey
var update types.StateUpdate
var mkey key.MachinePublic
// retrieve node information if it exist
// The error is not important, because if it does not
// exist, then this is a new node and we will move
@ -324,7 +327,7 @@ func (h *Headscale) handleAuthKey(
node.NodeKey = nodeKey
node.AuthKeyID = uint(pak.ID)
err := h.db.NodeSetExpiry(node, registerRequest.Expiry)
err := h.db.NodeSetExpiry(node.ID, registerRequest.Expiry)
if err != nil {
log.Error().
Caller().
@ -335,10 +338,13 @@ func (h *Headscale) handleAuthKey(
return
}
mkey = node.MachineKey
update = types.StateUpdateExpire(node.ID, registerRequest.Expiry)
aclTags := pak.Proto().GetAclTags()
if len(aclTags) > 0 {
// This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login
err = h.db.SetTags(node, aclTags)
err = h.db.SetTags(node.ID, aclTags)
if err != nil {
log.Error().
@ -370,6 +376,7 @@ func (h *Headscale) handleAuthKey(
Hostname: registerRequest.Hostinfo.Hostname,
GivenName: givenName,
UserID: pak.User.ID,
User: pak.User,
MachineKey: machineKey,
RegisterMethod: util.RegisterMethodAuthKey,
Expiry: &registerRequest.Expiry,
@ -393,9 +400,18 @@ func (h *Headscale) handleAuthKey(
return
}
mkey = node.MachineKey
update = types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{node},
Message: "called from auth.handleAuthKey",
}
}
err = h.db.UsePreAuthKey(pak)
err = h.db.DB.Transaction(func(tx *gorm.DB) error {
return db.UsePreAuthKey(tx, pak)
})
if err != nil {
log.Error().
Caller().
@ -437,6 +453,13 @@ func (h *Headscale) handleAuthKey(
Caller().
Err(err).
Msg("Failed to write response")
return
}
// TODO(kradalby): if notifying after register make sense.
if update.Valid() {
ctx := types.NotifyCtx(context.Background(), "handle-authkey", "na")
h.nodeNotifier.NotifyWithIgnore(ctx, update, mkey.String())
}
log.Info().
@ -502,7 +525,7 @@ func (h *Headscale) handleNodeLogOut(
Msg("Client requested logout")
now := time.Now()
err := h.db.NodeSetExpiry(&node, now)
err := h.db.NodeSetExpiry(node.ID, now)
if err != nil {
log.Error().
Caller().
@ -513,17 +536,10 @@ func (h *Headscale) handleNodeLogOut(
return
}
stateUpdate := types.StateUpdate{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: tailcfg.NodeID(node.ID),
KeyExpiry: &now,
},
},
}
stateUpdate := types.StateUpdateExpire(node.ID, now)
if stateUpdate.Valid() {
h.nodeNotifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String())
ctx := types.NotifyCtx(context.Background(), "logout-expiry", "na")
h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
}
resp.AuthURL = ""
@ -554,7 +570,7 @@ func (h *Headscale) handleNodeLogOut(
}
if node.IsEphemeral() {
err = h.db.DeleteNode(&node)
err = h.db.DeleteNode(&node, h.nodeNotifier.ConnectedMap())
if err != nil {
log.Error().
Err(err).
@ -562,6 +578,15 @@ func (h *Headscale) handleNodeLogOut(
Msg("Cannot delete ephemeral node from the database")
}
stateUpdate := types.StateUpdate{
Type: types.StatePeerRemoved,
Removed: []tailcfg.NodeID{tailcfg.NodeID(node.ID)},
}
if stateUpdate.Valid() {
ctx := types.NotifyCtx(context.Background(), "logout-ephemeral", "na")
h.nodeNotifier.NotifyAll(ctx, stateUpdate)
}
return
}
@ -633,7 +658,9 @@ func (h *Headscale) handleNodeKeyRefresh(
Str("node", node.Hostname).
Msg("We have the OldNodeKey in the database. This is a key refresh")
err := h.db.NodeSetNodeKey(&node, registerRequest.NodeKey)
err := h.db.DB.Transaction(func(tx *gorm.DB) error {
return db.NodeSetNodeKey(tx, &node, registerRequest.NodeKey)
})
if err != nil {
log.Error().
Caller().

View file

@ -13,16 +13,23 @@ import (
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"go4.org/netipx"
"gorm.io/gorm"
)
var ErrCouldNotAllocateIP = errors.New("could not find any suitable IP")
func (hsdb *HSDatabase) getAvailableIPs() (types.NodeAddresses, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (types.NodeAddresses, error) {
return getAvailableIPs(rx, hsdb.ipPrefixes)
})
}
func getAvailableIPs(rx *gorm.DB, ipPrefixes []netip.Prefix) (types.NodeAddresses, error) {
var ips types.NodeAddresses
var err error
for _, ipPrefix := range hsdb.ipPrefixes {
for _, ipPrefix := range ipPrefixes {
var ip *netip.Addr
ip, err = hsdb.getAvailableIP(ipPrefix)
ip, err = getAvailableIP(rx, ipPrefix)
if err != nil {
return ips, err
}
@ -32,8 +39,8 @@ func (hsdb *HSDatabase) getAvailableIPs() (types.NodeAddresses, error) {
return ips, err
}
func (hsdb *HSDatabase) getAvailableIP(ipPrefix netip.Prefix) (*netip.Addr, error) {
usedIps, err := hsdb.getUsedIPs()
func getAvailableIP(rx *gorm.DB, ipPrefix netip.Prefix) (*netip.Addr, error) {
usedIps, err := getUsedIPs(rx)
if err != nil {
return nil, err
}
@ -64,12 +71,12 @@ func (hsdb *HSDatabase) getAvailableIP(ipPrefix netip.Prefix) (*netip.Addr, erro
}
}
func (hsdb *HSDatabase) getUsedIPs() (*netipx.IPSet, error) {
func getUsedIPs(rx *gorm.DB) (*netipx.IPSet, error) {
// FIXME: This really deserves a better data model,
// but this was quick to get running and it should be enough
// to begin experimenting with a dual stack tailnet.
var addressesSlices []string
hsdb.db.Model(&types.Node{}).Pluck("ip_addresses", &addressesSlices)
rx.Model(&types.Node{}).Pluck("ip_addresses", &addressesSlices)
var ips netipx.IPSetBuilder
for _, slice := range addressesSlices {

View file

@ -7,10 +7,16 @@ import (
"github.com/juanfont/headscale/hscontrol/util"
"go4.org/netipx"
"gopkg.in/check.v1"
"gorm.io/gorm"
)
func (s *Suite) TestGetAvailableIp(c *check.C) {
ips, err := db.getAvailableIPs()
tx := db.DB.Begin()
defer tx.Rollback()
ips, err := getAvailableIPs(tx, []netip.Prefix{
netip.MustParsePrefix("10.27.0.0/23"),
})
c.Assert(err, check.IsNil)
@ -30,7 +36,7 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetNode("test", "testnode")
_, err = db.getNode("test", "testnode")
c.Assert(err, check.NotNil)
node := types.Node{
@ -41,10 +47,13 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
AuthKeyID: uint(pak.ID),
IPAddresses: ips,
}
db.db.Save(&node)
usedIps, err := db.getUsedIPs()
db.Write(func(tx *gorm.DB) error {
return tx.Save(&node).Error
})
usedIps, err := Read(db.DB, func(rx *gorm.DB) (*netipx.IPSet, error) {
return getUsedIPs(rx)
})
c.Assert(err, check.IsNil)
expected := netip.MustParseAddr("10.27.0.1")
@ -63,19 +72,23 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
}
func (s *Suite) TestGetMultiIp(c *check.C) {
user, err := db.CreateUser("test-ip-multi")
user, err := db.CreateUser("test-ip")
c.Assert(err, check.IsNil)
ipPrefixes := []netip.Prefix{
netip.MustParsePrefix("10.27.0.0/23"),
}
for index := 1; index <= 350; index++ {
db.ipAllocationMutex.Lock()
tx := db.DB.Begin()
ips, err := db.getAvailableIPs()
ips, err := getAvailableIPs(tx, ipPrefixes)
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
pak, err := CreatePreAuthKey(tx, user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetNode("test", "testnode")
_, err = getNode(tx, "test", "testnode")
c.Assert(err, check.NotNil)
node := types.Node{
@ -86,12 +99,13 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
AuthKeyID: uint(pak.ID),
IPAddresses: ips,
}
db.db.Save(&node)
db.ipAllocationMutex.Unlock()
tx.Save(&node)
c.Assert(tx.Commit().Error, check.IsNil)
}
usedIps, err := db.getUsedIPs()
usedIps, err := Read(db.DB, func(rx *gorm.DB) (*netipx.IPSet, error) {
return getUsedIPs(rx)
})
c.Assert(err, check.IsNil)
expected0 := netip.MustParseAddr("10.27.0.1")
@ -162,7 +176,7 @@ func (s *Suite) TestGetAvailableIpNodeWithoutIP(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetNode("test", "testnode")
_, err = db.getNode("test", "testnode")
c.Assert(err, check.NotNil)
node := types.Node{
@ -172,7 +186,7 @@ func (s *Suite) TestGetAvailableIpNodeWithoutIP(c *check.C) {
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
db.db.Save(&node)
db.DB.Save(&node)
ips2, err := db.getAvailableIPs()
c.Assert(err, check.IsNil)

View file

@ -22,9 +22,6 @@ var ErrAPIKeyFailedToParse = errors.New("failed to parse ApiKey")
func (hsdb *HSDatabase) CreateAPIKey(
expiration *time.Time,
) (string, *types.APIKey, error) {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
prefix, err := util.GenerateRandomStringURLSafe(apiPrefixLength)
if err != nil {
return "", nil, err
@ -49,7 +46,7 @@ func (hsdb *HSDatabase) CreateAPIKey(
Expiration: expiration,
}
if err := hsdb.db.Save(&key).Error; err != nil {
if err := hsdb.DB.Save(&key).Error; err != nil {
return "", nil, fmt.Errorf("failed to save API key to database: %w", err)
}
@ -58,11 +55,8 @@ func (hsdb *HSDatabase) CreateAPIKey(
// ListAPIKeys returns the list of ApiKeys for a user.
func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
keys := []types.APIKey{}
if err := hsdb.db.Find(&keys).Error; err != nil {
if err := hsdb.DB.Find(&keys).Error; err != nil {
return nil, err
}
@ -71,11 +65,8 @@ func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) {
// GetAPIKey returns a ApiKey for a given key.
func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
key := types.APIKey{}
if result := hsdb.db.First(&key, "prefix = ?", prefix); result.Error != nil {
if result := hsdb.DB.First(&key, "prefix = ?", prefix); result.Error != nil {
return nil, result.Error
}
@ -84,11 +75,8 @@ func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) {
// GetAPIKeyByID returns a ApiKey for a given id.
func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
key := types.APIKey{}
if result := hsdb.db.Find(&types.APIKey{ID: id}).First(&key); result.Error != nil {
if result := hsdb.DB.Find(&types.APIKey{ID: id}).First(&key); result.Error != nil {
return nil, result.Error
}
@ -98,10 +86,7 @@ func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) {
// DestroyAPIKey destroys a ApiKey. Returns error if the ApiKey
// does not exist.
func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
if result := hsdb.db.Unscoped().Delete(key); result.Error != nil {
if result := hsdb.DB.Unscoped().Delete(key); result.Error != nil {
return result.Error
}
@ -110,10 +95,7 @@ func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error {
// ExpireAPIKey marks a ApiKey as expired.
func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
if err := hsdb.db.Model(&key).Update("Expiration", time.Now()).Error; err != nil {
if err := hsdb.DB.Model(&key).Update("Expiration", time.Now()).Error; err != nil {
return err
}
@ -121,9 +103,6 @@ func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error {
}
func (hsdb *HSDatabase) ValidateAPIKey(keyStr string) (bool, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
prefix, hash, found := strings.Cut(keyStr, ".")
if !found {
return false, ErrAPIKeyFailedToParse

View file

@ -7,12 +7,10 @@ import (
"fmt"
"net/netip"
"strings"
"sync"
"time"
"github.com/glebarez/sqlite"
"github.com/go-gormigrate/gormigrate/v2"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
@ -36,12 +34,7 @@ type KV struct {
}
type HSDatabase struct {
db *gorm.DB
notifier *notifier.Notifier
mu sync.RWMutex
ipAllocationMutex sync.Mutex
DB *gorm.DB
ipPrefixes []netip.Prefix
baseDomain string
@ -52,7 +45,6 @@ type HSDatabase struct {
func NewHeadscaleDatabase(
dbType, connectionAddr string,
debug bool,
notifier *notifier.Notifier,
ipPrefixes []netip.Prefix,
baseDomain string,
) (*HSDatabase, error) {
@ -147,7 +139,9 @@ func NewHeadscaleDatabase(
DiscoKey string
}
var results []result
err = tx.Raw("SELECT id, node_key, machine_key, disco_key FROM nodes").Find(&results).Error
err = tx.Raw("SELECT id, node_key, machine_key, disco_key FROM nodes").
Find(&results).
Error
if err != nil {
return err
}
@ -180,7 +174,8 @@ func NewHeadscaleDatabase(
}
if tx.Migrator().HasColumn(&types.Node{}, "enabled_routes") {
log.Info().Msgf("Database has legacy enabled_routes column in node, migrating...")
log.Info().
Msgf("Database has legacy enabled_routes column in node, migrating...")
type NodeAux struct {
ID uint64
@ -317,8 +312,7 @@ func NewHeadscaleDatabase(
}
db := HSDatabase{
db: dbConn,
notifier: notifier,
DB: dbConn,
ipPrefixes: ipPrefixes,
baseDomain: baseDomain,
@ -376,7 +370,7 @@ func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) {
func (hsdb *HSDatabase) PingDB(ctx context.Context) error {
ctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
sqlDB, err := hsdb.db.DB()
sqlDB, err := hsdb.DB.DB()
if err != nil {
return err
}
@ -385,10 +379,48 @@ func (hsdb *HSDatabase) PingDB(ctx context.Context) error {
}
func (hsdb *HSDatabase) Close() error {
db, err := hsdb.db.DB()
db, err := hsdb.DB.DB()
if err != nil {
return err
}
return db.Close()
}
func (hsdb *HSDatabase) Read(fn func(rx *gorm.DB) error) error {
rx := hsdb.DB.Begin()
defer rx.Rollback()
return fn(rx)
}
func Read[T any](db *gorm.DB, fn func(rx *gorm.DB) (T, error)) (T, error) {
rx := db.Begin()
defer rx.Rollback()
ret, err := fn(rx)
if err != nil {
var no T
return no, err
}
return ret, nil
}
func (hsdb *HSDatabase) Write(fn func(tx *gorm.DB) error) error {
tx := hsdb.DB.Begin()
defer tx.Rollback()
if err := fn(tx); err != nil {
return err
}
return tx.Commit().Error
}
func Write[T any](db *gorm.DB, fn func(tx *gorm.DB) (T, error)) (T, error) {
tx := db.Begin()
defer tx.Rollback()
ret, err := fn(tx)
if err != nil {
var no T
return no, err
}
return ret, tx.Commit().Error
}

View file

@ -34,22 +34,21 @@ var (
)
)
// ListPeers returns all peers of node, regardless of any Policy or if the node is expired.
func (hsdb *HSDatabase) ListPeers(node *types.Node) (types.Nodes, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.listPeers(node)
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
return ListPeers(rx, node)
})
}
func (hsdb *HSDatabase) listPeers(node *types.Node) (types.Nodes, error) {
// ListPeers returns all peers of node, regardless of any Policy or if the node is expired.
func ListPeers(tx *gorm.DB, node *types.Node) (types.Nodes, error) {
log.Trace().
Caller().
Str("node", node.Hostname).
Msg("Finding direct peers")
nodes := types.Nodes{}
if err := hsdb.db.
if err := tx.
Preload("AuthKey").
Preload("AuthKey.User").
Preload("User").
@ -64,16 +63,15 @@ func (hsdb *HSDatabase) listPeers(node *types.Node) (types.Nodes, error) {
return nodes, nil
}
func (hsdb *HSDatabase) ListNodes() ([]types.Node, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.listNodes()
func (hsdb *HSDatabase) ListNodes() (types.Nodes, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
return ListNodes(rx)
})
}
func (hsdb *HSDatabase) listNodes() ([]types.Node, error) {
nodes := []types.Node{}
if err := hsdb.db.
func ListNodes(tx *gorm.DB) (types.Nodes, error) {
nodes := types.Nodes{}
if err := tx.
Preload("AuthKey").
Preload("AuthKey.User").
Preload("User").
@ -85,16 +83,9 @@ func (hsdb *HSDatabase) listNodes() ([]types.Node, error) {
return nodes, nil
}
func (hsdb *HSDatabase) ListNodesByGivenName(givenName string) (types.Nodes, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.listNodesByGivenName(givenName)
}
func (hsdb *HSDatabase) listNodesByGivenName(givenName string) (types.Nodes, error) {
func listNodesByGivenName(tx *gorm.DB, givenName string) (types.Nodes, error) {
nodes := types.Nodes{}
if err := hsdb.db.
if err := tx.
Preload("AuthKey").
Preload("AuthKey.User").
Preload("User").
@ -106,12 +97,15 @@ func (hsdb *HSDatabase) listNodesByGivenName(givenName string) (types.Nodes, err
return nodes, nil
}
// GetNode finds a Node by name and user and returns the Node struct.
func (hsdb *HSDatabase) GetNode(user string, name string) (*types.Node, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
func (hsdb *HSDatabase) getNode(user string, name string) (*types.Node, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
return getNode(rx, user, name)
})
}
nodes, err := hsdb.ListNodesByUser(user)
// getNode finds a Node by name and user and returns the Node struct.
func getNode(tx *gorm.DB, user string, name string) (*types.Node, error) {
nodes, err := ListNodesByUser(tx, user)
if err != nil {
return nil, err
}
@ -125,34 +119,16 @@ func (hsdb *HSDatabase) GetNode(user string, name string) (*types.Node, error) {
return nil, ErrNodeNotFound
}
// GetNodeByGivenName finds a Node by given name and user and returns the Node struct.
func (hsdb *HSDatabase) GetNodeByGivenName(
user string,
givenName string,
) (*types.Node, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
node := types.Node{}
if err := hsdb.db.
Preload("AuthKey").
Preload("AuthKey.User").
Preload("User").
Preload("Routes").
Where("given_name = ?", givenName).First(&node).Error; err != nil {
return nil, err
}
return nil, ErrNodeNotFound
func (hsdb *HSDatabase) GetNodeByID(id uint64) (*types.Node, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
return GetNodeByID(rx, id)
})
}
// GetNodeByID finds a Node by ID and returns the Node struct.
func (hsdb *HSDatabase) GetNodeByID(id uint64) (*types.Node, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
func GetNodeByID(tx *gorm.DB, id uint64) (*types.Node, error) {
mach := types.Node{}
if result := hsdb.db.
if result := tx.
Preload("AuthKey").
Preload("AuthKey.User").
Preload("User").
@ -164,21 +140,19 @@ func (hsdb *HSDatabase) GetNodeByID(id uint64) (*types.Node, error) {
return &mach, nil
}
// GetNodeByMachineKey finds a Node by its MachineKey and returns the Node struct.
func (hsdb *HSDatabase) GetNodeByMachineKey(
machineKey key.MachinePublic,
) (*types.Node, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.getNodeByMachineKey(machineKey)
func (hsdb *HSDatabase) GetNodeByMachineKey(machineKey key.MachinePublic) (*types.Node, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
return GetNodeByMachineKey(rx, machineKey)
})
}
func (hsdb *HSDatabase) getNodeByMachineKey(
// GetNodeByMachineKey finds a Node by its MachineKey and returns the Node struct.
func GetNodeByMachineKey(
tx *gorm.DB,
machineKey key.MachinePublic,
) (*types.Node, error) {
mach := types.Node{}
if result := hsdb.db.
if result := tx.
Preload("AuthKey").
Preload("AuthKey.User").
Preload("User").
@ -190,36 +164,24 @@ func (hsdb *HSDatabase) getNodeByMachineKey(
return &mach, nil
}
// GetNodeByNodeKey finds a Node by its current NodeKey.
func (hsdb *HSDatabase) GetNodeByNodeKey(
func (hsdb *HSDatabase) GetNodeByAnyKey(
machineKey key.MachinePublic,
nodeKey key.NodePublic,
oldNodeKey key.NodePublic,
) (*types.Node, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
node := types.Node{}
if result := hsdb.db.
Preload("AuthKey").
Preload("AuthKey.User").
Preload("User").
Preload("Routes").
First(&node, "node_key = ?",
nodeKey.String()); result.Error != nil {
return nil, result.Error
}
return &node, nil
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
return GetNodeByAnyKey(rx, machineKey, nodeKey, oldNodeKey)
})
}
// GetNodeByAnyKey finds a Node by its MachineKey, its current NodeKey or the old one, and returns the Node struct.
func (hsdb *HSDatabase) GetNodeByAnyKey(
// TODO(kradalby): see if we can remove this.
func GetNodeByAnyKey(
tx *gorm.DB,
machineKey key.MachinePublic, nodeKey key.NodePublic, oldNodeKey key.NodePublic,
) (*types.Node, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
node := types.Node{}
if result := hsdb.db.
if result := tx.
Preload("AuthKey").
Preload("AuthKey.User").
Preload("User").
@ -234,60 +196,44 @@ func (hsdb *HSDatabase) GetNodeByAnyKey(
return &node, nil
}
func (hsdb *HSDatabase) NodeReloadFromDatabase(node *types.Node) error {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
if result := hsdb.db.Find(node).First(&node); result.Error != nil {
return result.Error
}
return nil
func (hsdb *HSDatabase) SetTags(
nodeID uint64,
tags []string,
) error {
return hsdb.Write(func(tx *gorm.DB) error {
return SetTags(tx, nodeID, tags)
})
}
// SetTags takes a Node struct pointer and update the forced tags.
func (hsdb *HSDatabase) SetTags(
node *types.Node,
func SetTags(
tx *gorm.DB,
nodeID uint64,
tags []string,
) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
if len(tags) == 0 {
return nil
}
newTags := []string{}
newTags := types.StringList{}
for _, tag := range tags {
if !util.StringOrPrefixListContains(newTags, tag) {
newTags = append(newTags, tag)
}
}
if err := hsdb.db.Model(node).Updates(types.Node{
ForcedTags: newTags,
}).Error; err != nil {
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", newTags).Error; err != nil {
return fmt.Errorf("failed to update tags for node in the database: %w", err)
}
stateUpdate := types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{node},
Message: "called from db.SetTags",
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String())
}
return nil
}
// RenameNode takes a Node struct and a new GivenName for the nodes
// and renames it.
func (hsdb *HSDatabase) RenameNode(node *types.Node, newName string) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
func RenameNode(tx *gorm.DB,
nodeID uint64, newName string,
) error {
err := util.CheckForFQDNRules(
newName,
)
@ -295,129 +241,74 @@ func (hsdb *HSDatabase) RenameNode(node *types.Node, newName string) error {
log.Error().
Caller().
Str("func", "RenameNode").
Str("node", node.Hostname).
Uint64("nodeID", nodeID).
Str("newName", newName).
Err(err).
Msg("failed to rename node")
return err
}
node.GivenName = newName
if err := hsdb.db.Model(node).Updates(types.Node{
GivenName: newName,
}).Error; err != nil {
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil {
return fmt.Errorf("failed to rename node in the database: %w", err)
}
stateUpdate := types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{node},
Message: "called from db.RenameNode",
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String())
return nil
}
return nil
func (hsdb *HSDatabase) NodeSetExpiry(nodeID uint64, expiry time.Time) error {
return hsdb.Write(func(tx *gorm.DB) error {
return NodeSetExpiry(tx, nodeID, expiry)
})
}
// NodeSetExpiry takes a Node struct and a new expiry time.
func (hsdb *HSDatabase) NodeSetExpiry(node *types.Node, expiry time.Time) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
return hsdb.nodeSetExpiry(node, expiry)
func NodeSetExpiry(tx *gorm.DB,
nodeID uint64, expiry time.Time,
) error {
return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("expiry", expiry).Error
}
func (hsdb *HSDatabase) nodeSetExpiry(node *types.Node, expiry time.Time) error {
if err := hsdb.db.Model(node).Updates(types.Node{
Expiry: &expiry,
}).Error; err != nil {
return fmt.Errorf(
"failed to refresh node (update expiration) in the database: %w",
err,
)
}
node.Expiry = &expiry
stateSelfUpdate := types.StateUpdate{
Type: types.StateSelfUpdate,
ChangeNodes: types.Nodes{node},
}
if stateSelfUpdate.Valid() {
hsdb.notifier.NotifyByMachineKey(stateSelfUpdate, node.MachineKey)
}
stateUpdate := types.StateUpdate{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: tailcfg.NodeID(node.ID),
KeyExpiry: &expiry,
},
},
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String())
}
return nil
func (hsdb *HSDatabase) DeleteNode(node *types.Node, isConnected map[key.MachinePublic]bool) error {
return hsdb.Write(func(tx *gorm.DB) error {
return DeleteNode(tx, node, isConnected)
})
}
// DeleteNode deletes a Node from the database.
func (hsdb *HSDatabase) DeleteNode(node *types.Node) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
return hsdb.deleteNode(node)
}
func (hsdb *HSDatabase) deleteNode(node *types.Node) error {
err := hsdb.deleteNodeRoutes(node)
// Caller is responsible for notifying all of change.
func DeleteNode(tx *gorm.DB,
node *types.Node,
isConnected map[key.MachinePublic]bool,
) error {
err := deleteNodeRoutes(tx, node, map[key.MachinePublic]bool{})
if err != nil {
return err
}
// Unscoped causes the node to be fully removed from the database.
if err := hsdb.db.Unscoped().Delete(&node).Error; err != nil {
if err := tx.Unscoped().Delete(&node).Error; err != nil {
return err
}
stateUpdate := types.StateUpdate{
Type: types.StatePeerRemoved,
Removed: []tailcfg.NodeID{tailcfg.NodeID(node.ID)},
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyAll(stateUpdate)
}
return nil
}
// UpdateLastSeen sets a node's last seen field indicating that we
// have recently communicating with this node.
// This is mostly used to indicate if a node is online and is not
// extremely important to make sure is fully correct and to avoid
// holding up the hot path, does not contain any locks and isnt
// concurrency safe. But that should be ok.
func (hsdb *HSDatabase) UpdateLastSeen(node *types.Node) error {
return hsdb.db.Model(node).Updates(types.Node{
LastSeen: node.LastSeen,
}).Error
func UpdateLastSeen(tx *gorm.DB, nodeID uint64, lastSeen time.Time) error {
return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("last_seen", lastSeen).Error
}
func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
func RegisterNodeFromAuthCallback(
tx *gorm.DB,
cache *cache.Cache,
mkey key.MachinePublic,
userName string,
nodeExpiry *time.Time,
registrationMethod string,
ipPrefixes []netip.Prefix,
) (*types.Node, error) {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
log.Debug().
Str("machine_key", mkey.ShortString()).
Str("userName", userName).
@ -427,7 +318,7 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
if nodeInterface, ok := cache.Get(mkey.String()); ok {
if registrationNode, ok := nodeInterface.(types.Node); ok {
user, err := hsdb.getUser(userName)
user, err := GetUser(tx, userName)
if err != nil {
return nil, fmt.Errorf(
"failed to find user in register node from auth callback, %w",
@ -442,14 +333,17 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
}
registrationNode.UserID = user.ID
registrationNode.User = *user
registrationNode.RegisterMethod = registrationMethod
if nodeExpiry != nil {
registrationNode.Expiry = nodeExpiry
}
node, err := hsdb.registerNode(
node, err := RegisterNode(
tx,
registrationNode,
ipPrefixes,
)
if err == nil {
@ -465,15 +359,14 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
return nil, ErrNodeNotFoundRegistrationCache
}
// RegisterNode is executed from the CLI to register a new Node using its MachineKey.
func (hsdb *HSDatabase) RegisterNode(node types.Node) (*types.Node, error) {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
return hsdb.registerNode(node)
return Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
return RegisterNode(tx, node, hsdb.ipPrefixes)
})
}
func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) {
// RegisterNode is executed from the CLI to register a new Node using its MachineKey.
func RegisterNode(tx *gorm.DB, node types.Node, ipPrefixes []netip.Prefix) (*types.Node, error) {
log.Debug().
Str("node", node.Hostname).
Str("machine_key", node.MachineKey.ShortString()).
@ -485,7 +378,7 @@ func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) {
// so we store the node.Expire and node.Nodekey that has been set when
// adding it to the registrationCache
if len(node.IPAddresses) > 0 {
if err := hsdb.db.Save(&node).Error; err != nil {
if err := tx.Save(&node).Error; err != nil {
return nil, fmt.Errorf("failed register existing node in the database: %w", err)
}
@ -500,10 +393,7 @@ func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) {
return &node, nil
}
hsdb.ipAllocationMutex.Lock()
defer hsdb.ipAllocationMutex.Unlock()
ips, err := hsdb.getAvailableIPs()
ips, err := getAvailableIPs(tx, ipPrefixes)
if err != nil {
log.Error().
Caller().
@ -516,7 +406,7 @@ func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) {
node.IPAddresses = ips
if err := hsdb.db.Save(&node).Error; err != nil {
if err := tx.Save(&node).Error; err != nil {
return nil, fmt.Errorf("failed register(save) node in the database: %w", err)
}
@ -530,61 +420,50 @@ func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) {
}
// NodeSetNodeKey sets the node key of a node and saves it to the database.
func (hsdb *HSDatabase) NodeSetNodeKey(node *types.Node, nodeKey key.NodePublic) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
if err := hsdb.db.Model(node).Updates(types.Node{
func NodeSetNodeKey(tx *gorm.DB, node *types.Node, nodeKey key.NodePublic) error {
return tx.Model(node).Updates(types.Node{
NodeKey: nodeKey,
}).Error; err != nil {
return err
}).Error
}
return nil
}
// NodeSetMachineKey sets the node key of a node and saves it to the database.
func (hsdb *HSDatabase) NodeSetMachineKey(
node *types.Node,
machineKey key.MachinePublic,
) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
if err := hsdb.db.Model(node).Updates(types.Node{
MachineKey: machineKey,
}).Error; err != nil {
return err
return hsdb.Write(func(tx *gorm.DB) error {
return NodeSetMachineKey(tx, node, machineKey)
})
}
return nil
// NodeSetMachineKey sets the node key of a node and saves it to the database.
func NodeSetMachineKey(
tx *gorm.DB,
node *types.Node,
machineKey key.MachinePublic,
) error {
return tx.Model(node).Updates(types.Node{
MachineKey: machineKey,
}).Error
}
// NodeSave saves a node object to the database, prefer to use a specific save method rather
// than this. It is intended to be used when we are changing or.
func (hsdb *HSDatabase) NodeSave(node *types.Node) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
if err := hsdb.db.Save(node).Error; err != nil {
return err
// TODO(kradalby): Remove this func, just use Save.
func NodeSave(tx *gorm.DB, node *types.Node) error {
return tx.Save(node).Error
}
return nil
func (hsdb *HSDatabase) GetAdvertisedRoutes(node *types.Node) ([]netip.Prefix, error) {
return Read(hsdb.DB, func(rx *gorm.DB) ([]netip.Prefix, error) {
return GetAdvertisedRoutes(rx, node)
})
}
// GetAdvertisedRoutes returns the routes that are be advertised by the given node.
func (hsdb *HSDatabase) GetAdvertisedRoutes(node *types.Node) ([]netip.Prefix, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.getAdvertisedRoutes(node)
}
func (hsdb *HSDatabase) getAdvertisedRoutes(node *types.Node) ([]netip.Prefix, error) {
func GetAdvertisedRoutes(tx *gorm.DB, node *types.Node) ([]netip.Prefix, error) {
routes := types.Routes{}
err := hsdb.db.
err := tx.
Preload("Node").
Where("node_id = ? AND advertised = ?", node.ID, true).Find(&routes).Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
@ -605,18 +484,17 @@ func (hsdb *HSDatabase) getAdvertisedRoutes(node *types.Node) ([]netip.Prefix, e
return prefixes, nil
}
// GetEnabledRoutes returns the routes that are enabled for the node.
func (hsdb *HSDatabase) GetEnabledRoutes(node *types.Node) ([]netip.Prefix, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.getEnabledRoutes(node)
return Read(hsdb.DB, func(rx *gorm.DB) ([]netip.Prefix, error) {
return GetEnabledRoutes(rx, node)
})
}
func (hsdb *HSDatabase) getEnabledRoutes(node *types.Node) ([]netip.Prefix, error) {
// GetEnabledRoutes returns the routes that are enabled for the node.
func GetEnabledRoutes(tx *gorm.DB, node *types.Node) ([]netip.Prefix, error) {
routes := types.Routes{}
err := hsdb.db.
err := tx.
Preload("Node").
Where("node_id = ? AND advertised = ? AND enabled = ?", node.ID, true, true).
Find(&routes).Error
@ -638,16 +516,13 @@ func (hsdb *HSDatabase) getEnabledRoutes(node *types.Node) ([]netip.Prefix, erro
return prefixes, nil
}
func (hsdb *HSDatabase) IsRoutesEnabled(node *types.Node, routeStr string) bool {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
func IsRoutesEnabled(tx *gorm.DB, node *types.Node, routeStr string) bool {
route, err := netip.ParsePrefix(routeStr)
if err != nil {
return false
}
enabledRoutes, err := hsdb.getEnabledRoutes(node)
enabledRoutes, err := GetEnabledRoutes(tx, node)
if err != nil {
log.Error().Err(err).Msg("Could not get enabled routes")
@ -663,26 +538,37 @@ func (hsdb *HSDatabase) IsRoutesEnabled(node *types.Node, routeStr string) bool
return false
}
func (hsdb *HSDatabase) enableRoutes(
node *types.Node,
routeStrs ...string,
) (*types.StateUpdate, error) {
return Write(hsdb.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
return enableRoutes(tx, node, routeStrs...)
})
}
// enableRoutes enables new routes based on a list of new routes.
func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) error {
func enableRoutes(tx *gorm.DB,
node *types.Node, routeStrs ...string,
) (*types.StateUpdate, error) {
newRoutes := make([]netip.Prefix, len(routeStrs))
for index, routeStr := range routeStrs {
route, err := netip.ParsePrefix(routeStr)
if err != nil {
return err
return nil, err
}
newRoutes[index] = route
}
advertisedRoutes, err := hsdb.getAdvertisedRoutes(node)
advertisedRoutes, err := GetAdvertisedRoutes(tx, node)
if err != nil {
return err
return nil, err
}
for _, newRoute := range newRoutes {
if !util.StringOrPrefixListContains(advertisedRoutes, newRoute) {
return fmt.Errorf(
return nil, fmt.Errorf(
"route (%s) is not available on node %s: %w",
node.Hostname,
newRoute, ErrNodeRouteIsNotAvailable,
@ -693,7 +579,7 @@ func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) erro
// Separate loop so we don't leave things in a half-updated state
for _, prefix := range newRoutes {
route := types.Route{}
err := hsdb.db.Preload("Node").
err := tx.Preload("Node").
Where("node_id = ? AND prefix = ?", node.ID, types.IPPrefix(prefix)).
First(&route).Error
if err == nil {
@ -702,23 +588,23 @@ func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) erro
// Mark already as primary if there is only this node offering this subnet
// (and is not an exit route)
if !route.IsExitRoute() {
route.IsPrimary = hsdb.isUniquePrefix(route)
route.IsPrimary = isUniquePrefix(tx, route)
}
err = hsdb.db.Save(&route).Error
err = tx.Save(&route).Error
if err != nil {
return fmt.Errorf("failed to enable route: %w", err)
return nil, fmt.Errorf("failed to enable route: %w", err)
}
} else {
return fmt.Errorf("failed to find route: %w", err)
return nil, fmt.Errorf("failed to find route: %w", err)
}
}
// Ensure the node has the latest routes when notifying the other
// nodes
nRoutes, err := hsdb.getNodeRoutes(node)
nRoutes, err := GetNodeRoutes(tx, node)
if err != nil {
return fmt.Errorf("failed to read back routes: %w", err)
return nil, fmt.Errorf("failed to read back routes: %w", err)
}
node.Routes = nRoutes
@ -729,30 +615,11 @@ func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) erro
Strs("routes", routeStrs).
Msg("enabling routes")
stateUpdate := types.StateUpdate{
return &types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{node},
Message: "called from db.enableRoutes",
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyWithIgnore(
stateUpdate, node.MachineKey.String())
}
// Send an update to the node itself with to ensure it
// has an updated packetfilter allowing the new route
// if it is defined in the ACL.
selfUpdate := types.StateUpdate{
Type: types.StateSelfUpdate,
ChangeNodes: types.Nodes{node},
}
if selfUpdate.Valid() {
hsdb.notifier.NotifyByMachineKey(
selfUpdate,
node.MachineKey)
}
return nil
Message: "created in db.enableRoutes",
}, nil
}
func generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
@ -785,16 +652,23 @@ func (hsdb *HSDatabase) GenerateGivenName(
mkey key.MachinePublic,
suppliedName string,
) (string, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return Read(hsdb.DB, func(rx *gorm.DB) (string, error) {
return GenerateGivenName(rx, mkey, suppliedName)
})
}
func GenerateGivenName(
tx *gorm.DB,
mkey key.MachinePublic,
suppliedName string,
) (string, error) {
givenName, err := generateGivenName(suppliedName, false)
if err != nil {
return "", err
}
// Tailscale rules (may differ) https://tailscale.com/kb/1098/machine-names/
nodes, err := hsdb.listNodesByGivenName(givenName)
nodes, err := listNodesByGivenName(tx, givenName)
if err != nil {
return "", err
}
@ -818,29 +692,28 @@ func (hsdb *HSDatabase) GenerateGivenName(
return givenName, nil
}
func (hsdb *HSDatabase) ExpireEphemeralNodes(inactivityThreshhold time.Duration) {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
users, err := hsdb.listUsers()
func ExpireEphemeralNodes(tx *gorm.DB,
inactivityThreshhold time.Duration,
) (types.StateUpdate, bool) {
users, err := ListUsers(tx)
if err != nil {
log.Error().Err(err).Msg("Error listing users")
return
return types.StateUpdate{}, false
}
expired := make([]tailcfg.NodeID, 0)
for _, user := range users {
nodes, err := hsdb.listNodesByUser(user.Name)
nodes, err := ListNodesByUser(tx, user.Name)
if err != nil {
log.Error().
Err(err).
Str("user", user.Name).
Msg("Error listing nodes in user")
return
return types.StateUpdate{}, false
}
expired := make([]tailcfg.NodeID, 0)
for idx, node := range nodes {
if node.IsEphemeral() && node.LastSeen != nil &&
time.Now().
@ -851,7 +724,8 @@ func (hsdb *HSDatabase) ExpireEphemeralNodes(inactivityThreshhold time.Duration)
Str("node", node.Hostname).
Msg("Ephemeral client removed from database")
err = hsdb.deleteNode(nodes[idx])
// empty isConnected map as ephemeral nodes are not routes
err = DeleteNode(tx, nodes[idx], map[key.MachinePublic]bool{})
if err != nil {
log.Error().
Err(err).
@ -861,33 +735,35 @@ func (hsdb *HSDatabase) ExpireEphemeralNodes(inactivityThreshhold time.Duration)
}
}
// TODO(kradalby): needs to be moved out of transaction
}
if len(expired) > 0 {
hsdb.notifier.NotifyAll(types.StateUpdate{
return types.StateUpdate{
Type: types.StatePeerRemoved,
Removed: expired,
})
}
}
}, true
}
func (hsdb *HSDatabase) ExpireExpiredNodes(lastCheck time.Time) time.Time {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
return types.StateUpdate{}, false
}
func ExpireExpiredNodes(tx *gorm.DB,
lastCheck time.Time,
) (time.Time, types.StateUpdate, bool) {
// use the time of the start of the function to ensure we
// dont miss some nodes by returning it _after_ we have
// checked everything.
started := time.Now()
expiredNodes := make([]*types.Node, 0)
expired := make([]*tailcfg.PeerChange, 0)
nodes, err := hsdb.listNodes()
nodes, err := ListNodes(tx)
if err != nil {
log.Error().
Err(err).
Msg("Error listing nodes to find expired nodes")
return time.Unix(0, 0)
return time.Unix(0, 0), types.StateUpdate{}, false
}
for index, node := range nodes {
if node.IsExpired() &&
@ -895,13 +771,17 @@ func (hsdb *HSDatabase) ExpireExpiredNodes(lastCheck time.Time) time.Time {
// It will notify about all nodes that has been expired.
// It should only notify about expired nodes since _last check_.
node.Expiry.After(lastCheck) {
expiredNodes = append(expiredNodes, &nodes[index])
expired = append(expired, &tailcfg.PeerChange{
NodeID: tailcfg.NodeID(node.ID),
KeyExpiry: node.Expiry,
})
now := time.Now()
// Do not use setNodeExpiry as that has a notifier hook, which
// can cause a deadlock, we are updating all changed nodes later
// and there is no point in notifiying twice.
if err := hsdb.db.Model(&nodes[index]).Updates(types.Node{
Expiry: &started,
if err := tx.Model(&nodes[index]).Updates(types.Node{
Expiry: &now,
}).Error; err != nil {
log.Error().
Err(err).
@ -917,33 +797,12 @@ func (hsdb *HSDatabase) ExpireExpiredNodes(lastCheck time.Time) time.Time {
}
}
expired := make([]*tailcfg.PeerChange, len(expiredNodes))
for idx, node := range expiredNodes {
expired[idx] = &tailcfg.PeerChange{
NodeID: tailcfg.NodeID(node.ID),
KeyExpiry: &started,
}
}
// Inform the peers of a node with a lightweight update.
stateUpdate := types.StateUpdate{
if len(expired) > 0 {
return started, types.StateUpdate{
Type: types.StatePeerChangedPatch,
ChangePatches: expired,
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyAll(stateUpdate)
}, true
}
// Inform the node itself that it has expired.
for _, node := range expiredNodes {
stateSelfUpdate := types.StateUpdate{
Type: types.StateSelfUpdate,
ChangeNodes: types.Nodes{node},
}
if stateSelfUpdate.Valid() {
hsdb.notifier.NotifyByMachineKey(stateSelfUpdate, node.MachineKey)
}
}
return started
return started, types.StateUpdate{}, false
}

View file

@ -23,7 +23,7 @@ func (s *Suite) TestGetNode(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetNode("test", "testnode")
_, err = db.getNode("test", "testnode")
c.Assert(err, check.NotNil)
nodeKey := key.NewNode()
@ -38,9 +38,9 @@ func (s *Suite) TestGetNode(c *check.C) {
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
db.db.Save(node)
db.DB.Save(node)
_, err = db.GetNode("test", "testnode")
_, err = db.getNode("test", "testnode")
c.Assert(err, check.IsNil)
}
@ -66,40 +66,12 @@ func (s *Suite) TestGetNodeByID(c *check.C) {
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
db.db.Save(&node)
db.DB.Save(&node)
_, err = db.GetNodeByID(0)
c.Assert(err, check.IsNil)
}
func (s *Suite) TestGetNodeByNodeKey(c *check.C) {
user, err := db.CreateUser("test")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetNodeByID(0)
c.Assert(err, check.NotNil)
nodeKey := key.NewNode()
machineKey := key.NewMachine()
node := types.Node{
ID: 0,
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
db.db.Save(&node)
_, err = db.GetNodeByNodeKey(nodeKey.Public())
c.Assert(err, check.IsNil)
}
func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) {
user, err := db.CreateUser("test")
c.Assert(err, check.IsNil)
@ -124,7 +96,7 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) {
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
db.db.Save(&node)
db.DB.Save(&node)
_, err = db.GetNodeByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public())
c.Assert(err, check.IsNil)
@ -146,12 +118,12 @@ func (s *Suite) TestHardDeleteNode(c *check.C) {
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(1),
}
db.db.Save(&node)
db.DB.Save(&node)
err = db.DeleteNode(&node)
err = db.DeleteNode(&node, map[key.MachinePublic]bool{})
c.Assert(err, check.IsNil)
_, err = db.GetNode(user.Name, "testnode3")
_, err = db.getNode(user.Name, "testnode3")
c.Assert(err, check.NotNil)
}
@ -178,7 +150,7 @@ func (s *Suite) TestListPeers(c *check.C) {
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
db.db.Save(&node)
db.DB.Save(&node)
}
node0ByID, err := db.GetNodeByID(0)
@ -228,7 +200,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(stor[index%2].key.ID),
}
db.db.Save(&node)
db.DB.Save(&node)
}
aclPolicy := &policy.ACLPolicy{
@ -295,7 +267,7 @@ func (s *Suite) TestExpireNode(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetNode("test", "testnode")
_, err = db.getNode("test", "testnode")
c.Assert(err, check.NotNil)
nodeKey := key.NewNode()
@ -311,16 +283,19 @@ func (s *Suite) TestExpireNode(c *check.C) {
AuthKeyID: uint(pak.ID),
Expiry: &time.Time{},
}
db.db.Save(node)
db.DB.Save(node)
nodeFromDB, err := db.GetNode("test", "testnode")
nodeFromDB, err := db.getNode("test", "testnode")
c.Assert(err, check.IsNil)
c.Assert(nodeFromDB, check.NotNil)
c.Assert(nodeFromDB.IsExpired(), check.Equals, false)
now := time.Now()
err = db.NodeSetExpiry(nodeFromDB, now)
err = db.NodeSetExpiry(nodeFromDB.ID, now)
c.Assert(err, check.IsNil)
nodeFromDB, err = db.getNode("test", "testnode")
c.Assert(err, check.IsNil)
c.Assert(nodeFromDB.IsExpired(), check.Equals, true)
@ -354,7 +329,7 @@ func (s *Suite) TestGenerateGivenName(c *check.C) {
pak, err := db.CreatePreAuthKey(user1.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetNode("user-1", "testnode")
_, err = db.getNode("user-1", "testnode")
c.Assert(err, check.NotNil)
nodeKey := key.NewNode()
@ -372,7 +347,7 @@ func (s *Suite) TestGenerateGivenName(c *check.C) {
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
db.db.Save(node)
db.DB.Save(node)
givenName, err := db.GenerateGivenName(machineKey2.Public(), "hostname-2")
comment := check.Commentf("Same user, unique nodes, unique hostnames, no conflict")
@ -397,7 +372,7 @@ func (s *Suite) TestSetTags(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetNode("test", "testnode")
_, err = db.getNode("test", "testnode")
c.Assert(err, check.NotNil)
nodeKey := key.NewNode()
@ -412,21 +387,21 @@ func (s *Suite) TestSetTags(c *check.C) {
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
db.db.Save(node)
db.DB.Save(node)
// assign simple tags
sTags := []string{"tag:test", "tag:foo"}
err = db.SetTags(node, sTags)
err = db.SetTags(node.ID, sTags)
c.Assert(err, check.IsNil)
node, err = db.GetNode("test", "testnode")
node, err = db.getNode("test", "testnode")
c.Assert(err, check.IsNil)
c.Assert(node.ForcedTags, check.DeepEquals, types.StringList(sTags))
// assign duplicat tags, expect no errors but no doubles in DB
eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"}
err = db.SetTags(node, eTags)
err = db.SetTags(node.ID, eTags)
c.Assert(err, check.IsNil)
node, err = db.GetNode("test", "testnode")
node, err = db.getNode("test", "testnode")
c.Assert(err, check.IsNil)
c.Assert(
node.ForcedTags,
@ -601,7 +576,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
}
db.db.Save(&node)
db.DB.Save(&node)
sendUpdate, err := db.SaveNodeRoutes(&node)
c.Assert(err, check.IsNil)
@ -610,7 +585,8 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
node0ByID, err := db.GetNodeByID(0)
c.Assert(err, check.IsNil)
err = db.EnableAutoApprovedRoutes(pol, node0ByID)
// TODO(kradalby): Check state update
_, err = db.EnableAutoApprovedRoutes(pol, node0ByID)
c.Assert(err, check.IsNil)
enabledRoutes, err := db.GetEnabledRoutes(node0ByID)

View file

@ -20,7 +20,6 @@ var (
ErrPreAuthKeyACLTagInvalid = errors.New("AuthKey tag is invalid")
)
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
func (hsdb *HSDatabase) CreatePreAuthKey(
userName string,
reusable bool,
@ -28,11 +27,21 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
expiration *time.Time,
aclTags []string,
) (*types.PreAuthKey, error) {
// TODO(kradalby): figure out this lock
// hsdb.mu.Lock()
// defer hsdb.mu.Unlock()
return Write(hsdb.DB, func(tx *gorm.DB) (*types.PreAuthKey, error) {
return CreatePreAuthKey(tx, userName, reusable, ephemeral, expiration, aclTags)
})
}
user, err := hsdb.GetUser(userName)
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
func CreatePreAuthKey(
tx *gorm.DB,
userName string,
reusable bool,
ephemeral bool,
expiration *time.Time,
aclTags []string,
) (*types.PreAuthKey, error) {
user, err := GetUser(tx, userName)
if err != nil {
return nil, err
}
@ -48,7 +57,7 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
}
now := time.Now().UTC()
kstr, err := hsdb.generateKey()
kstr, err := generateKey()
if err != nil {
return nil, err
}
@ -63,9 +72,8 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
Expiration: expiration,
}
err = hsdb.db.Transaction(func(db *gorm.DB) error {
if err := db.Save(&key).Error; err != nil {
return fmt.Errorf("failed to create key in the database: %w", err)
if err := tx.Save(&key).Error; err != nil {
return nil, fmt.Errorf("failed to create key in the database: %w", err)
}
if len(aclTags) > 0 {
@ -73,8 +81,8 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
for _, tag := range aclTags {
if !seenTags[tag] {
if err := db.Save(&types.PreAuthKeyACLTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil {
return fmt.Errorf(
if err := tx.Save(&types.PreAuthKeyACLTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil {
return nil, fmt.Errorf(
"failed to ceate key tag in the database: %w",
err,
)
@ -84,9 +92,6 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
}
}
return nil
})
if err != nil {
return nil, err
}
@ -94,22 +99,21 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
return &key, nil
}
// ListPreAuthKeys returns the list of PreAuthKeys for a user.
func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.listPreAuthKeys(userName)
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) {
return ListPreAuthKeys(rx, userName)
})
}
func (hsdb *HSDatabase) listPreAuthKeys(userName string) ([]types.PreAuthKey, error) {
user, err := hsdb.getUser(userName)
// ListPreAuthKeys returns the list of PreAuthKeys for a user.
func ListPreAuthKeys(tx *gorm.DB, userName string) ([]types.PreAuthKey, error) {
user, err := GetUser(tx, userName)
if err != nil {
return nil, err
}
keys := []types.PreAuthKey{}
if err := hsdb.db.Preload("User").Preload("ACLTags").Where(&types.PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil {
if err := tx.Preload("User").Preload("ACLTags").Where(&types.PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil {
return nil, err
}
@ -117,11 +121,8 @@ func (hsdb *HSDatabase) listPreAuthKeys(userName string) ([]types.PreAuthKey, er
}
// GetPreAuthKey returns a PreAuthKey for a given key.
func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKey, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
pak, err := hsdb.ValidatePreAuthKey(key)
func GetPreAuthKey(tx *gorm.DB, user string, key string) (*types.PreAuthKey, error) {
pak, err := ValidatePreAuthKey(tx, key)
if err != nil {
return nil, err
}
@ -135,15 +136,8 @@ func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKe
// DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey
// does not exist.
func (hsdb *HSDatabase) DestroyPreAuthKey(pak types.PreAuthKey) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
return hsdb.destroyPreAuthKey(pak)
}
func (hsdb *HSDatabase) destroyPreAuthKey(pak types.PreAuthKey) error {
return hsdb.db.Transaction(func(db *gorm.DB) error {
func DestroyPreAuthKey(tx *gorm.DB, pak types.PreAuthKey) error {
return tx.Transaction(func(db *gorm.DB) error {
if result := db.Unscoped().Where(types.PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&types.PreAuthKeyACLTag{}); result.Error != nil {
return result.Error
}
@ -156,12 +150,15 @@ func (hsdb *HSDatabase) destroyPreAuthKey(pak types.PreAuthKey) error {
})
}
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
return hsdb.Write(func(tx *gorm.DB) error {
return ExpirePreAuthKey(tx, k)
})
}
if err := hsdb.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
if err := tx.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
return err
}
@ -169,26 +166,26 @@ func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error {
}
// UsePreAuthKey marks a PreAuthKey as used.
func (hsdb *HSDatabase) UsePreAuthKey(k *types.PreAuthKey) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
k.Used = true
if err := hsdb.db.Save(k).Error; err != nil {
if err := tx.Save(k).Error; err != nil {
return fmt.Errorf("failed to update key used status in the database: %w", err)
}
return nil
}
func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.PreAuthKey, error) {
return ValidatePreAuthKey(rx, k)
})
}
// ValidatePreAuthKey does the heavy lifting for validation of the PreAuthKey coming from a node
// If returns no error and a PreAuthKey, it can be used.
func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
func ValidatePreAuthKey(tx *gorm.DB, k string) (*types.PreAuthKey, error) {
pak := types.PreAuthKey{}
if result := hsdb.db.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is(
if result := tx.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is(
result.Error,
gorm.ErrRecordNotFound,
) {
@ -204,7 +201,7 @@ func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error)
}
nodes := types.Nodes{}
if err := hsdb.db.
if err := tx.
Preload("AuthKey").
Where(&types.Node{AuthKeyID: uint(pak.ID)}).
Find(&nodes).Error; err != nil {
@ -218,7 +215,7 @@ func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error)
return &pak, nil
}
func (hsdb *HSDatabase) generateKey() (string, error) {
func generateKey() (string, error) {
size := 24
bytes := make([]byte, size)
if _, err := rand.Read(bytes); err != nil {

View file

@ -6,6 +6,7 @@ import (
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"gopkg.in/check.v1"
"gorm.io/gorm"
)
func (*Suite) TestCreatePreAuthKey(c *check.C) {
@ -41,7 +42,7 @@ func (*Suite) TestExpiredPreAuthKey(c *check.C) {
user, err := db.CreateUser("test2")
c.Assert(err, check.IsNil)
now := time.Now()
now := time.Now().Add(-5 * time.Second)
pak, err := db.CreatePreAuthKey(user.Name, true, false, &now, nil)
c.Assert(err, check.IsNil)
@ -82,7 +83,7 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) {
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
db.db.Save(&node)
db.DB.Save(&node)
key, err := db.ValidatePreAuthKey(pak.Key)
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)
@ -103,7 +104,7 @@ func (*Suite) TestReusableBeingUsedKey(c *check.C) {
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
db.db.Save(&node)
db.DB.Save(&node)
key, err := db.ValidatePreAuthKey(pak.Key)
c.Assert(err, check.IsNil)
@ -138,19 +139,22 @@ func (*Suite) TestEphemeralKey(c *check.C) {
LastSeen: &now,
AuthKeyID: uint(pak.ID),
}
db.db.Save(&node)
db.DB.Save(&node)
_, err = db.ValidatePreAuthKey(pak.Key)
// Ephemeral keys are by definition reusable
c.Assert(err, check.IsNil)
_, err = db.GetNode("test7", "testest")
_, err = db.getNode("test7", "testest")
c.Assert(err, check.IsNil)
db.ExpireEphemeralNodes(time.Second * 20)
db.DB.Transaction(func(tx *gorm.DB) error {
ExpireEphemeralNodes(tx, time.Second*20)
return nil
})
// The machine record should have been deleted
_, err = db.GetNode("test7", "testest")
_, err = db.getNode("test7", "testest")
c.Assert(err, check.NotNil)
}
@ -178,7 +182,7 @@ func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
pak.Used = true
db.db.Save(&pak)
db.DB.Save(&pak)
_, err = db.ValidatePreAuthKey(pak.Key)
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)

View file

@ -7,23 +7,15 @@ import (
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log"
"github.com/samber/lo"
"gorm.io/gorm"
"tailscale.com/types/key"
)
var ErrRouteIsNotAvailable = errors.New("route is not available")
func (hsdb *HSDatabase) GetRoutes() (types.Routes, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.getRoutes()
}
func (hsdb *HSDatabase) getRoutes() (types.Routes, error) {
func GetRoutes(tx *gorm.DB) (types.Routes, error) {
var routes types.Routes
err := hsdb.db.
err := tx.
Preload("Node").
Preload("Node.User").
Find(&routes).Error
@ -34,9 +26,9 @@ func (hsdb *HSDatabase) getRoutes() (types.Routes, error) {
return routes, nil
}
func (hsdb *HSDatabase) getAdvertisedAndEnabledRoutes() (types.Routes, error) {
func getAdvertisedAndEnabledRoutes(tx *gorm.DB) (types.Routes, error) {
var routes types.Routes
err := hsdb.db.
err := tx.
Preload("Node").
Preload("Node.User").
Where("advertised = ? AND enabled = ?", true, true).
@ -48,9 +40,9 @@ func (hsdb *HSDatabase) getAdvertisedAndEnabledRoutes() (types.Routes, error) {
return routes, nil
}
func (hsdb *HSDatabase) getRoutesByPrefix(pref netip.Prefix) (types.Routes, error) {
func getRoutesByPrefix(tx *gorm.DB, pref netip.Prefix) (types.Routes, error) {
var routes types.Routes
err := hsdb.db.
err := tx.
Preload("Node").
Preload("Node.User").
Where("prefix = ?", types.IPPrefix(pref)).
@ -62,16 +54,9 @@ func (hsdb *HSDatabase) getRoutesByPrefix(pref netip.Prefix) (types.Routes, erro
return routes, nil
}
func (hsdb *HSDatabase) GetNodeAdvertisedRoutes(node *types.Node) (types.Routes, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.getNodeAdvertisedRoutes(node)
}
func (hsdb *HSDatabase) getNodeAdvertisedRoutes(node *types.Node) (types.Routes, error) {
func GetNodeAdvertisedRoutes(tx *gorm.DB, node *types.Node) (types.Routes, error) {
var routes types.Routes
err := hsdb.db.
err := tx.
Preload("Node").
Preload("Node.User").
Where("node_id = ? AND advertised = true", node.ID).
@ -84,15 +69,14 @@ func (hsdb *HSDatabase) getNodeAdvertisedRoutes(node *types.Node) (types.Routes,
}
func (hsdb *HSDatabase) GetNodeRoutes(node *types.Node) (types.Routes, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.getNodeRoutes(node)
return Read(hsdb.DB, func(rx *gorm.DB) (types.Routes, error) {
return GetNodeRoutes(rx, node)
})
}
func (hsdb *HSDatabase) getNodeRoutes(node *types.Node) (types.Routes, error) {
func GetNodeRoutes(tx *gorm.DB, node *types.Node) (types.Routes, error) {
var routes types.Routes
err := hsdb.db.
err := tx.
Preload("Node").
Preload("Node.User").
Where("node_id = ?", node.ID).
@ -104,16 +88,9 @@ func (hsdb *HSDatabase) getNodeRoutes(node *types.Node) (types.Routes, error) {
return routes, nil
}
func (hsdb *HSDatabase) GetRoute(id uint64) (*types.Route, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.getRoute(id)
}
func (hsdb *HSDatabase) getRoute(id uint64) (*types.Route, error) {
func GetRoute(tx *gorm.DB, id uint64) (*types.Route, error) {
var route types.Route
err := hsdb.db.
err := tx.
Preload("Node").
Preload("Node.User").
First(&route, id).Error
@ -124,40 +101,34 @@ func (hsdb *HSDatabase) getRoute(id uint64) (*types.Route, error) {
return &route, nil
}
func (hsdb *HSDatabase) EnableRoute(id uint64) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
return hsdb.enableRoute(id)
}
func (hsdb *HSDatabase) enableRoute(id uint64) error {
route, err := hsdb.getRoute(id)
func EnableRoute(tx *gorm.DB, id uint64) (*types.StateUpdate, error) {
route, err := GetRoute(tx, id)
if err != nil {
return err
return nil, err
}
// Tailscale requires both IPv4 and IPv6 exit routes to
// be enabled at the same time, as per
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
if route.IsExitRoute() {
return hsdb.enableRoutes(
return enableRoutes(
tx,
&route.Node,
types.ExitRouteV4.String(),
types.ExitRouteV6.String(),
)
}
return hsdb.enableRoutes(&route.Node, netip.Prefix(route.Prefix).String())
return enableRoutes(tx, &route.Node, netip.Prefix(route.Prefix).String())
}
func (hsdb *HSDatabase) DisableRoute(id uint64) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
route, err := hsdb.getRoute(id)
func DisableRoute(tx *gorm.DB,
id uint64,
isConnected map[key.MachinePublic]bool,
) (*types.StateUpdate, error) {
route, err := GetRoute(tx, id)
if err != nil {
return err
return nil, err
}
var routes types.Routes
@ -166,64 +137,79 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error {
// Tailscale requires both IPv4 and IPv6 exit routes to
// be enabled at the same time, as per
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
var update *types.StateUpdate
if !route.IsExitRoute() {
err = hsdb.failoverRouteWithNotify(route)
update, err = failoverRouteReturnUpdate(tx, isConnected, route)
if err != nil {
return err
return nil, err
}
route.Enabled = false
route.IsPrimary = false
err = hsdb.db.Save(route).Error
err = tx.Save(route).Error
if err != nil {
return err
return nil, err
}
} else {
routes, err = hsdb.getNodeRoutes(&node)
routes, err = GetNodeRoutes(tx, &node)
if err != nil {
return err
return nil, err
}
for i := range routes {
if routes[i].IsExitRoute() {
routes[i].Enabled = false
routes[i].IsPrimary = false
err = hsdb.db.Save(&routes[i]).Error
err = tx.Save(&routes[i]).Error
if err != nil {
return err
return nil, err
}
}
}
}
if routes == nil {
routes, err = hsdb.getNodeRoutes(&node)
routes, err = GetNodeRoutes(tx, &node)
if err != nil {
return err
return nil, err
}
}
node.Routes = routes
stateUpdate := types.StateUpdate{
// If update is empty, it means that one was not created
// by failover (as a failover was not necessary), create
// one and return to the caller.
if update == nil {
update = &types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{&node},
ChangeNodes: types.Nodes{
&node,
},
Message: "called from db.DisableRoute",
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyAll(stateUpdate)
}
return nil
return update, nil
}
func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
func (hsdb *HSDatabase) DeleteRoute(
id uint64,
isConnected map[key.MachinePublic]bool,
) (*types.StateUpdate, error) {
return Write(hsdb.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
return DeleteRoute(tx, id, isConnected)
})
}
route, err := hsdb.getRoute(id)
func DeleteRoute(
tx *gorm.DB,
id uint64,
isConnected map[key.MachinePublic]bool,
) (*types.StateUpdate, error) {
route, err := GetRoute(tx, id)
if err != nil {
return err
return nil, err
}
var routes types.Routes
@ -232,19 +218,20 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
// Tailscale requires both IPv4 and IPv6 exit routes to
// be enabled at the same time, as per
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
var update *types.StateUpdate
if !route.IsExitRoute() {
err := hsdb.failoverRouteWithNotify(route)
update, err = failoverRouteReturnUpdate(tx, isConnected, route)
if err != nil {
return nil
return nil, nil
}
if err := hsdb.db.Unscoped().Delete(&route).Error; err != nil {
return err
if err := tx.Unscoped().Delete(&route).Error; err != nil {
return nil, err
}
} else {
routes, err := hsdb.getNodeRoutes(&node)
routes, err := GetNodeRoutes(tx, &node)
if err != nil {
return err
return nil, err
}
routesToDelete := types.Routes{}
@ -254,56 +241,59 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
}
}
if err := hsdb.db.Unscoped().Delete(&routesToDelete).Error; err != nil {
return err
if err := tx.Unscoped().Delete(&routesToDelete).Error; err != nil {
return nil, err
}
}
// If update is empty, it means that one was not created
// by failover (as a failover was not necessary), create
// one and return to the caller.
if routes == nil {
routes, err = hsdb.getNodeRoutes(&node)
routes, err = GetNodeRoutes(tx, &node)
if err != nil {
return err
return nil, err
}
}
node.Routes = routes
stateUpdate := types.StateUpdate{
if update == nil {
update = &types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{&node},
ChangeNodes: types.Nodes{
&node,
},
Message: "called from db.DeleteRoute",
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyAll(stateUpdate)
}
return nil
return update, nil
}
func (hsdb *HSDatabase) deleteNodeRoutes(node *types.Node) error {
routes, err := hsdb.getNodeRoutes(node)
func deleteNodeRoutes(tx *gorm.DB, node *types.Node, isConnected map[key.MachinePublic]bool) error {
routes, err := GetNodeRoutes(tx, node)
if err != nil {
return err
}
for i := range routes {
if err := hsdb.db.Unscoped().Delete(&routes[i]).Error; err != nil {
if err := tx.Unscoped().Delete(&routes[i]).Error; err != nil {
return err
}
// TODO(kradalby): This is a bit too aggressive, we could probably
// figure out which routes needs to be failed over rather than all.
hsdb.failoverRouteWithNotify(&routes[i])
failoverRouteReturnUpdate(tx, isConnected, &routes[i])
}
return nil
}
// isUniquePrefix returns if there is another node providing the same route already.
func (hsdb *HSDatabase) isUniquePrefix(route types.Route) bool {
func isUniquePrefix(tx *gorm.DB, route types.Route) bool {
var count int64
hsdb.db.
Model(&types.Route{}).
tx.Model(&types.Route{}).
Where("prefix = ? AND node_id != ? AND advertised = ? AND enabled = ?",
route.Prefix,
route.NodeID,
@ -312,9 +302,9 @@ func (hsdb *HSDatabase) isUniquePrefix(route types.Route) bool {
return count == 0
}
func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*types.Route, error) {
func getPrimaryRoute(tx *gorm.DB, prefix netip.Prefix) (*types.Route, error) {
var route types.Route
err := hsdb.db.
err := tx.
Preload("Node").
Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", types.IPPrefix(prefix), true, true, true).
First(&route).Error
@ -329,14 +319,17 @@ func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*types.Route, erro
return &route, nil
}
func (hsdb *HSDatabase) GetNodePrimaryRoutes(node *types.Node) (types.Routes, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (types.Routes, error) {
return GetNodePrimaryRoutes(rx, node)
})
}
// getNodePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover)
// Exit nodes are not considered for this, as they are never marked as Primary.
func (hsdb *HSDatabase) GetNodePrimaryRoutes(node *types.Node) (types.Routes, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
func GetNodePrimaryRoutes(tx *gorm.DB, node *types.Node) (types.Routes, error) {
var routes types.Routes
err := hsdb.db.
err := tx.
Preload("Node").
Where("node_id = ? AND advertised = ? AND enabled = ? AND is_primary = ?", node.ID, true, true, true).
Find(&routes).Error
@ -347,22 +340,21 @@ func (hsdb *HSDatabase) GetNodePrimaryRoutes(node *types.Node) (types.Routes, er
return routes, nil
}
func (hsdb *HSDatabase) SaveNodeRoutes(node *types.Node) (bool, error) {
return Write(hsdb.DB, func(tx *gorm.DB) (bool, error) {
return SaveNodeRoutes(tx, node)
})
}
// SaveNodeRoutes takes a node and updates the database with
// the new routes.
// It returns a bool whether an update should be sent as the
// saved route impacts nodes.
func (hsdb *HSDatabase) SaveNodeRoutes(node *types.Node) (bool, error) {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
return hsdb.saveNodeRoutes(node)
}
func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) {
func SaveNodeRoutes(tx *gorm.DB, node *types.Node) (bool, error) {
sendUpdate := false
currentRoutes := types.Routes{}
err := hsdb.db.Where("node_id = ?", node.ID).Find(&currentRoutes).Error
err := tx.Where("node_id = ?", node.ID).Find(&currentRoutes).Error
if err != nil {
return sendUpdate, err
}
@ -382,7 +374,7 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) {
if _, ok := advertisedRoutes[netip.Prefix(route.Prefix)]; ok {
if !route.Advertised {
currentRoutes[pos].Advertised = true
err := hsdb.db.Save(&currentRoutes[pos]).Error
err := tx.Save(&currentRoutes[pos]).Error
if err != nil {
return sendUpdate, err
}
@ -398,7 +390,7 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) {
} else if route.Advertised {
currentRoutes[pos].Advertised = false
currentRoutes[pos].Enabled = false
err := hsdb.db.Save(&currentRoutes[pos]).Error
err := tx.Save(&currentRoutes[pos]).Error
if err != nil {
return sendUpdate, err
}
@ -413,7 +405,7 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) {
Advertised: true,
Enabled: false,
}
err := hsdb.db.Create(&route).Error
err := tx.Create(&route).Error
if err != nil {
return sendUpdate, err
}
@ -425,127 +417,89 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) {
// EnsureFailoverRouteIsAvailable takes a node and checks if the node's route
// currently have a functioning host that exposes the network.
func (hsdb *HSDatabase) EnsureFailoverRouteIsAvailable(node *types.Node) error {
nodeRoutes, err := hsdb.getNodeRoutes(node)
func EnsureFailoverRouteIsAvailable(
tx *gorm.DB,
isConnected map[key.MachinePublic]bool,
node *types.Node,
) (*types.StateUpdate, error) {
nodeRoutes, err := GetNodeRoutes(tx, node)
if err != nil {
return nil
return nil, nil
}
var changedNodes types.Nodes
for _, nodeRoute := range nodeRoutes {
routes, err := hsdb.getRoutesByPrefix(netip.Prefix(nodeRoute.Prefix))
routes, err := getRoutesByPrefix(tx, netip.Prefix(nodeRoute.Prefix))
if err != nil {
return err
return nil, err
}
for _, route := range routes {
if route.IsPrimary {
// if we have a primary route, and the node is connected
// nothing needs to be done.
if hsdb.notifier.IsConnected(route.Node.MachineKey) {
if isConnected[route.Node.MachineKey] {
continue
}
// if not, we need to failover the route
err := hsdb.failoverRouteWithNotify(&route)
update, err := failoverRouteReturnUpdate(tx, isConnected, &route)
if err != nil {
return err
return nil, err
}
if update != nil {
changedNodes = append(changedNodes, update.ChangeNodes...)
}
}
}
}
return nil
}
func (hsdb *HSDatabase) FailoverNodeRoutesWithNotify(node *types.Node) error {
routes, err := hsdb.getNodeRoutes(node)
if err != nil {
return nil
}
var changedKeys []key.MachinePublic
for _, route := range routes {
changed, err := hsdb.failoverRoute(&route)
if err != nil {
return err
}
changedKeys = append(changedKeys, changed...)
}
changedKeys = lo.Uniq(changedKeys)
var nodes types.Nodes
for _, key := range changedKeys {
node, err := hsdb.GetNodeByMachineKey(key)
if err != nil {
return err
}
nodes = append(nodes, node)
}
if nodes != nil {
stateUpdate := types.StateUpdate{
if len(changedNodes) != 0 {
return &types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: nodes,
Message: "called from db.FailoverNodeRoutesWithNotify",
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyAll(stateUpdate)
}
ChangeNodes: changedNodes,
Message: "called from db.EnsureFailoverRouteIsAvailable",
}, nil
}
return nil
return nil, nil
}
func (hsdb *HSDatabase) failoverRouteWithNotify(r *types.Route) error {
changedKeys, err := hsdb.failoverRoute(r)
func failoverRouteReturnUpdate(
tx *gorm.DB,
isConnected map[key.MachinePublic]bool,
r *types.Route,
) (*types.StateUpdate, error) {
changedKeys, err := failoverRoute(tx, isConnected, r)
if err != nil {
return err
return nil, err
}
log.Trace().
Interface("isConnected", isConnected).
Interface("changedKeys", changedKeys).
Msg("building route failover")
if len(changedKeys) == 0 {
return nil
return nil, nil
}
var nodes types.Nodes
log.Trace().
Str("hostname", r.Node.Hostname).
Msg("loading machines with new primary routes from db")
for _, key := range changedKeys {
node, err := hsdb.getNodeByMachineKey(key)
node, err := GetNodeByMachineKey(tx, key)
if err != nil {
return err
return nil, err
}
nodes = append(nodes, node)
}
log.Trace().
Str("hostname", r.Node.Hostname).
Msg("notifying peers about primary route change")
if nodes != nil {
stateUpdate := types.StateUpdate{
return &types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: nodes,
Message: "called from db.failoverRouteWithNotify",
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyAll(stateUpdate)
}
}
log.Trace().
Str("hostname", r.Node.Hostname).
Msg("notified peers about primary route change")
return nil
Message: "called from db.failoverRouteReturnUpdate",
}, nil
}
// failoverRoute takes a route that is no longer available,
@ -556,12 +510,16 @@ func (hsdb *HSDatabase) failoverRouteWithNotify(r *types.Route) error {
//
// and tries to find a new route to take over its place.
// If the given route was not primary, it returns early.
func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, error) {
func failoverRoute(
tx *gorm.DB,
isConnected map[key.MachinePublic]bool,
r *types.Route,
) ([]key.MachinePublic, error) {
if r == nil {
return nil, nil
}
// This route is not a primary route, and it isnt
// This route is not a primary route, and it is not
// being served to nodes.
if !r.IsPrimary {
return nil, nil
@ -572,7 +530,7 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro
return nil, nil
}
routes, err := hsdb.getRoutesByPrefix(netip.Prefix(r.Prefix))
routes, err := getRoutesByPrefix(tx, netip.Prefix(r.Prefix))
if err != nil {
return nil, err
}
@ -589,14 +547,14 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro
continue
}
if hsdb.notifier.IsConnected(route.Node.MachineKey) {
if isConnected[route.Node.MachineKey] {
newPrimary = &routes[idx]
break
}
}
// If a new route was not found/available,
// return with an error.
// return without an error.
// We do not want to update the database as
// the one currently marked as primary is the
// best we got.
@ -610,7 +568,7 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro
// Remove primary from the old route
r.IsPrimary = false
err = hsdb.db.Save(&r).Error
err = tx.Save(&r).Error
if err != nil {
log.Error().Err(err).Msg("error disabling new primary route")
@ -623,7 +581,7 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro
// Set primary for the new primary
newPrimary.IsPrimary = true
err = hsdb.db.Save(&newPrimary).Error
err = tx.Save(&newPrimary).Error
if err != nil {
log.Error().Err(err).Msg("error enabling new primary route")
@ -638,25 +596,26 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro
return []key.MachinePublic{r.Node.MachineKey, newPrimary.Node.MachineKey}, nil
}
// EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy.
func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
aclPolicy *policy.ACLPolicy,
node *types.Node,
) error {
if len(aclPolicy.AutoApprovers.ExitNode) == 0 && len(aclPolicy.AutoApprovers.Routes) == 0 {
// No autoapprovers configured
return nil
) (*types.StateUpdate, error) {
return Write(hsdb.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
return EnableAutoApprovedRoutes(tx, aclPolicy, node)
})
}
// EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy.
func EnableAutoApprovedRoutes(
tx *gorm.DB,
aclPolicy *policy.ACLPolicy,
node *types.Node,
) (*types.StateUpdate, error) {
if len(node.IPAddresses) == 0 {
// This node has no IPAddresses, so can't possibly match any autoApprovers ACLs
return nil
return nil, nil // This node has no IPAddresses, so can't possibly match any autoApprovers ACLs
}
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
routes, err := hsdb.getNodeAdvertisedRoutes(node)
routes, err := GetNodeAdvertisedRoutes(tx, node)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
log.Error().
Caller().
@ -664,7 +623,7 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
Str("node", node.Hostname).
Msg("Could not get advertised routes for node")
return err
return nil, err
}
log.Trace().Interface("routes", routes).Msg("routes for autoapproving")
@ -685,7 +644,7 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
Uint64("nodeId", node.ID).
Msg("Failed to resolve autoApprovers for advertised route")
return err
return nil, err
}
log.Trace().
@ -706,7 +665,7 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
Str("alias", approvedAlias).
Msg("Failed to expand alias when processing autoApprovers policy")
return err
return nil, err
}
// approvedIPs should contain all of node's IPs if it matches the rule, so check for first
@ -717,17 +676,25 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
}
}
update := &types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{},
Message: "created in db.EnableAutoApprovedRoutes",
}
for _, approvedRoute := range approvedRoutes {
err := hsdb.enableRoute(uint64(approvedRoute.ID))
perHostUpdate, err := EnableRoute(tx, uint64(approvedRoute.ID))
if err != nil {
log.Err(err).
Str("approvedRoute", approvedRoute.String()).
Uint64("nodeId", node.ID).
Msg("Failed to enable approved route")
return err
}
return nil, err
}
return nil
update.ChangeNodes = append(update.ChangeNodes, perHostUpdate.ChangeNodes...)
}
return update, nil
}

View file

@ -7,7 +7,6 @@ import (
"time"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/stretchr/testify/assert"
@ -24,7 +23,7 @@ func (s *Suite) TestGetRoutes(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetNode("test", "test_get_route_node")
_, err = db.getNode("test", "test_get_route_node")
c.Assert(err, check.NotNil)
route, err := netip.ParsePrefix("10.0.0.0/24")
@ -42,7 +41,7 @@ func (s *Suite) TestGetRoutes(c *check.C) {
AuthKeyID: uint(pak.ID),
Hostinfo: &hostInfo,
}
db.db.Save(&node)
db.DB.Save(&node)
su, err := db.SaveNodeRoutes(&node)
c.Assert(err, check.IsNil)
@ -52,10 +51,11 @@ func (s *Suite) TestGetRoutes(c *check.C) {
c.Assert(err, check.IsNil)
c.Assert(len(advertisedRoutes), check.Equals, 1)
err = db.enableRoutes(&node, "192.168.0.0/24")
// TODO(kradalby): check state update
_, err = db.enableRoutes(&node, "192.168.0.0/24")
c.Assert(err, check.NotNil)
err = db.enableRoutes(&node, "10.0.0.0/24")
_, err = db.enableRoutes(&node, "10.0.0.0/24")
c.Assert(err, check.IsNil)
}
@ -66,7 +66,7 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetNode("test", "test_enable_route_node")
_, err = db.getNode("test", "test_enable_route_node")
c.Assert(err, check.NotNil)
route, err := netip.ParsePrefix(
@ -91,7 +91,7 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
AuthKeyID: uint(pak.ID),
Hostinfo: &hostInfo,
}
db.db.Save(&node)
db.DB.Save(&node)
sendUpdate, err := db.SaveNodeRoutes(&node)
c.Assert(err, check.IsNil)
@ -106,10 +106,10 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
c.Assert(err, check.IsNil)
c.Assert(len(noEnabledRoutes), check.Equals, 0)
err = db.enableRoutes(&node, "192.168.0.0/24")
_, err = db.enableRoutes(&node, "192.168.0.0/24")
c.Assert(err, check.NotNil)
err = db.enableRoutes(&node, "10.0.0.0/24")
_, err = db.enableRoutes(&node, "10.0.0.0/24")
c.Assert(err, check.IsNil)
enabledRoutes, err := db.GetEnabledRoutes(&node)
@ -117,14 +117,14 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
c.Assert(len(enabledRoutes), check.Equals, 1)
// Adding it twice will just let it pass through
err = db.enableRoutes(&node, "10.0.0.0/24")
_, err = db.enableRoutes(&node, "10.0.0.0/24")
c.Assert(err, check.IsNil)
enableRoutesAfterDoubleApply, err := db.GetEnabledRoutes(&node)
c.Assert(err, check.IsNil)
c.Assert(len(enableRoutesAfterDoubleApply), check.Equals, 1)
err = db.enableRoutes(&node, "150.0.10.0/25")
_, err = db.enableRoutes(&node, "150.0.10.0/25")
c.Assert(err, check.IsNil)
enabledRoutesWithAdditionalRoute, err := db.GetEnabledRoutes(&node)
@ -139,7 +139,7 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetNode("test", "test_enable_route_node")
_, err = db.getNode("test", "test_enable_route_node")
c.Assert(err, check.NotNil)
route, err := netip.ParsePrefix(
@ -163,16 +163,16 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
AuthKeyID: uint(pak.ID),
Hostinfo: &hostInfo1,
}
db.db.Save(&node1)
db.DB.Save(&node1)
sendUpdate, err := db.SaveNodeRoutes(&node1)
c.Assert(err, check.IsNil)
c.Assert(sendUpdate, check.Equals, false)
err = db.enableRoutes(&node1, route.String())
_, err = db.enableRoutes(&node1, route.String())
c.Assert(err, check.IsNil)
err = db.enableRoutes(&node1, route2.String())
_, err = db.enableRoutes(&node1, route2.String())
c.Assert(err, check.IsNil)
hostInfo2 := tailcfg.Hostinfo{
@ -186,13 +186,13 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
AuthKeyID: uint(pak.ID),
Hostinfo: &hostInfo2,
}
db.db.Save(&node2)
db.DB.Save(&node2)
sendUpdate, err = db.SaveNodeRoutes(&node2)
c.Assert(err, check.IsNil)
c.Assert(sendUpdate, check.Equals, false)
err = db.enableRoutes(&node2, route2.String())
_, err = db.enableRoutes(&node2, route2.String())
c.Assert(err, check.IsNil)
enabledRoutes1, err := db.GetEnabledRoutes(&node1)
@ -219,7 +219,7 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetNode("test", "test_enable_route_node")
_, err = db.getNode("test", "test_enable_route_node")
c.Assert(err, check.NotNil)
prefix, err := netip.ParsePrefix(
@ -246,22 +246,23 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
Hostinfo: &hostInfo1,
LastSeen: &now,
}
db.db.Save(&node1)
db.DB.Save(&node1)
sendUpdate, err := db.SaveNodeRoutes(&node1)
c.Assert(err, check.IsNil)
c.Assert(sendUpdate, check.Equals, false)
err = db.enableRoutes(&node1, prefix.String())
_, err = db.enableRoutes(&node1, prefix.String())
c.Assert(err, check.IsNil)
err = db.enableRoutes(&node1, prefix2.String())
_, err = db.enableRoutes(&node1, prefix2.String())
c.Assert(err, check.IsNil)
routes, err := db.GetNodeRoutes(&node1)
c.Assert(err, check.IsNil)
err = db.DeleteRoute(uint64(routes[0].ID))
// TODO(kradalby): check stateupdate
_, err = db.DeleteRoute(uint64(routes[0].ID), map[key.MachinePublic]bool{})
c.Assert(err, check.IsNil)
enabledRoutes1, err := db.GetEnabledRoutes(&node1)
@ -269,17 +270,9 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
c.Assert(len(enabledRoutes1), check.Equals, 1)
}
var ipp = func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) }
func TestFailoverRoute(t *testing.T) {
ipp := func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) }
// TODO(kradalby): Count/verify updates
var sink chan types.StateUpdate
go func() {
for range sink {
}
}()
machineKeys := []key.MachinePublic{
key.NewMachine().Public(),
key.NewMachine().Public(),
@ -291,6 +284,7 @@ func TestFailoverRoute(t *testing.T) {
name string
failingRoute types.Route
routes types.Routes
isConnected map[key.MachinePublic]bool
want []key.MachinePublic
wantErr bool
}{
@ -397,6 +391,10 @@ func TestFailoverRoute(t *testing.T) {
Enabled: true,
},
},
isConnected: map[key.MachinePublic]bool{
machineKeys[0]: false,
machineKeys[1]: true,
},
want: []key.MachinePublic{
machineKeys[0],
machineKeys[1],
@ -491,6 +489,11 @@ func TestFailoverRoute(t *testing.T) {
Enabled: true,
},
},
isConnected: map[key.MachinePublic]bool{
machineKeys[0]: true,
machineKeys[1]: true,
machineKeys[2]: true,
},
want: []key.MachinePublic{
machineKeys[1],
machineKeys[0],
@ -535,6 +538,10 @@ func TestFailoverRoute(t *testing.T) {
Enabled: true,
},
},
isConnected: map[key.MachinePublic]bool{
machineKeys[0]: true,
machineKeys[3]: false,
},
want: nil,
wantErr: false,
},
@ -587,6 +594,11 @@ func TestFailoverRoute(t *testing.T) {
Enabled: true,
},
},
isConnected: map[key.MachinePublic]bool{
machineKeys[0]: false,
machineKeys[1]: true,
machineKeys[3]: false,
},
want: []key.MachinePublic{
machineKeys[0],
machineKeys[1],
@ -641,13 +653,10 @@ func TestFailoverRoute(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "failover-db-test")
assert.NoError(t, err)
notif := notifier.NewNotifier()
db, err = NewHeadscaleDatabase(
"sqlite3",
tmpDir+"/headscale_test.db",
false,
notif,
[]netip.Prefix{
netip.MustParsePrefix("10.27.0.0/23"),
},
@ -655,23 +664,15 @@ func TestFailoverRoute(t *testing.T) {
)
assert.NoError(t, err)
// Pretend that all the nodes are connected to control
for idx, key := range machineKeys {
// Pretend one node is offline
if idx == 3 {
continue
}
notif.AddNode(key, sink)
}
for _, route := range tt.routes {
if err := db.db.Save(&route).Error; err != nil {
if err := db.DB.Save(&route).Error; err != nil {
t.Fatalf("failed to create route: %s", err)
}
}
got, err := db.failoverRoute(&tt.failingRoute)
got, err := Write(db.DB, func(tx *gorm.DB) ([]key.MachinePublic, error) {
return failoverRoute(tx, tt.isConnected, &tt.failingRoute)
})
if (err != nil) != tt.wantErr {
t.Errorf("failoverRoute() error = %v, wantErr %v", err, tt.wantErr)
@ -685,3 +686,231 @@ func TestFailoverRoute(t *testing.T) {
})
}
}
// func TestDisableRouteFailover(t *testing.T) {
// machineKeys := []key.MachinePublic{
// key.NewMachine().Public(),
// key.NewMachine().Public(),
// key.NewMachine().Public(),
// key.NewMachine().Public(),
// }
// tests := []struct {
// name string
// nodes types.Nodes
// routeID uint64
// isConnected map[key.MachinePublic]bool
// wantMachineKey key.MachinePublic
// wantErr string
// }{
// {
// name: "single-route",
// nodes: types.Nodes{
// &types.Node{
// ID: 0,
// MachineKey: machineKeys[0],
// Routes: []types.Route{
// {
// Model: gorm.Model{
// ID: 1,
// },
// Prefix: ipp("10.0.0.0/24"),
// Node: types.Node{
// MachineKey: machineKeys[0],
// },
// IsPrimary: true,
// },
// },
// Hostinfo: &tailcfg.Hostinfo{
// RoutableIPs: []netip.Prefix{
// netip.MustParsePrefix("10.0.0.0/24"),
// },
// },
// },
// },
// routeID: 1,
// wantMachineKey: machineKeys[0],
// },
// {
// name: "failover-simple",
// nodes: types.Nodes{
// &types.Node{
// ID: 0,
// MachineKey: machineKeys[0],
// Routes: []types.Route{
// {
// Model: gorm.Model{
// ID: 1,
// },
// Prefix: ipp("10.0.0.0/24"),
// IsPrimary: true,
// },
// },
// Hostinfo: &tailcfg.Hostinfo{
// RoutableIPs: []netip.Prefix{
// netip.MustParsePrefix("10.0.0.0/24"),
// },
// },
// },
// &types.Node{
// ID: 1,
// MachineKey: machineKeys[1],
// Routes: []types.Route{
// {
// Model: gorm.Model{
// ID: 2,
// },
// Prefix: ipp("10.0.0.0/24"),
// IsPrimary: false,
// },
// },
// Hostinfo: &tailcfg.Hostinfo{
// RoutableIPs: []netip.Prefix{
// netip.MustParsePrefix("10.0.0.0/24"),
// },
// },
// },
// },
// routeID: 1,
// wantMachineKey: machineKeys[1],
// },
// {
// name: "no-failover-offline",
// nodes: types.Nodes{
// &types.Node{
// ID: 0,
// MachineKey: machineKeys[0],
// Routes: []types.Route{
// {
// Model: gorm.Model{
// ID: 1,
// },
// Prefix: ipp("10.0.0.0/24"),
// IsPrimary: true,
// },
// },
// Hostinfo: &tailcfg.Hostinfo{
// RoutableIPs: []netip.Prefix{
// netip.MustParsePrefix("10.0.0.0/24"),
// },
// },
// },
// &types.Node{
// ID: 1,
// MachineKey: machineKeys[1],
// Routes: []types.Route{
// {
// Model: gorm.Model{
// ID: 2,
// },
// Prefix: ipp("10.0.0.0/24"),
// IsPrimary: false,
// },
// },
// Hostinfo: &tailcfg.Hostinfo{
// RoutableIPs: []netip.Prefix{
// netip.MustParsePrefix("10.0.0.0/24"),
// },
// },
// },
// },
// isConnected: map[key.MachinePublic]bool{
// machineKeys[0]: true,
// machineKeys[1]: false,
// },
// routeID: 1,
// wantMachineKey: machineKeys[1],
// },
// {
// name: "failover-to-online",
// nodes: types.Nodes{
// &types.Node{
// ID: 0,
// MachineKey: machineKeys[0],
// Routes: []types.Route{
// {
// Model: gorm.Model{
// ID: 1,
// },
// Prefix: ipp("10.0.0.0/24"),
// IsPrimary: true,
// },
// },
// Hostinfo: &tailcfg.Hostinfo{
// RoutableIPs: []netip.Prefix{
// netip.MustParsePrefix("10.0.0.0/24"),
// },
// },
// },
// &types.Node{
// ID: 1,
// MachineKey: machineKeys[1],
// Routes: []types.Route{
// {
// Model: gorm.Model{
// ID: 2,
// },
// Prefix: ipp("10.0.0.0/24"),
// IsPrimary: false,
// },
// },
// Hostinfo: &tailcfg.Hostinfo{
// RoutableIPs: []netip.Prefix{
// netip.MustParsePrefix("10.0.0.0/24"),
// },
// },
// },
// },
// isConnected: map[key.MachinePublic]bool{
// machineKeys[0]: true,
// machineKeys[1]: true,
// },
// routeID: 1,
// wantMachineKey: machineKeys[1],
// },
// }
// for _, tt := range tests {
// t.Run(tt.name, func(t *testing.T) {
// datab, err := NewHeadscaleDatabase("sqlite3", ":memory:", false, []netip.Prefix{}, "")
// assert.NoError(t, err)
// // bootstrap db
// datab.DB.Transaction(func(tx *gorm.DB) error {
// for _, node := range tt.nodes {
// err := tx.Save(node).Error
// if err != nil {
// return err
// }
// _, err = SaveNodeRoutes(tx, node)
// if err != nil {
// return err
// }
// }
// return nil
// })
// got, err := Write(datab.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
// return DisableRoute(tx, tt.routeID, tt.isConnected)
// })
// // if (err.Error() != "") != tt.wantErr {
// // t.Errorf("failoverRoute() error = %v, wantErr %v", err, tt.wantErr)
// // return
// // }
// if len(got.ChangeNodes) != 1 {
// t.Errorf("expected update with one machine, got %d", len(got.ChangeNodes))
// }
// if diff := cmp.Diff(tt.wantMachineKey, got.ChangeNodes[0].MachineKey, util.Comparers...); diff != "" {
// t.Errorf("DisableRoute() unexpected result (-want +got):\n%s", diff)
// }
// })
// }
// }

View file

@ -6,7 +6,6 @@ import (
"os"
"testing"
"github.com/juanfont/headscale/hscontrol/notifier"
"gopkg.in/check.v1"
)
@ -48,7 +47,6 @@ func (s *Suite) ResetDB(c *check.C) {
"sqlite3",
tmpDir+"/headscale_test.db",
false,
notifier.NewNotifier(),
[]netip.Prefix{
netip.MustParsePrefix("10.27.0.0/23"),
},

View file

@ -15,22 +15,25 @@ var (
ErrUserStillHasNodes = errors.New("user not empty: node(s) found")
)
func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) {
return Write(hsdb.DB, func(tx *gorm.DB) (*types.User, error) {
return CreateUser(tx, name)
})
}
// CreateUser creates a new User. Returns error if could not be created
// or another user already exists.
func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
func CreateUser(tx *gorm.DB, name string) (*types.User, error) {
err := util.CheckForFQDNRules(name)
if err != nil {
return nil, err
}
user := types.User{}
if err := hsdb.db.Where("name = ?", name).First(&user).Error; err == nil {
if err := tx.Where("name = ?", name).First(&user).Error; err == nil {
return nil, ErrUserExists
}
user.Name = name
if err := hsdb.db.Create(&user).Error; err != nil {
if err := tx.Create(&user).Error; err != nil {
log.Error().
Str("func", "CreateUser").
Err(err).
@ -42,18 +45,21 @@ func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) {
return &user, nil
}
func (hsdb *HSDatabase) DestroyUser(name string) error {
return hsdb.Write(func(tx *gorm.DB) error {
return DestroyUser(tx, name)
})
}
// DestroyUser destroys a User. Returns error if the User does
// not exist or if there are nodes associated with it.
func (hsdb *HSDatabase) DestroyUser(name string) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
user, err := hsdb.getUser(name)
func DestroyUser(tx *gorm.DB, name string) error {
user, err := GetUser(tx, name)
if err != nil {
return ErrUserNotFound
}
nodes, err := hsdb.listNodesByUser(name)
nodes, err := ListNodesByUser(tx, name)
if err != nil {
return err
}
@ -61,32 +67,35 @@ func (hsdb *HSDatabase) DestroyUser(name string) error {
return ErrUserStillHasNodes
}
keys, err := hsdb.listPreAuthKeys(name)
keys, err := ListPreAuthKeys(tx, name)
if err != nil {
return err
}
for _, key := range keys {
err = hsdb.destroyPreAuthKey(key)
err = DestroyPreAuthKey(tx, key)
if err != nil {
return err
}
}
if result := hsdb.db.Unscoped().Delete(&user); result.Error != nil {
if result := tx.Unscoped().Delete(&user); result.Error != nil {
return result.Error
}
return nil
}
func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
return hsdb.Write(func(tx *gorm.DB) error {
return RenameUser(tx, oldName, newName)
})
}
// RenameUser renames a User. Returns error if the User does
// not exist or if another User exists with the new name.
func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
func RenameUser(tx *gorm.DB, oldName, newName string) error {
var err error
oldUser, err := hsdb.getUser(oldName)
oldUser, err := GetUser(tx, oldName)
if err != nil {
return err
}
@ -94,7 +103,7 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
if err != nil {
return err
}
_, err = hsdb.getUser(newName)
_, err = GetUser(tx, newName)
if err == nil {
return ErrUserExists
}
@ -104,24 +113,22 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
oldUser.Name = newName
if result := hsdb.db.Save(&oldUser); result.Error != nil {
if result := tx.Save(&oldUser); result.Error != nil {
return result.Error
}
return nil
}
// GetUser fetches a user by name.
func (hsdb *HSDatabase) GetUser(name string) (*types.User, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.getUser(name)
return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) {
return GetUser(rx, name)
})
}
func (hsdb *HSDatabase) getUser(name string) (*types.User, error) {
func GetUser(tx *gorm.DB, name string) (*types.User, error) {
user := types.User{}
if result := hsdb.db.First(&user, "name = ?", name); errors.Is(
if result := tx.First(&user, "name = ?", name); errors.Is(
result.Error,
gorm.ErrRecordNotFound,
) {
@ -131,17 +138,16 @@ func (hsdb *HSDatabase) getUser(name string) (*types.User, error) {
return &user, nil
}
// ListUsers gets all the existing users.
func (hsdb *HSDatabase) ListUsers() ([]types.User, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.listUsers()
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.User, error) {
return ListUsers(rx)
})
}
func (hsdb *HSDatabase) listUsers() ([]types.User, error) {
// ListUsers gets all the existing users.
func ListUsers(tx *gorm.DB) ([]types.User, error) {
users := []types.User{}
if err := hsdb.db.Find(&users).Error; err != nil {
if err := tx.Find(&users).Error; err != nil {
return nil, err
}
@ -149,46 +155,42 @@ func (hsdb *HSDatabase) listUsers() ([]types.User, error) {
}
// ListNodesByUser gets all the nodes in a given user.
func (hsdb *HSDatabase) ListNodesByUser(name string) (types.Nodes, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.listNodesByUser(name)
}
func (hsdb *HSDatabase) listNodesByUser(name string) (types.Nodes, error) {
func ListNodesByUser(tx *gorm.DB, name string) (types.Nodes, error) {
err := util.CheckForFQDNRules(name)
if err != nil {
return nil, err
}
user, err := hsdb.getUser(name)
user, err := GetUser(tx, name)
if err != nil {
return nil, err
}
nodes := types.Nodes{}
if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: user.ID}).Find(&nodes).Error; err != nil {
if err := tx.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: user.ID}).Find(&nodes).Error; err != nil {
return nil, err
}
return nodes, nil
}
// AssignNodeToUser assigns a Node to a user.
func (hsdb *HSDatabase) AssignNodeToUser(node *types.Node, username string) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
return hsdb.Write(func(tx *gorm.DB) error {
return AssignNodeToUser(tx, node, username)
})
}
// AssignNodeToUser assigns a Node to a user.
func AssignNodeToUser(tx *gorm.DB, node *types.Node, username string) error {
err := util.CheckForFQDNRules(username)
if err != nil {
return err
}
user, err := hsdb.getUser(username)
user, err := GetUser(tx, username)
if err != nil {
return err
}
node.User = *user
if result := hsdb.db.Save(&node); result.Error != nil {
if result := tx.Save(&node); result.Error != nil {
return result.Error
}

View file

@ -36,7 +36,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) {
err = db.DestroyUser("test")
c.Assert(err, check.IsNil)
result := db.db.Preload("User").First(&pak, "key = ?", pak.Key)
result := db.DB.Preload("User").First(&pak, "key = ?", pak.Key)
// destroying a user also deletes all associated preauthkeys
c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound)
@ -53,7 +53,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) {
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
db.db.Save(&node)
db.DB.Save(&node)
err = db.DestroyUser("test")
c.Assert(err, check.Equals, ErrUserStillHasNodes)
@ -105,7 +105,7 @@ func (s *Suite) TestSetMachineUser(c *check.C) {
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
db.db.Save(&node)
db.DB.Save(&node)
c.Assert(node.UserID, check.Equals, oldUser.ID)
err = db.AssignNodeToUser(&node, newUser.Name)

View file

@ -211,7 +211,7 @@ func DERPProbeHandler(
// The initial implementation is here https://github.com/tailscale/tailscale/pull/1406
// They have a cache, but not clear if that is really necessary at Headscale, uh, scale.
// An example implementation is found here https://derp.tailscale.com/bootstrap-dns
// Coordination server is included automatically, since local DERP is using the same DNS Name in d.serverURL
// Coordination server is included automatically, since local DERP is using the same DNS Name in d.serverURL.
func DERPBootstrapDNSHandler(
derpMap *tailcfg.DERPMap,
) func(http.ResponseWriter, *http.Request) {

View file

@ -8,11 +8,13 @@ import (
"time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
@ -136,12 +138,14 @@ func (api headscaleV1APIServer) ExpirePreAuthKey(
ctx context.Context,
request *v1.ExpirePreAuthKeyRequest,
) (*v1.ExpirePreAuthKeyResponse, error) {
preAuthKey, err := api.h.db.GetPreAuthKey(request.GetUser(), request.Key)
err := api.h.db.DB.Transaction(func(tx *gorm.DB) error {
preAuthKey, err := db.GetPreAuthKey(tx, request.GetUser(), request.Key)
if err != nil {
return nil, err
return err
}
err = api.h.db.ExpirePreAuthKey(preAuthKey)
return db.ExpirePreAuthKey(tx, preAuthKey)
})
if err != nil {
return nil, err
}
@ -181,17 +185,31 @@ func (api headscaleV1APIServer) RegisterNode(
return nil, err
}
node, err := api.h.db.RegisterNodeFromAuthCallback(
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
return db.RegisterNodeFromAuthCallback(
tx,
api.h.registrationCache,
mkey,
request.GetUser(),
nil,
util.RegisterMethodCLI,
api.h.cfg.IPPrefixes,
)
})
if err != nil {
return nil, err
}
stateUpdate := types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{node},
Message: "called from api.RegisterNode",
}
if stateUpdate.Valid() {
ctx := types.NotifyCtx(ctx, "cli-registernode", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
}
return &v1.RegisterNodeResponse{Node: node.Proto()}, nil
}
@ -217,25 +235,35 @@ func (api headscaleV1APIServer) SetTags(
ctx context.Context,
request *v1.SetTagsRequest,
) (*v1.SetTagsResponse, error) {
node, err := api.h.db.GetNodeByID(request.GetNodeId())
for _, tag := range request.GetTags() {
err := validateTag(tag)
if err != nil {
return nil, err
}
}
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
err := db.SetTags(tx, request.GetNodeId(), request.GetTags())
if err != nil {
return nil, err
}
for _, tag := range request.GetTags() {
err := validateTag(tag)
return db.GetNodeByID(tx, request.GetNodeId())
})
if err != nil {
return &v1.SetTagsResponse{
Node: nil,
}, status.Error(codes.InvalidArgument, err.Error())
}
}
err = api.h.db.SetTags(node, request.GetTags())
if err != nil {
return &v1.SetTagsResponse{
Node: nil,
}, status.Error(codes.Internal, err.Error())
stateUpdate := types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{node},
Message: "called from api.SetTags",
}
if stateUpdate.Valid() {
ctx := types.NotifyCtx(ctx, "cli-settags", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
}
log.Trace().
@ -270,11 +298,21 @@ func (api headscaleV1APIServer) DeleteNode(
err = api.h.db.DeleteNode(
node,
api.h.nodeNotifier.ConnectedMap(),
)
if err != nil {
return nil, err
}
stateUpdate := types.StateUpdate{
Type: types.StatePeerRemoved,
Removed: []tailcfg.NodeID{tailcfg.NodeID(node.ID)},
}
if stateUpdate.Valid() {
ctx := types.NotifyCtx(ctx, "cli-deletenode", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, stateUpdate)
}
return &v1.DeleteNodeResponse{}, nil
}
@ -282,17 +320,38 @@ func (api headscaleV1APIServer) ExpireNode(
ctx context.Context,
request *v1.ExpireNodeRequest,
) (*v1.ExpireNodeResponse, error) {
node, err := api.h.db.GetNodeByID(request.GetNodeId())
now := time.Now()
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
db.NodeSetExpiry(
tx,
request.GetNodeId(),
now,
)
return db.GetNodeByID(tx, request.GetNodeId())
})
if err != nil {
return nil, err
}
now := time.Now()
selfUpdate := types.StateUpdate{
Type: types.StateSelfUpdate,
ChangeNodes: types.Nodes{node},
}
if selfUpdate.Valid() {
ctx := types.NotifyCtx(ctx, "cli-expirenode-self", node.Hostname)
api.h.nodeNotifier.NotifyByMachineKey(
ctx,
selfUpdate,
node.MachineKey)
}
api.h.db.NodeSetExpiry(
node,
now,
)
stateUpdate := types.StateUpdateExpire(node.ID, now)
if stateUpdate.Valid() {
ctx := types.NotifyCtx(ctx, "cli-expirenode-peers", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
}
log.Trace().
Str("node", node.Hostname).
@ -306,19 +365,32 @@ func (api headscaleV1APIServer) RenameNode(
ctx context.Context,
request *v1.RenameNodeRequest,
) (*v1.RenameNodeResponse, error) {
node, err := api.h.db.GetNodeByID(request.GetNodeId())
if err != nil {
return nil, err
}
err = api.h.db.RenameNode(
node,
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
err := db.RenameNode(
tx,
request.GetNodeId(),
request.GetNewName(),
)
if err != nil {
return nil, err
}
return db.GetNodeByID(tx, request.GetNodeId())
})
if err != nil {
return nil, err
}
stateUpdate := types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{node},
Message: "called from api.RenameNode",
}
if stateUpdate.Valid() {
ctx := types.NotifyCtx(ctx, "cli-renamenode", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
}
log.Trace().
Str("node", node.Hostname).
Str("new_name", request.GetNewName()).
@ -331,8 +403,11 @@ func (api headscaleV1APIServer) ListNodes(
ctx context.Context,
request *v1.ListNodesRequest,
) (*v1.ListNodesResponse, error) {
isConnected := api.h.nodeNotifier.ConnectedMap()
if request.GetUser() != "" {
nodes, err := api.h.db.ListNodesByUser(request.GetUser())
nodes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Nodes, error) {
return db.ListNodesByUser(rx, request.GetUser())
})
if err != nil {
return nil, err
}
@ -343,7 +418,7 @@ func (api headscaleV1APIServer) ListNodes(
// Populate the online field based on
// currently connected nodes.
resp.Online = api.h.nodeNotifier.IsConnected(node.MachineKey)
resp.Online = isConnected[node.MachineKey]
response[index] = resp
}
@ -362,10 +437,10 @@ func (api headscaleV1APIServer) ListNodes(
// Populate the online field based on
// currently connected nodes.
resp.Online = api.h.nodeNotifier.IsConnected(node.MachineKey)
resp.Online = isConnected[node.MachineKey]
validTags, invalidTags := api.h.ACLPolicy.TagsOfNode(
&node,
node,
)
resp.InvalidTags = invalidTags
resp.ValidTags = validTags
@ -396,7 +471,9 @@ func (api headscaleV1APIServer) GetRoutes(
ctx context.Context,
request *v1.GetRoutesRequest,
) (*v1.GetRoutesResponse, error) {
routes, err := api.h.db.GetRoutes()
routes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Routes, error) {
return db.GetRoutes(rx)
})
if err != nil {
return nil, err
}
@ -410,11 +487,19 @@ func (api headscaleV1APIServer) EnableRoute(
ctx context.Context,
request *v1.EnableRouteRequest,
) (*v1.EnableRouteResponse, error) {
err := api.h.db.EnableRoute(request.GetRouteId())
update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
return db.EnableRoute(tx, request.GetRouteId())
})
if err != nil {
return nil, err
}
if update != nil && update.Valid() {
ctx := types.NotifyCtx(ctx, "cli-enableroute", "unknown")
api.h.nodeNotifier.NotifyAll(
ctx, *update)
}
return &v1.EnableRouteResponse{}, nil
}
@ -422,11 +507,19 @@ func (api headscaleV1APIServer) DisableRoute(
ctx context.Context,
request *v1.DisableRouteRequest,
) (*v1.DisableRouteResponse, error) {
err := api.h.db.DisableRoute(request.GetRouteId())
isConnected := api.h.nodeNotifier.ConnectedMap()
update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
return db.DisableRoute(tx, request.GetRouteId(), isConnected)
})
if err != nil {
return nil, err
}
if update != nil && update.Valid() {
ctx := types.NotifyCtx(ctx, "cli-disableroute", "unknown")
api.h.nodeNotifier.NotifyAll(ctx, *update)
}
return &v1.DisableRouteResponse{}, nil
}
@ -453,11 +546,19 @@ func (api headscaleV1APIServer) DeleteRoute(
ctx context.Context,
request *v1.DeleteRouteRequest,
) (*v1.DeleteRouteResponse, error) {
err := api.h.db.DeleteRoute(request.GetRouteId())
isConnected := api.h.nodeNotifier.ConnectedMap()
update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
return db.DeleteRoute(tx, request.GetRouteId(), isConnected)
})
if err != nil {
return nil, err
}
if update != nil && update.Valid() {
ctx := types.NotifyCtx(ctx, "cli-deleteroute", "unknown")
api.h.nodeNotifier.NotifyWithIgnore(ctx, *update)
}
return &v1.DeleteRouteResponse{}, nil
}

View file

@ -272,6 +272,7 @@ func (m *Mapper) LiteMapResponse(
mapRequest tailcfg.MapRequest,
node *types.Node,
pol *policy.ACLPolicy,
messages ...string,
) ([]byte, error) {
resp, err := m.baseWithConfigMapResponse(node, pol, mapRequest.Version)
if err != nil {
@ -290,7 +291,7 @@ func (m *Mapper) LiteMapResponse(
resp.PacketFilter = policy.ReduceFilterRules(node, rules)
resp.SSHPolicy = sshPolicy
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress)
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...)
}
func (m *Mapper) KeepAliveResponse(
@ -392,9 +393,7 @@ func (m *Mapper) PeerChangedPatchResponse(
}
if patches, ok := m.patches[uint64(change.NodeID)]; ok {
patches := append(patches, p)
m.patches[uint64(change.NodeID)] = patches
m.patches[uint64(change.NodeID)] = append(patches, p)
} else {
m.patches[uint64(change.NodeID)] = []patch{p}
}
@ -470,6 +469,8 @@ func (m *Mapper) marshalMapResponse(
switch {
case resp.Peers != nil && len(resp.Peers) > 0:
responseType = "full"
case isSelfUpdate(messages...):
responseType = "self"
case resp.Peers == nil && resp.PeersChanged == nil && resp.PeersChangedPatch == nil:
responseType = "lite"
case resp.PeersChanged != nil && len(resp.PeersChanged) > 0:
@ -668,3 +669,13 @@ func appendPeerChanges(
return nil
}
func isSelfUpdate(messages ...string) bool {
for _, message := range messages {
if strings.Contains(message, types.SelfUpdateIdentifier) {
return true
}
}
return false
}

View file

@ -72,7 +72,7 @@ func tailNode(
}
var derp string
if node.Hostinfo.NetInfo != nil {
if node.Hostinfo != nil && node.Hostinfo.NetInfo != nil {
derp = fmt.Sprintf("127.3.3.40:%d", node.Hostinfo.NetInfo.PreferredDERP)
} else {
derp = "127.3.3.40:0" // Zero means disconnected or unknown.

View file

@ -1,6 +1,7 @@
package notifier
import (
"context"
"fmt"
"strings"
"sync"
@ -14,24 +15,28 @@ import (
type Notifier struct {
l sync.RWMutex
nodes map[string]chan<- types.StateUpdate
connected map[key.MachinePublic]bool
}
func NewNotifier() *Notifier {
return &Notifier{}
return &Notifier{
nodes: make(map[string]chan<- types.StateUpdate),
connected: make(map[key.MachinePublic]bool),
}
}
func (n *Notifier) AddNode(machineKey key.MachinePublic, c chan<- types.StateUpdate) {
log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("acquiring lock to add node")
defer log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("releasing lock to add node")
defer log.Trace().
Caller().
Str("key", machineKey.ShortString()).
Msg("releasing lock to add node")
n.l.Lock()
defer n.l.Unlock()
if n.nodes == nil {
n.nodes = make(map[string]chan<- types.StateUpdate)
}
n.nodes[machineKey.String()] = c
n.connected[machineKey] = true
log.Trace().
Str("machine_key", machineKey.ShortString()).
@ -41,16 +46,20 @@ func (n *Notifier) AddNode(machineKey key.MachinePublic, c chan<- types.StateUpd
func (n *Notifier) RemoveNode(machineKey key.MachinePublic) {
log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("acquiring lock to remove node")
defer log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("releasing lock to remove node")
defer log.Trace().
Caller().
Str("key", machineKey.ShortString()).
Msg("releasing lock to remove node")
n.l.Lock()
defer n.l.Unlock()
if n.nodes == nil {
if len(n.nodes) == 0 {
return
}
delete(n.nodes, machineKey.String())
n.connected[machineKey] = false
log.Trace().
Str("machine_key", machineKey.ShortString()).
@ -64,23 +73,28 @@ func (n *Notifier) IsConnected(machineKey key.MachinePublic) bool {
n.l.RLock()
defer n.l.RUnlock()
if _, ok := n.nodes[machineKey.String()]; ok {
return true
return n.connected[machineKey]
}
return false
// TODO(kradalby): This returns a pointer and can be dangerous.
func (n *Notifier) ConnectedMap() map[key.MachinePublic]bool {
return n.connected
}
func (n *Notifier) NotifyAll(update types.StateUpdate) {
n.NotifyWithIgnore(update)
func (n *Notifier) NotifyAll(ctx context.Context, update types.StateUpdate) {
n.NotifyWithIgnore(ctx, update)
}
func (n *Notifier) NotifyWithIgnore(update types.StateUpdate, ignore ...string) {
func (n *Notifier) NotifyWithIgnore(
ctx context.Context,
update types.StateUpdate,
ignore ...string,
) {
log.Trace().Caller().Interface("type", update.Type).Msg("acquiring lock to notify")
defer log.Trace().
Caller().
Interface("type", update.Type).
Msg("releasing lock, finished notifing")
Msg("releasing lock, finished notifying")
n.l.RLock()
defer n.l.RUnlock()
@ -90,23 +104,58 @@ func (n *Notifier) NotifyWithIgnore(update types.StateUpdate, ignore ...string)
continue
}
log.Trace().Caller().Str("machine", key).Strs("ignoring", ignore).Msg("sending update")
c <- update
select {
case <-ctx.Done():
log.Error().
Err(ctx.Err()).
Str("mkey", key).
Any("origin", ctx.Value("origin")).
Any("hostname", ctx.Value("hostname")).
Msgf("update not sent, context cancelled")
return
case c <- update:
log.Trace().
Str("mkey", key).
Any("origin", ctx.Value("origin")).
Any("hostname", ctx.Value("hostname")).
Msgf("update successfully sent on chan")
}
}
}
func (n *Notifier) NotifyByMachineKey(update types.StateUpdate, mKey key.MachinePublic) {
func (n *Notifier) NotifyByMachineKey(
ctx context.Context,
update types.StateUpdate,
mKey key.MachinePublic,
) {
log.Trace().Caller().Interface("type", update.Type).Msg("acquiring lock to notify")
defer log.Trace().
Caller().
Interface("type", update.Type).
Msg("releasing lock, finished notifing")
Msg("releasing lock, finished notifying")
n.l.RLock()
defer n.l.RUnlock()
if c, ok := n.nodes[mKey.String()]; ok {
c <- update
select {
case <-ctx.Done():
log.Error().
Err(ctx.Err()).
Str("mkey", mKey.String()).
Any("origin", ctx.Value("origin")).
Any("hostname", ctx.Value("hostname")).
Msgf("update not sent, context cancelled")
return
case c <- update:
log.Trace().
Str("mkey", mKey.String()).
Any("origin", ctx.Value("origin")).
Any("hostname", ctx.Value("hostname")).
Msgf("update successfully sent on chan")
}
}
}

View file

@ -20,6 +20,7 @@ import (
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"golang.org/x/oauth2"
"gorm.io/gorm"
"tailscale.com/types/key"
)
@ -492,7 +493,7 @@ func (h *Headscale) validateNodeForOIDCCallback(
Str("node", node.Hostname).
Msg("node already registered, reauthenticating")
err := h.db.NodeSetExpiry(node, expiry)
err := h.db.NodeSetExpiry(node.ID, expiry)
if err != nil {
util.LogErr(err, "Failed to refresh node")
http.Error(
@ -536,6 +537,12 @@ func (h *Headscale) validateNodeForOIDCCallback(
util.LogErr(err, "Failed to write response")
}
stateUpdate := types.StateUpdateExpire(node.ID, expiry)
if stateUpdate.Valid() {
ctx := types.NotifyCtx(context.Background(), "oidc-expiry", "na")
h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
}
return nil, true, nil
}
@ -613,14 +620,22 @@ func (h *Headscale) registerNodeForOIDCCallback(
machineKey *key.MachinePublic,
expiry time.Time,
) error {
if _, err := h.db.RegisterNodeFromAuthCallback(
if err := h.db.DB.Transaction(func(tx *gorm.DB) error {
if _, err := db.RegisterNodeFromAuthCallback(
// TODO(kradalby): find a better way to use the cache across modules
tx,
h.registrationCache,
*machineKey,
user.Name,
&expiry,
util.RegisterMethodOIDC,
h.cfg.IPPrefixes,
); err != nil {
return err
}
return nil
}); err != nil {
util.LogErr(err, "could not register node")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)

View file

@ -905,8 +905,14 @@ func (pol *ACLPolicy) TagsOfNode(
validTags := make([]string, 0)
invalidTags := make([]string, 0)
// TODO(kradalby): Why is this sometimes nil? coming from tailNode?
if node == nil {
return validTags, invalidTags
}
validTagMap := make(map[string]bool)
invalidTagMap := make(map[string]bool)
if node.Hostinfo != nil {
for _, tag := range node.Hostinfo.RequestTags {
owners, err := expandOwnersFromTag(pol, tag)
if errors.Is(err, ErrInvalidTag) {
@ -932,6 +938,7 @@ func (pol *ACLPolicy) TagsOfNode(
for tag := range validTagMap {
validTags = append(validTags, tag)
}
}
return validTags, invalidTags
}

View file

@ -4,12 +4,15 @@ import (
"context"
"fmt"
"net/http"
"strings"
"time"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/mapper"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log"
xslices "golang.org/x/exp/slices"
"gorm.io/gorm"
"tailscale.com/tailcfg"
)
@ -128,10 +131,14 @@ func (h *Headscale) handlePoll(
if h.ACLPolicy != nil {
// update routes with peer information
err = h.db.EnableAutoApprovedRoutes(h.ACLPolicy, node)
update, err := h.db.EnableAutoApprovedRoutes(h.ACLPolicy, node)
if err != nil {
logErr(err, "Error running auto approved routes")
}
if update != nil {
sendUpdate = true
}
}
}
@ -146,7 +153,7 @@ func (h *Headscale) handlePoll(
}
if sendUpdate {
if err := h.db.NodeSave(node); err != nil {
if err := h.db.DB.Save(node).Error; err != nil {
logErr(err, "Failed to persist/update node in the database")
http.Error(writer, "", http.StatusInternalServerError)
@ -161,7 +168,9 @@ func (h *Headscale) handlePoll(
Message: "called from handlePoll -> update -> new hostinfo",
}
if stateUpdate.Valid() {
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-peers-hostinfochange", node.Hostname)
h.nodeNotifier.NotifyWithIgnore(
ctx,
stateUpdate,
node.MachineKey.String())
}
@ -174,7 +183,9 @@ func (h *Headscale) handlePoll(
ChangeNodes: types.Nodes{node},
}
if selfUpdate.Valid() {
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-self-hostinfochange", node.Hostname)
h.nodeNotifier.NotifyByMachineKey(
ctx,
selfUpdate,
node.MachineKey)
}
@ -183,7 +194,7 @@ func (h *Headscale) handlePoll(
}
}
if err := h.db.NodeSave(node); err != nil {
if err := h.db.DB.Save(node).Error; err != nil {
logErr(err, "Failed to persist/update node in the database")
http.Error(writer, "", http.StatusInternalServerError)
@ -195,7 +206,9 @@ func (h *Headscale) handlePoll(
ChangePatches: []*tailcfg.PeerChange{&change},
}
if stateUpdate.Valid() {
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-peers-patch", node.Hostname)
h.nodeNotifier.NotifyWithIgnore(
ctx,
stateUpdate,
node.MachineKey.String())
}
@ -251,7 +264,7 @@ func (h *Headscale) handlePoll(
}
}
if err := h.db.NodeSave(node); err != nil {
if err := h.db.DB.Save(node).Error; err != nil {
logErr(err, "Failed to persist/update node in the database")
http.Error(writer, "", http.StatusInternalServerError)
@ -288,7 +301,10 @@ func (h *Headscale) handlePoll(
// update ACLRules with peer informations (to update server tags if necessary)
if h.ACLPolicy != nil {
// update routes with peer information
err = h.db.EnableAutoApprovedRoutes(h.ACLPolicy, node)
// This state update is ignored as it will be sent
// as part of the whole node
// TODO(kradalby): figure out if that is actually correct
_, err = h.db.EnableAutoApprovedRoutes(h.ACLPolicy, node)
if err != nil {
logErr(err, "Error running auto approved routes")
}
@ -324,11 +340,17 @@ func (h *Headscale) handlePoll(
Message: "called from handlePoll -> new node added",
}
if stateUpdate.Valid() {
ctx := types.NotifyCtx(context.Background(), "poll-newnode-peers", node.Hostname)
h.nodeNotifier.NotifyWithIgnore(
ctx,
stateUpdate,
node.MachineKey.String())
}
if len(node.Routes) > 0 {
go h.pollFailoverRoutes(logErr, "new node", node)
}
// Set up the client stream
h.pollNetMapStreamWG.Add(1)
defer h.pollNetMapStreamWG.Done()
@ -346,15 +368,9 @@ func (h *Headscale) handlePoll(
keepAliveTicker := time.NewTicker(keepAliveInterval)
ctx = context.WithValue(ctx, nodeNameContextKey, node.Hostname)
ctx, cancel := context.WithCancel(ctx)
ctx, cancel := context.WithCancel(context.WithValue(ctx, nodeNameContextKey, node.Hostname))
defer cancel()
if len(node.Routes) > 0 {
go h.db.EnsureFailoverRouteIsAvailable(node)
}
for {
logInfo("Waiting for update on stream channel")
select {
@ -403,6 +419,7 @@ func (h *Headscale) handlePoll(
return
}
startMapResp := time.Now()
switch update.Type {
case types.StateFullUpdate:
logInfo("Sending Full MapResponse")
@ -411,6 +428,7 @@ func (h *Headscale) handlePoll(
case types.StatePeerChanged:
logInfo(fmt.Sprintf("Sending Changed MapResponse: %s", update.Message))
isConnectedMap := h.nodeNotifier.ConnectedMap()
for _, node := range update.ChangeNodes {
// If a node is not reported to be online, it might be
// because the value is outdated, check with the notifier.
@ -418,7 +436,7 @@ func (h *Headscale) handlePoll(
// this might be because it has announced itself, but not
// reached the stage to actually create the notifier channel.
if node.IsOnline != nil && !*node.IsOnline {
isOnline := h.nodeNotifier.IsConnected(node.MachineKey)
isOnline := isConnectedMap[node.MachineKey]
node.IsOnline = &isOnline
}
}
@ -434,7 +452,7 @@ func (h *Headscale) handlePoll(
if len(update.ChangeNodes) == 1 {
logInfo("Sending SelfUpdate MapResponse")
node = update.ChangeNodes[0]
data, err = mapp.LiteMapResponse(mapRequest, node, h.ACLPolicy)
data, err = mapp.LiteMapResponse(mapRequest, node, h.ACLPolicy, types.SelfUpdateIdentifier)
} else {
logInfo("SelfUpdate contained too many nodes, this is likely a bug in the code, please report.")
}
@ -449,8 +467,11 @@ func (h *Headscale) handlePoll(
return
}
log.Trace().Str("node", node.Hostname).TimeDiff("timeSpent", time.Now(), startMapResp).Str("mkey", node.MachineKey.String()).Int("type", int(update.Type)).Msg("finished making map response")
// Only send update if there is change
if data != nil {
startWrite := time.Now()
_, err = writer.Write(data)
if err != nil {
logErr(err, "Could not write the map response")
@ -468,6 +489,7 @@ func (h *Headscale) handlePoll(
return
}
log.Trace().Str("node", node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", node.MachineKey.String()).Int("type", int(update.Type)).Msg("finished writing mapresp to node")
log.Info().
Caller().
@ -487,7 +509,7 @@ func (h *Headscale) handlePoll(
go h.updateNodeOnlineStatus(false, node)
// Failover the node's routes if any.
go h.db.FailoverNodeRoutesWithNotify(node)
go h.pollFailoverRoutes(logErr, "node closing connection", node)
// The connection has been closed, so we can stop polling.
return
@ -500,6 +522,22 @@ func (h *Headscale) handlePoll(
}
}
func (h *Headscale) pollFailoverRoutes(logErr func(error, string), where string, node *types.Node) {
update, err := db.Write(h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
return db.EnsureFailoverRouteIsAvailable(tx, h.nodeNotifier.ConnectedMap(), node)
})
if err != nil {
logErr(err, fmt.Sprintf("failed to ensure failover routes, %s", where))
return
}
if update != nil && !update.Empty() && update.Valid() {
ctx := types.NotifyCtx(context.Background(), fmt.Sprintf("poll-%s-routes-ensurefailover", strings.ReplaceAll(where, " ", "-")), node.Hostname)
h.nodeNotifier.NotifyWithIgnore(ctx, *update, node.MachineKey.String())
}
}
// updateNodeOnlineStatus records the last seen status of a node and notifies peers
// about change in their online/offline status.
// It takes a StateUpdateType of either StatePeerOnlineChanged or StatePeerOfflineChanged.
@ -519,10 +557,13 @@ func (h *Headscale) updateNodeOnlineStatus(online bool, node *types.Node) {
},
}
if statusUpdate.Valid() {
h.nodeNotifier.NotifyWithIgnore(statusUpdate, node.MachineKey.String())
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-onlinestatus", node.Hostname)
h.nodeNotifier.NotifyWithIgnore(ctx, statusUpdate, node.MachineKey.String())
}
err := h.db.UpdateLastSeen(node)
err := h.db.DB.Transaction(func(tx *gorm.DB) error {
return db.UpdateLastSeen(tx, node.ID, *node.LastSeen)
})
if err != nil {
log.Error().Err(err).Msg("Cannot update node LastSeen")

View file

@ -13,7 +13,7 @@ import (
)
const (
MinimumCapVersion tailcfg.CapabilityVersion = 56
MinimumCapVersion tailcfg.CapabilityVersion = 58
)
// NoisePollNetMapHandler takes care of /machine/:id/map using the Noise protocol

View file

@ -1,15 +1,19 @@
package types
import (
"context"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"net/netip"
"time"
"tailscale.com/tailcfg"
)
const SelfUpdateIdentifier = "self-update"
var ErrCannotParsePrefix = errors.New("cannot parse prefix")
type IPPrefix netip.Prefix
@ -160,3 +164,37 @@ func (su *StateUpdate) Valid() bool {
return true
}
// Empty reports if there are any updates in the StateUpdate.
func (su *StateUpdate) Empty() bool {
switch su.Type {
case StatePeerChanged:
return len(su.ChangeNodes) == 0
case StatePeerChangedPatch:
return len(su.ChangePatches) == 0
case StatePeerRemoved:
return len(su.Removed) == 0
}
return false
}
func StateUpdateExpire(nodeID uint64, expiry time.Time) StateUpdate {
return StateUpdate{
Type: StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: tailcfg.NodeID(nodeID),
KeyExpiry: &expiry,
},
},
}
}
func NotifyCtx(ctx context.Context, origin, hostname string) context.Context {
ctx2, _ := context.WithTimeout(
context.WithValue(context.WithValue(ctx, "hostname", hostname), "origin", origin),
3*time.Second,
)
return ctx2
}

View file

@ -2,7 +2,6 @@ package types
import (
"strconv"
"time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/util"
@ -25,9 +24,10 @@ func (n *User) TailscaleUser() *tailcfg.User {
ID: tailcfg.UserID(n.ID),
LoginName: n.Name,
DisplayName: n.Name,
// TODO(kradalby): See if we can fill in Gravatar here
ProfilePicURL: "",
Logins: []tailcfg.LoginID{},
Created: time.Time{},
Created: n.CreatedAt,
}
return &user
@ -38,6 +38,7 @@ func (n *User) TailscaleLogin() *tailcfg.Login {
ID: tailcfg.LoginID(n.ID),
LoginName: n.Name,
DisplayName: n.Name,
// TODO(kradalby): See if we can fill in Gravatar here
ProfilePicURL: "",
}

View file

@ -1240,7 +1240,7 @@ func TestNodeRenameCommand(t *testing.T) {
assert.Contains(t, listAll[4].GetGivenName(), "node-5")
for idx := 0; idx < 3; idx++ {
_, err := headscale.Execute(
res, err := headscale.Execute(
[]string{
"headscale",
"nodes",
@ -1251,6 +1251,8 @@ func TestNodeRenameCommand(t *testing.T) {
},
)
assert.Nil(t, err)
assert.Contains(t, res, "Node renamed")
}
var listAllAfterRename []v1.Node

View file

@ -26,6 +26,8 @@ func TestPingAllByIP(t *testing.T) {
assertNoErr(t, err)
defer scenario.Shutdown()
// TODO(kradalby): it does not look like the user thing works, only second
// get created? maybe only when many?
spec := map[string]int{
"user1": len(MustTestVersions),
"user2": len(MustTestVersions),
@ -321,7 +323,12 @@ func TestTaildrop(t *testing.T) {
t.Fatalf("failed to install curl on %s, err: %s", client.Hostname(), err)
}
}
curlCommand := []string{"curl", "--unix-socket", "/var/run/tailscale/tailscaled.sock", "http://local-tailscaled.sock/localapi/v0/file-targets"}
curlCommand := []string{
"curl",
"--unix-socket",
"/var/run/tailscale/tailscaled.sock",
"http://local-tailscaled.sock/localapi/v0/file-targets",
}
err = retry(10, 1*time.Second, func() error {
result, _, err := client.Execute(curlCommand)
if err != nil {
@ -338,13 +345,23 @@ func TestTaildrop(t *testing.T) {
for _, ft := range fts {
ftStr += fmt.Sprintf("\t%s\n", ft.Node.Name)
}
return fmt.Errorf("client %s does not have all its peers as FileTargets, got %d, want: %d\n%s", client.Hostname(), len(fts), len(allClients)-1, ftStr)
return fmt.Errorf(
"client %s does not have all its peers as FileTargets, got %d, want: %d\n%s",
client.Hostname(),
len(fts),
len(allClients)-1,
ftStr,
)
}
return err
})
if err != nil {
t.Errorf("failed to query localapi for filetarget on %s, err: %s", client.Hostname(), err)
t.Errorf(
"failed to query localapi for filetarget on %s, err: %s",
client.Hostname(),
err,
)
}
}
@ -434,72 +451,6 @@ func TestTaildrop(t *testing.T) {
}
}
func TestResolveMagicDNS(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario, err := NewScenario()
assertNoErr(t, err)
defer scenario.Shutdown()
spec := map[string]int{
"magicdns1": len(MustTestVersions),
"magicdns2": len(MustTestVersions),
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("magicdns"))
assertNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
// Poor mans cache
_, err = scenario.ListTailscaleClientsFQDNs()
assertNoErrListFQDN(t, err)
_, err = scenario.ListTailscaleClientsIPs()
assertNoErrListClientIPs(t, err)
for _, client := range allClients {
for _, peer := range allClients {
// It is safe to ignore this error as we handled it when caching it
peerFQDN, _ := peer.FQDN()
command := []string{
"tailscale",
"ip", peerFQDN,
}
result, _, err := client.Execute(command)
if err != nil {
t.Fatalf(
"failed to execute resolve/ip command %s from %s: %s",
peerFQDN,
client.Hostname(),
err,
)
}
ips, err := peer.IPs()
if err != nil {
t.Fatalf(
"failed to get ips for %s: %s",
peer.Hostname(),
err,
)
}
for _, ip := range ips {
if !strings.Contains(result, ip.String()) {
t.Fatalf("ip %s is not found in \n%s\n", ip.String(), result)
}
}
}
}
}
func TestExpireNode(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
@ -545,7 +496,7 @@ func TestExpireNode(t *testing.T) {
// TODO(kradalby): This is Headscale specific and would not play nicely
// with other implementations of the ControlServer interface
result, err := headscale.Execute([]string{
"headscale", "nodes", "expire", "--identifier", "0", "--output", "json",
"headscale", "nodes", "expire", "--identifier", "1", "--output", "json",
})
assertNoErr(t, err)
@ -576,16 +527,38 @@ func TestExpireNode(t *testing.T) {
assertNotNil(t, peerStatus.Expired)
assert.NotNil(t, peerStatus.KeyExpiry)
t.Logf("node %q should have a key expire before %s, was %s", peerStatus.HostName, now.String(), peerStatus.KeyExpiry)
t.Logf(
"node %q should have a key expire before %s, was %s",
peerStatus.HostName,
now.String(),
peerStatus.KeyExpiry,
)
if peerStatus.KeyExpiry != nil {
assert.Truef(t, peerStatus.KeyExpiry.Before(now), "node %q should have a key expire before %s, was %s", peerStatus.HostName, now.String(), peerStatus.KeyExpiry)
assert.Truef(
t,
peerStatus.KeyExpiry.Before(now),
"node %q should have a key expire before %s, was %s",
peerStatus.HostName,
now.String(),
peerStatus.KeyExpiry,
)
}
assert.Truef(t, peerStatus.Expired, "node %q should be expired, expired is %v", peerStatus.HostName, peerStatus.Expired)
assert.Truef(
t,
peerStatus.Expired,
"node %q should be expired, expired is %v",
peerStatus.HostName,
peerStatus.Expired,
)
_, stderr, _ := client.Execute([]string{"tailscale", "ping", node.GetName()})
if !strings.Contains(stderr, "node key has expired") {
t.Errorf("expected to be unable to ping expired host %q from %q", node.GetName(), client.Hostname())
t.Errorf(
"expected to be unable to ping expired host %q from %q",
node.GetName(),
client.Hostname(),
)
}
} else {
t.Errorf("failed to find node %q with nodekey (%s) in mapresponse, should be present even if it is expired", node.GetName(), expiredNodeKey)
@ -597,7 +570,7 @@ func TestExpireNode(t *testing.T) {
// NeedsLogin means that the node has understood that it is no longer
// valid.
assert.Equal(t, "NeedsLogin", status.BackendState)
assert.Equalf(t, "NeedsLogin", status.BackendState, "checking node %q", status.Self.HostName)
}
}
}
@ -690,7 +663,8 @@ func TestNodeOnlineLastSeenStatus(t *testing.T) {
assert.Truef(
t,
lastSeen.After(lastSeenThreshold),
"lastSeen (%v) was not %s after the threshold (%v)",
"node (%s) lastSeen (%v) was not %s after the threshold (%v)",
node.GetName(),
lastSeen,
keepAliveInterval,
lastSeenThreshold,

View file

@ -88,9 +88,9 @@ func TestEnablingRoutes(t *testing.T) {
assert.Len(t, routes, 3)
for _, route := range routes {
assert.Equal(t, route.GetAdvertised(), true)
assert.Equal(t, route.GetEnabled(), false)
assert.Equal(t, route.GetIsPrimary(), false)
assert.Equal(t, true, route.GetAdvertised())
assert.Equal(t, false, route.GetEnabled())
assert.Equal(t, false, route.GetIsPrimary())
}
// Verify that no routes has been sent to the client,
@ -135,9 +135,9 @@ func TestEnablingRoutes(t *testing.T) {
assert.Len(t, enablingRoutes, 3)
for _, route := range enablingRoutes {
assert.Equal(t, route.GetAdvertised(), true)
assert.Equal(t, route.GetEnabled(), true)
assert.Equal(t, route.GetIsPrimary(), true)
assert.Equal(t, true, route.GetAdvertised())
assert.Equal(t, true, route.GetEnabled())
assert.Equal(t, true, route.GetIsPrimary())
}
time.Sleep(5 * time.Second)
@ -191,6 +191,8 @@ func TestEnablingRoutes(t *testing.T) {
})
assertNoErr(t, err)
time.Sleep(5 * time.Second)
var disablingRoutes []*v1.Route
err = executeAndUnmarshal(
headscale,
@ -209,16 +211,14 @@ func TestEnablingRoutes(t *testing.T) {
assert.Equal(t, true, route.GetAdvertised())
if route.GetId() == routeToBeDisabled.GetId() {
assert.Equal(t, route.GetEnabled(), false)
assert.Equal(t, route.GetIsPrimary(), false)
assert.Equal(t, false, route.GetEnabled())
assert.Equal(t, false, route.GetIsPrimary())
} else {
assert.Equal(t, route.GetEnabled(), true)
assert.Equal(t, route.GetIsPrimary(), true)
assert.Equal(t, true, route.GetEnabled())
assert.Equal(t, true, route.GetIsPrimary())
}
}
time.Sleep(5 * time.Second)
// Verify that the clients can see the new routes
for _, client := range allClients {
status, err := client.Status()
@ -294,7 +294,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
// advertise HA route on node 1 and 2
// ID 1 will be primary
// ID 2 will be secondary
for _, client := range allClients {
for _, client := range allClients[:2] {
status, err := client.Status()
assertNoErr(t, err)
@ -306,6 +306,8 @@ func TestHASubnetRouterFailover(t *testing.T) {
}
_, _, err = client.Execute(command)
assertNoErrf(t, "failed to advertise route: %s", err)
} else {
t.Fatalf("failed to find route for Node %s (id: %s)", status.Self.HostName, status.Self.ID)
}
}
@ -328,6 +330,8 @@ func TestHASubnetRouterFailover(t *testing.T) {
assertNoErr(t, err)
assert.Len(t, routes, 2)
t.Logf("initial routes %#v", routes)
for _, route := range routes {
assert.Equal(t, true, route.GetAdvertised())
assert.Equal(t, false, route.GetEnabled())
@ -644,6 +648,8 @@ func TestHASubnetRouterFailover(t *testing.T) {
assertNoErr(t, err)
assert.Len(t, routesAfterDisabling1, 2)
t.Logf("routes after disabling1 %#v", routesAfterDisabling1)
// Node 1 is not primary
assert.Equal(t, true, routesAfterDisabling1[0].GetAdvertised())
assert.Equal(t, false, routesAfterDisabling1[0].GetEnabled())

View file

@ -56,8 +56,8 @@ var (
"1.44": true, // CapVer: 63
"1.42": true, // CapVer: 61
"1.40": true, // CapVer: 61
"1.38": true, // CapVer: 58
"1.36": true, // Oldest supported version, CapVer: 56
"1.38": true, // Oldest supported version, CapVer: 58
"1.36": false, // CapVer: 56
"1.34": false, // CapVer: 51
"1.32": false, // CapVer: 46
"1.30": false,