Split code into modules

This is a massive commit that restructures the code into modules:

db/
    All functions related to modifying the Database

types/
    All type definitions and methods that can be exclusivly used on
    these types without dependencies

policy/
    All Policy related code, now without dependencies on the Database.

policy/matcher/
    Dedicated code to match machines in a list of FilterRules

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2023-05-21 19:37:59 +03:00 committed by Kristoffer Dalby
parent 14e29a7bee
commit feb15365b5
51 changed files with 4677 additions and 4290 deletions

View file

@ -7,7 +7,7 @@ import (
"strconv" "strconv"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol" "github.com/juanfont/headscale/hscontrol/types"
"github.com/pterm/pterm" "github.com/pterm/pterm"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
@ -277,7 +277,7 @@ func routesToPtables(routes []*v1.Route) pterm.TableData {
continue continue
} }
if prefix == hscontrol.ExitRouteV4 || prefix == hscontrol.ExitRouteV6 { if prefix == types.ExitRouteV4 || prefix == types.ExitRouteV6 {
isPrimaryStr = "-" isPrimaryStr = "-"
} else { } else {
isPrimaryStr = strconv.FormatBool(route.IsPrimary) isPrimaryStr = strconv.FormatBool(route.IsPrimary)

View file

@ -10,6 +10,7 @@ import (
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol" "github.com/juanfont/headscale/hscontrol"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"google.golang.org/grpc" "google.golang.org/grpc"
@ -41,13 +42,15 @@ func getHeadscaleApp() (*hscontrol.Headscale, error) {
if cfg.ACL.PolicyPath != "" { if cfg.ACL.PolicyPath != "" {
aclPath := util.AbsolutePathFromConfigPath(cfg.ACL.PolicyPath) aclPath := util.AbsolutePathFromConfigPath(cfg.ACL.PolicyPath)
err = app.LoadACLPolicyFromPath(aclPath) pol, err := policy.LoadACLPolicyFromPath(aclPath)
if err != nil { if err != nil {
log.Fatal(). log.Fatal().
Str("path", aclPath). Str("path", aclPath).
Err(err). Err(err).
Msg("Could not load the ACL policy") Msg("Could not load the ACL policy")
} }
app.ACLPolicy = pol
} }
return app, nil return app, nil

View file

@ -18,9 +18,6 @@ const (
// TODO(juan): remove this once https://github.com/juanfont/headscale/issues/727 is fixed. // TODO(juan): remove this once https://github.com/juanfont/headscale/issues/727 is fixed.
registrationHoldoff = time.Second * 5 registrationHoldoff = time.Second * 5
reservedResponseHeaderSize = 4 reservedResponseHeaderSize = 4
RegisterMethodAuthKey = "authkey"
RegisterMethodOIDC = "oidc"
RegisterMethodCLI = "cli"
) )
var ErrRegisterMethodCLIDoesNotSupportExpire = errors.New( var ErrRegisterMethodCLIDoesNotSupportExpire = errors.New(
@ -56,7 +53,7 @@ func (h *Headscale) HealthHandler(
} }
} }
if err := h.db.pingDB(req.Context()); err != nil { if err := h.db.PingDB(req.Context()); err != nil {
respond(err) respond(err)
return return

View file

@ -3,6 +3,7 @@ package hscontrol
import ( import (
"time" "time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
@ -10,13 +11,13 @@ import (
func (h *Headscale) generateMapResponse( func (h *Headscale) generateMapResponse(
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
machine *Machine, machine *types.Machine,
) (*tailcfg.MapResponse, error) { ) (*tailcfg.MapResponse, error) {
log.Trace(). log.Trace().
Str("func", "generateMapResponse"). Str("func", "generateMapResponse").
Str("machine", mapRequest.Hostinfo.Hostname). Str("machine", mapRequest.Hostinfo.Hostname).
Msg("Creating Map response") Msg("Creating Map response")
node, err := h.db.toNode(*machine, h.aclPolicy, h.cfg.BaseDomain, h.cfg.DNSConfig) node, err := h.db.TailNode(*machine, h.ACLPolicy, h.cfg.DNSConfig)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
@ -27,7 +28,7 @@ func (h *Headscale) generateMapResponse(
return nil, err return nil, err
} }
peers, err := h.db.getValidPeers(h.aclPolicy, h.aclRules, machine) peers, err := h.db.GetValidPeers(h.aclRules, machine)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
@ -38,9 +39,9 @@ func (h *Headscale) generateMapResponse(
return nil, err return nil, err
} }
profiles := h.db.getMapResponseUserProfiles(*machine, peers) profiles := h.db.GetMapResponseUserProfiles(*machine, peers)
nodePeers, err := h.db.toNodes(peers, h.aclPolicy, h.cfg.BaseDomain, h.cfg.DNSConfig) nodePeers, err := h.db.TailNodes(peers, h.ACLPolicy, h.cfg.DNSConfig)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().

View file

@ -23,6 +23,9 @@ import (
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"github.com/juanfont/headscale" "github.com/juanfont/headscale"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/patrickmn/go-cache" "github.com/patrickmn/go-cache"
zerolog "github.com/philip-bui/grpc-zerolog" zerolog "github.com/philip-bui/grpc-zerolog"
@ -73,7 +76,7 @@ const (
// Headscale represents the base app of the service. // Headscale represents the base app of the service.
type Headscale struct { type Headscale struct {
cfg *Config cfg *Config
db *HSDatabase db *db.HSDatabase
dbString string dbString string
dbType string dbType string
dbDebug bool dbDebug bool
@ -83,7 +86,7 @@ type Headscale struct {
DERPMap *tailcfg.DERPMap DERPMap *tailcfg.DERPMap
DERPServer *DERPServer DERPServer *DERPServer
aclPolicy *ACLPolicy ACLPolicy *policy.ACLPolicy
aclRules []tailcfg.FilterRule aclRules []tailcfg.FilterRule
sshPolicy *tailcfg.SSHPolicy sshPolicy *tailcfg.SSHPolicy
@ -99,6 +102,12 @@ type Headscale struct {
stateUpdateChan chan struct{} stateUpdateChan chan struct{}
cancelStateUpdateChan chan struct{} cancelStateUpdateChan chan struct{}
// TODO(kradalby): Temporary measure to make sure we can update policy
// across modules, will be removed when aclRules are no longer stored
// globally but generated per node basis.
policyUpdateChan chan struct{}
cancelPolicyUpdateChan chan struct{}
} }
func NewHeadscale(cfg *Config) (*Headscale, error) { func NewHeadscale(cfg *Config) (*Headscale, error) {
@ -119,7 +128,7 @@ func NewHeadscale(cfg *Config) (*Headscale, error) {
var dbString string var dbString string
switch cfg.DBtype { switch cfg.DBtype {
case Postgres: case db.Postgres:
dbString = fmt.Sprintf( dbString = fmt.Sprintf(
"host=%s dbname=%s user=%s", "host=%s dbname=%s user=%s",
cfg.DBhost, cfg.DBhost,
@ -142,7 +151,7 @@ func NewHeadscale(cfg *Config) (*Headscale, error) {
if cfg.DBpass != "" { if cfg.DBpass != "" {
dbString += fmt.Sprintf(" password=%s", cfg.DBpass) dbString += fmt.Sprintf(" password=%s", cfg.DBpass)
} }
case Sqlite: case db.Sqlite:
dbString = cfg.DBpath dbString = cfg.DBpath
default: default:
return nil, errUnsupportedDatabase return nil, errUnsupportedDatabase
@ -166,23 +175,28 @@ func NewHeadscale(cfg *Config) (*Headscale, error) {
stateUpdateChan: make(chan struct{}), stateUpdateChan: make(chan struct{}),
cancelStateUpdateChan: make(chan struct{}), cancelStateUpdateChan: make(chan struct{}),
policyUpdateChan: make(chan struct{}),
cancelPolicyUpdateChan: make(chan struct{}),
} }
go app.watchStateChannel() go app.watchStateChannel()
go app.watchPolicyChannel()
db, err := NewHeadscaleDatabase( database, err := db.NewHeadscaleDatabase(
cfg.DBtype, cfg.DBtype,
dbString, dbString,
cfg.OIDC.StripEmaildomain, cfg.OIDC.StripEmaildomain,
app.dbDebug, app.dbDebug,
app.stateUpdateChan, app.stateUpdateChan,
app.policyUpdateChan,
cfg.IPPrefixes, cfg.IPPrefixes,
cfg.BaseDomain) cfg.BaseDomain)
if err != nil { if err != nil {
return nil, err return nil, err
} }
app.db = db app.db = database
if cfg.OIDC.Issuer != "" { if cfg.OIDC.Issuer != "" {
err = app.initOIDC() err = app.initOIDC()
@ -228,7 +242,7 @@ func (h *Headscale) redirect(w http.ResponseWriter, req *http.Request) {
func (h *Headscale) expireEphemeralNodes(milliSeconds int64) { func (h *Headscale) expireEphemeralNodes(milliSeconds int64) {
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond) ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
for range ticker.C { for range ticker.C {
h.expireEphemeralNodesWorker() h.db.ExpireEphemeralMachines(h.cfg.EphemeralNodeInactivityTimeout)
} }
} }
@ -237,112 +251,20 @@ func (h *Headscale) expireEphemeralNodes(milliSeconds int64) {
func (h *Headscale) expireExpiredMachines(milliSeconds int64) { func (h *Headscale) expireExpiredMachines(milliSeconds int64) {
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond) ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
for range ticker.C { for range ticker.C {
h.expireExpiredMachinesWorker() h.db.ExpireExpiredMachines(h.getLastStateChange())
} }
} }
func (h *Headscale) failoverSubnetRoutes(milliSeconds int64) { func (h *Headscale) failoverSubnetRoutes(milliSeconds int64) {
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond) ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
for range ticker.C { for range ticker.C {
err := h.db.handlePrimarySubnetFailover() err := h.db.HandlePrimarySubnetFailover()
if err != nil { if err != nil {
log.Error().Err(err).Msg("failed to handle primary subnet failover") log.Error().Err(err).Msg("failed to handle primary subnet failover")
} }
} }
} }
func (h *Headscale) expireEphemeralNodesWorker() {
users, err := h.db.ListUsers()
if err != nil {
log.Error().Err(err).Msg("Error listing users")
return
}
for _, user := range users {
machines, err := h.db.ListMachinesByUser(user.Name)
if err != nil {
log.Error().
Err(err).
Str("user", user.Name).
Msg("Error listing machines in user")
return
}
expiredFound := false
for _, machine := range machines {
if machine.isEphemeral() && machine.LastSeen != nil &&
time.Now().
After(machine.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) {
expiredFound = true
log.Info().
Str("machine", machine.Hostname).
Msg("Ephemeral client removed from database")
err = h.db.db.Unscoped().Delete(machine).Error
if err != nil {
log.Error().
Err(err).
Str("machine", machine.Hostname).
Msg("🤮 Cannot delete ephemeral machine from the database")
}
}
}
if expiredFound {
h.setLastStateChangeToNow()
}
}
}
func (h *Headscale) expireExpiredMachinesWorker() {
users, err := h.db.ListUsers()
if err != nil {
log.Error().Err(err).Msg("Error listing users")
return
}
for _, user := range users {
machines, err := h.db.ListMachinesByUser(user.Name)
if err != nil {
log.Error().
Err(err).
Str("user", user.Name).
Msg("Error listing machines in user")
return
}
expiredFound := false
for index, machine := range machines {
if machine.isExpired() &&
machine.Expiry.After(h.getLastStateChange(user)) {
expiredFound = true
err := h.db.ExpireMachine(&machines[index])
if err != nil {
log.Error().
Err(err).
Str("machine", machine.Hostname).
Str("name", machine.GivenName).
Msg("🤮 Cannot expire machine")
} else {
log.Info().
Str("machine", machine.Hostname).
Str("name", machine.GivenName).
Msg("Machine successfully expired")
}
}
}
if expiredFound {
h.setLastStateChangeToNow()
}
}
}
func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context, func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
req interface{}, req interface{},
info *grpc.UnaryServerInfo, info *grpc.UnaryServerInfo,
@ -565,6 +487,8 @@ func (h *Headscale) Serve() error {
go h.scheduledDERPMapUpdateWorker(derpMapCancelChannel) go h.scheduledDERPMapUpdateWorker(derpMapCancelChannel)
} }
// TODO(kradalby): These should have cancel channels and be cleaned
// up on shutdown.
go h.expireEphemeralNodes(updateInterval) go h.expireEphemeralNodes(updateInterval)
go h.expireExpiredMachines(updateInterval) go h.expireExpiredMachines(updateInterval)
@ -774,10 +698,12 @@ func (h *Headscale) Serve() error {
if h.cfg.ACL.PolicyPath != "" { if h.cfg.ACL.PolicyPath != "" {
aclPath := util.AbsolutePathFromConfigPath(h.cfg.ACL.PolicyPath) aclPath := util.AbsolutePathFromConfigPath(h.cfg.ACL.PolicyPath)
err := h.LoadACLPolicyFromPath(aclPath) pol, err := policy.LoadACLPolicyFromPath(aclPath)
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to reload ACL policy") log.Error().Err(err).Msg("Failed to reload ACL policy")
} }
h.ACLPolicy = pol
log.Info(). log.Info().
Str("path", aclPath). Str("path", aclPath).
Msg("ACL policy successfully reloaded, notifying nodes of change") Msg("ACL policy successfully reloaded, notifying nodes of change")
@ -824,12 +750,12 @@ func (h *Headscale) Serve() error {
close(h.stateUpdateChan) close(h.stateUpdateChan)
close(h.cancelStateUpdateChan) close(h.cancelStateUpdateChan)
<-h.cancelPolicyUpdateChan
close(h.policyUpdateChan)
close(h.cancelPolicyUpdateChan)
// Close db connections // Close db connections
db, err := h.db.db.DB() err = h.db.Close()
if err != nil {
log.Error().Err(err).Msg("Failed to get db handle")
}
err = db.Close()
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to close db") log.Error().Err(err).Msg("Failed to close db")
} }
@ -936,6 +862,30 @@ func (h *Headscale) watchStateChannel() {
} }
} }
// TODO(kradalby): baby steps, make this more robust.
func (h *Headscale) watchPolicyChannel() {
for {
select {
case <-h.policyUpdateChan:
machines, err := h.db.ListMachines()
if err != nil {
log.Error().Err(err).Msg("failed to fetch machines during policy update")
}
rules, sshPolicy, err := policy.GenerateFilterRules(h.ACLPolicy, machines, h.cfg.OIDC.StripEmaildomain)
if err != nil {
log.Error().Err(err).Msg("failed to update ACL rules")
}
h.aclRules = rules
h.sshPolicy = sshPolicy
case <-h.cancelPolicyUpdateChan:
return
}
}
}
func (h *Headscale) setLastStateChangeToNow() { func (h *Headscale) setLastStateChangeToNow() {
var err error var err error
@ -958,7 +908,7 @@ func (h *Headscale) setLastStateChangeToNow() {
} }
} }
func (h *Headscale) getLastStateChange(users ...User) time.Time { func (h *Headscale) getLastStateChange(users ...types.User) time.Time {
times := []time.Time{} times := []time.Time{}
// getLastStateChange takes a list of users as a "filter", if no users // getLastStateChange takes a list of users as a "filter", if no users

480
hscontrol/db/acls_test.go Normal file
View file

@ -0,0 +1,480 @@
package db
import (
"net/netip"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"gopkg.in/check.v1"
"tailscale.com/envknob"
"tailscale.com/tailcfg"
)
// TODO(kradalby):
// Convert these tests to being non-database dependent and table driven. They are
// very verbose, and dont really need the database.
func (s *Suite) TestSshRules(c *check.C) {
envknob.Setenv("HEADSCALE_EXPERIMENTAL_FEATURE_SSH", "1")
user, err := db.CreateUser("user1")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetMachine("user1", "testmachine")
c.Assert(err, check.NotNil)
hostInfo := tailcfg.Hostinfo{
OS: "centos",
Hostname: "testmachine",
RequestTags: []string{"tag:test"},
}
machine := types.Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testmachine",
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: types.HostInfo(hostInfo),
}
err = db.MachineSave(&machine)
c.Assert(err, check.IsNil)
aclPolicy := &policy.ACLPolicy{
Groups: policy.Groups{
"group:test": []string{"user1"},
},
Hosts: policy.Hosts{
"client": netip.PrefixFrom(netip.MustParseAddr("100.64.99.42"), 32),
},
ACLs: []policy.ACL{
{
Action: "accept",
Sources: []string{"*"},
Destinations: []string{"*:*"},
},
},
SSHs: []policy.SSH{
{
Action: "accept",
Sources: []string{"group:test"},
Destinations: []string{"client"},
Users: []string{"autogroup:nonroot"},
},
{
Action: "accept",
Sources: []string{"*"},
Destinations: []string{"client"},
Users: []string{"autogroup:nonroot"},
},
},
}
_, sshPolicy, err := policy.GenerateFilterRules(aclPolicy, types.Machines{}, false)
c.Assert(err, check.IsNil)
c.Assert(sshPolicy, check.NotNil)
c.Assert(sshPolicy.Rules, check.HasLen, 2)
c.Assert(sshPolicy.Rules[0].SSHUsers, check.HasLen, 1)
c.Assert(sshPolicy.Rules[0].Principals, check.HasLen, 1)
c.Assert(sshPolicy.Rules[0].Principals[0].UserLogin, check.Matches, "user1")
c.Assert(sshPolicy.Rules[1].SSHUsers, check.HasLen, 1)
c.Assert(sshPolicy.Rules[1].Principals, check.HasLen, 1)
c.Assert(sshPolicy.Rules[1].Principals[0].NodeIP, check.Matches, "*")
}
// this test should validate that we can expand a group in a TagOWner section and
// match properly the IP's of the related hosts. The owner is valid and the tag is also valid.
// the tag is matched in the Sources section.
func (s *Suite) TestValidExpandTagOwnersInSources(c *check.C) {
user, err := db.CreateUser("user1")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetMachine("user1", "testmachine")
c.Assert(err, check.NotNil)
hostInfo := tailcfg.Hostinfo{
OS: "centos",
Hostname: "testmachine",
RequestTags: []string{"tag:test"},
}
machine := types.Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testmachine",
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: types.HostInfo(hostInfo),
}
err = db.MachineSave(&machine)
c.Assert(err, check.IsNil)
pol := &policy.ACLPolicy{
Groups: policy.Groups{"group:test": []string{"user1", "user2"}},
TagOwners: policy.TagOwners{"tag:test": []string{"user3", "group:test"}},
ACLs: []policy.ACL{
{
Action: "accept",
Sources: []string{"tag:test"},
Destinations: []string{"*:*"},
},
},
}
machines, err := db.ListMachines()
c.Assert(err, check.IsNil)
rules, _, err := policy.GenerateFilterRules(pol, machines, false)
c.Assert(err, check.IsNil)
c.Assert(rules, check.HasLen, 1)
c.Assert(rules[0].SrcIPs, check.HasLen, 1)
c.Assert(rules[0].SrcIPs[0], check.Equals, "100.64.0.1/32")
}
// this test should validate that we can expand a group in a TagOWner section and
// match properly the IP's of the related hosts. The owner is valid and the tag is also valid.
// the tag is matched in the Destinations section.
func (s *Suite) TestValidExpandTagOwnersInDestinations(c *check.C) {
user, err := db.CreateUser("user1")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetMachine("user1", "testmachine")
c.Assert(err, check.NotNil)
hostInfo := tailcfg.Hostinfo{
OS: "centos",
Hostname: "testmachine",
RequestTags: []string{"tag:test"},
}
machine := types.Machine{
ID: 1,
MachineKey: "12345",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testmachine",
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: types.HostInfo(hostInfo),
}
err = db.MachineSave(&machine)
c.Assert(err, check.IsNil)
pol := &policy.ACLPolicy{
Groups: policy.Groups{"group:test": []string{"user1", "user2"}},
TagOwners: policy.TagOwners{"tag:test": []string{"user3", "group:test"}},
ACLs: []policy.ACL{
{
Action: "accept",
Sources: []string{"*"},
Destinations: []string{"tag:test:*"},
},
},
}
machines, err := db.ListMachines()
c.Assert(err, check.IsNil)
rules, _, err := policy.GenerateFilterRules(pol, machines, false)
c.Assert(err, check.IsNil)
c.Assert(rules, check.HasLen, 1)
c.Assert(rules[0].DstPorts, check.HasLen, 1)
c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")
}
// need a test with:
// tag on a host that isn't owned by a tag owners. So the user
// of the host should be valid.
func (s *Suite) TestInvalidTagValidUser(c *check.C) {
user, err := db.CreateUser("user1")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetMachine("user1", "testmachine")
c.Assert(err, check.NotNil)
hostInfo := tailcfg.Hostinfo{
OS: "centos",
Hostname: "testmachine",
RequestTags: []string{"tag:foo"},
}
machine := types.Machine{
ID: 1,
MachineKey: "12345",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testmachine",
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: types.HostInfo(hostInfo),
}
err = db.MachineSave(&machine)
c.Assert(err, check.IsNil)
pol := &policy.ACLPolicy{
TagOwners: policy.TagOwners{"tag:test": []string{"user1"}},
ACLs: []policy.ACL{
{
Action: "accept",
Sources: []string{"user1"},
Destinations: []string{"*:*"},
},
},
}
machines, err := db.ListMachines()
c.Assert(err, check.IsNil)
rules, _, err := policy.GenerateFilterRules(pol, machines, false)
c.Assert(err, check.IsNil)
c.Assert(rules, check.HasLen, 1)
c.Assert(rules[0].SrcIPs, check.HasLen, 1)
c.Assert(rules[0].SrcIPs[0], check.Equals, "100.64.0.1/32")
}
// tag on a host is owned by a tag owner, the tag is valid.
// an ACL rule is matching the tag to a user. It should not be valid since the
// host should be tied to the tag now.
func (s *Suite) TestValidTagInvalidUser(c *check.C) {
user, err := db.CreateUser("user1")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetMachine("user1", "webserver")
c.Assert(err, check.NotNil)
hostInfo := tailcfg.Hostinfo{
OS: "centos",
Hostname: "webserver",
RequestTags: []string{"tag:webapp"},
}
machine := types.Machine{
ID: 1,
MachineKey: "12345",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "webserver",
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: types.HostInfo(hostInfo),
}
err = db.MachineSave(&machine)
c.Assert(err, check.IsNil)
_, err = db.GetMachine("user1", "user")
hostInfo2 := tailcfg.Hostinfo{
OS: "debian",
Hostname: "Hostname",
}
c.Assert(err, check.NotNil)
machine = types.Machine{
ID: 2,
MachineKey: "56789",
NodeKey: "bar2",
DiscoKey: "faab",
Hostname: "user",
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.2")},
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: types.HostInfo(hostInfo2),
}
err = db.MachineSave(&machine)
c.Assert(err, check.IsNil)
pol := &policy.ACLPolicy{
TagOwners: policy.TagOwners{"tag:webapp": []string{"user1"}},
ACLs: []policy.ACL{
{
Action: "accept",
Sources: []string{"user1"},
Destinations: []string{"tag:webapp:80,443"},
},
},
}
machines, err := db.ListMachines()
c.Assert(err, check.IsNil)
rules, _, err := policy.GenerateFilterRules(pol, machines, false)
c.Assert(err, check.IsNil)
c.Assert(rules, check.HasLen, 1)
c.Assert(rules[0].SrcIPs, check.HasLen, 1)
c.Assert(rules[0].SrcIPs[0], check.Equals, "100.64.0.2/32")
c.Assert(rules[0].DstPorts, check.HasLen, 2)
c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(80))
c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(80))
c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")
c.Assert(rules[0].DstPorts[1].Ports.First, check.Equals, uint16(443))
c.Assert(rules[0].DstPorts[1].Ports.Last, check.Equals, uint16(443))
c.Assert(rules[0].DstPorts[1].IP, check.Equals, "100.64.0.1/32")
}
func (s *Suite) TestPortUser(c *check.C) {
user, err := db.CreateUser("testuser")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetMachine("testuser", "testmachine")
c.Assert(err, check.NotNil)
ips, _ := db.getAvailableIPs()
machine := types.Machine{
ID: 0,
MachineKey: "12345",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testmachine",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
IPAddresses: ips,
AuthKeyID: uint(pak.ID),
}
err = db.MachineSave(&machine)
c.Assert(err, check.IsNil)
acl := []byte(`
{
"hosts": {
"host-1": "100.100.100.100",
"subnet-1": "100.100.101.100/24",
},
"acls": [
{
"action": "accept",
"src": [
"testuser",
],
"dst": [
"host-1:*",
],
},
],
}
`)
pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson")
c.Assert(err, check.IsNil)
c.Assert(pol, check.NotNil)
machines, err := db.ListMachines()
c.Assert(err, check.IsNil)
rules, _, err := policy.GenerateFilterRules(pol, machines, false)
c.Assert(err, check.IsNil)
c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil)
c.Assert(rules, check.HasLen, 1)
c.Assert(rules[0].DstPorts, check.HasLen, 1)
c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0))
c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535))
c.Assert(rules[0].SrcIPs, check.HasLen, 1)
c.Assert(rules[0].SrcIPs[0], check.Not(check.Equals), "not an ip")
c.Assert(len(ips), check.Equals, 1)
c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String()+"/32")
}
func (s *Suite) TestPortGroup(c *check.C) {
user, err := db.CreateUser("testuser")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetMachine("testuser", "testmachine")
c.Assert(err, check.NotNil)
ips, _ := db.getAvailableIPs()
machine := types.Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testmachine",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
IPAddresses: ips,
AuthKeyID: uint(pak.ID),
}
err = db.MachineSave(&machine)
c.Assert(err, check.IsNil)
acl := []byte(`
{
"groups": {
"group:example": [
"testuser",
],
},
"hosts": {
"host-1": "100.100.100.100",
"subnet-1": "100.100.101.100/24",
},
"acls": [
{
"action": "accept",
"src": [
"group:example",
],
"dst": [
"host-1:*",
],
},
],
}
`)
pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson")
c.Assert(err, check.IsNil)
machines, err := db.ListMachines()
c.Assert(err, check.IsNil)
rules, _, err := policy.GenerateFilterRules(pol, machines, false)
c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil)
c.Assert(rules, check.HasLen, 1)
c.Assert(rules[0].DstPorts, check.HasLen, 1)
c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0))
c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535))
c.Assert(rules[0].SrcIPs, check.HasLen, 1)
c.Assert(rules[0].SrcIPs[0], check.Not(check.Equals), "not an ip")
c.Assert(len(ips), check.Equals, 1)
c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String()+"/32")
}

View file

@ -3,21 +3,22 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package hscontrol package db
import ( import (
"errors" "errors"
"fmt" "fmt"
"net/netip" "net/netip"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"go4.org/netipx" "go4.org/netipx"
) )
var ErrCouldNotAllocateIP = errors.New("could not find any suitable IP") var ErrCouldNotAllocateIP = errors.New("could not find any suitable IP")
func (hsdb *HSDatabase) getAvailableIPs() (MachineAddresses, error) { func (hsdb *HSDatabase) getAvailableIPs() (types.MachineAddresses, error) {
var ips MachineAddresses var ips types.MachineAddresses
var err error var err error
for _, ipPrefix := range hsdb.ipPrefixes { for _, ipPrefix := range hsdb.ipPrefixes {
var ip *netip.Addr var ip *netip.Addr
@ -68,11 +69,11 @@ func (hsdb *HSDatabase) getUsedIPs() (*netipx.IPSet, error) {
// but this was quick to get running and it should be enough // but this was quick to get running and it should be enough
// to begin experimenting with a dual stack tailnet. // to begin experimenting with a dual stack tailnet.
var addressesSlices []string var addressesSlices []string
hsdb.db.Model(&Machine{}).Pluck("ip_addresses", &addressesSlices) hsdb.db.Model(&types.Machine{}).Pluck("ip_addresses", &addressesSlices)
var ips netipx.IPSetBuilder var ips netipx.IPSetBuilder
for _, slice := range addressesSlices { for _, slice := range addressesSlices {
var machineAddresses MachineAddresses var machineAddresses types.MachineAddresses
err := machineAddresses.Scan(slice) err := machineAddresses.Scan(slice)
if err != nil { if err != nil {
return &netipx.IPSet{}, fmt.Errorf( return &netipx.IPSet{}, fmt.Errorf(

View file

@ -1,14 +1,16 @@
package hscontrol package db
import ( import (
"net/netip" "net/netip"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"go4.org/netipx" "go4.org/netipx"
"gopkg.in/check.v1" "gopkg.in/check.v1"
) )
func (s *Suite) TestGetAvailableIp(c *check.C) { func (s *Suite) TestGetAvailableIp(c *check.C) {
ips, err := app.db.getAvailableIPs() ips, err := db.getAvailableIPs()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -19,32 +21,32 @@ func (s *Suite) TestGetAvailableIp(c *check.C) {
} }
func (s *Suite) TestGetUsedIps(c *check.C) { func (s *Suite) TestGetUsedIps(c *check.C) {
ips, err := app.db.getAvailableIPs() ips, err := db.getAvailableIPs()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
user, err := app.db.CreateUser("test-ip") user, err := db.CreateUser("test-ip")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = app.db.GetMachine("test", "testmachine") _, err = db.GetMachine("test", "testmachine")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
machine := Machine{ machine := types.Machine{
ID: 0, ID: 0,
MachineKey: "foo", MachineKey: "foo",
NodeKey: "bar", NodeKey: "bar",
DiscoKey: "faa", DiscoKey: "faa",
Hostname: "testmachine", Hostname: "testmachine",
UserID: user.ID, UserID: user.ID,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
IPAddresses: ips, IPAddresses: ips,
} }
app.db.db.Save(&machine) db.db.Save(&machine)
usedIps, err := app.db.getUsedIPs() usedIps, err := db.getUsedIPs()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -56,46 +58,48 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
c.Assert(usedIps.Equal(expectedIPSet), check.Equals, true) c.Assert(usedIps.Equal(expectedIPSet), check.Equals, true)
c.Assert(usedIps.Contains(expected), check.Equals, true) c.Assert(usedIps.Contains(expected), check.Equals, true)
machine1, err := app.db.GetMachineByID(0) machine1, err := db.GetMachineByID(0)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(machine1.IPAddresses), check.Equals, 1) c.Assert(len(machine1.IPAddresses), check.Equals, 1)
c.Assert(machine1.IPAddresses[0], check.Equals, expected) c.Assert(machine1.IPAddresses[0], check.Equals, expected)
c.Assert(channelUpdates, check.Equals, int32(0))
} }
func (s *Suite) TestGetMultiIp(c *check.C) { func (s *Suite) TestGetMultiIp(c *check.C) {
user, err := app.db.CreateUser("test-ip-multi") user, err := db.CreateUser("test-ip-multi")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
for index := 1; index <= 350; index++ { for index := 1; index <= 350; index++ {
app.db.ipAllocationMutex.Lock() db.ipAllocationMutex.Lock()
ips, err := app.db.getAvailableIPs() ips, err := db.getAvailableIPs()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = app.db.GetMachine("test", "testmachine") _, err = db.GetMachine("test", "testmachine")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
machine := Machine{ machine := types.Machine{
ID: uint64(index), ID: uint64(index),
MachineKey: "foo", MachineKey: "foo",
NodeKey: "bar", NodeKey: "bar",
DiscoKey: "faa", DiscoKey: "faa",
Hostname: "testmachine", Hostname: "testmachine",
UserID: user.ID, UserID: user.ID,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
IPAddresses: ips, IPAddresses: ips,
} }
app.db.db.Save(&machine) db.db.Save(&machine)
app.db.ipAllocationMutex.Unlock() db.ipAllocationMutex.Unlock()
} }
usedIps, err := app.db.getUsedIPs() usedIps, err := db.getUsedIPs()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
expected0 := netip.MustParseAddr("10.27.0.1") expected0 := netip.MustParseAddr("10.27.0.1")
@ -117,7 +121,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
c.Assert(usedIps.Contains(expected300), check.Equals, true) c.Assert(usedIps.Contains(expected300), check.Equals, true)
// Check that we can read back the IPs // Check that we can read back the IPs
machine1, err := app.db.GetMachineByID(1) machine1, err := db.GetMachineByID(1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(machine1.IPAddresses), check.Equals, 1) c.Assert(len(machine1.IPAddresses), check.Equals, 1)
c.Assert( c.Assert(
@ -126,7 +130,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
netip.MustParseAddr("10.27.0.1"), netip.MustParseAddr("10.27.0.1"),
) )
machine50, err := app.db.GetMachineByID(50) machine50, err := db.GetMachineByID(50)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(machine50.IPAddresses), check.Equals, 1) c.Assert(len(machine50.IPAddresses), check.Equals, 1)
c.Assert( c.Assert(
@ -136,7 +140,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
) )
expectedNextIP := netip.MustParseAddr("10.27.1.95") expectedNextIP := netip.MustParseAddr("10.27.1.95")
nextIP, err := app.db.getAvailableIPs() nextIP, err := db.getAvailableIPs()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(nextIP), check.Equals, 1) c.Assert(len(nextIP), check.Equals, 1)
@ -144,15 +148,17 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
// If we call get Available again, we should receive // If we call get Available again, we should receive
// the same IP, as it has not been reserved. // the same IP, as it has not been reserved.
nextIP2, err := app.db.getAvailableIPs() nextIP2, err := db.getAvailableIPs()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(nextIP2), check.Equals, 1) c.Assert(len(nextIP2), check.Equals, 1)
c.Assert(nextIP2[0].String(), check.Equals, expectedNextIP.String()) c.Assert(nextIP2[0].String(), check.Equals, expectedNextIP.String())
c.Assert(channelUpdates, check.Equals, int32(0))
} }
func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) { func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
ips, err := app.db.getAvailableIPs() ips, err := db.getAvailableIPs()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
expected := netip.MustParseAddr("10.27.0.1") expected := netip.MustParseAddr("10.27.0.1")
@ -160,30 +166,32 @@ func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
c.Assert(len(ips), check.Equals, 1) c.Assert(len(ips), check.Equals, 1)
c.Assert(ips[0].String(), check.Equals, expected.String()) c.Assert(ips[0].String(), check.Equals, expected.String())
user, err := app.db.CreateUser("test-ip") user, err := db.CreateUser("test-ip")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = app.db.GetMachine("test", "testmachine") _, err = db.GetMachine("test", "testmachine")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
machine := Machine{ machine := types.Machine{
ID: 0, ID: 0,
MachineKey: "foo", MachineKey: "foo",
NodeKey: "bar", NodeKey: "bar",
DiscoKey: "faa", DiscoKey: "faa",
Hostname: "testmachine", Hostname: "testmachine",
UserID: user.ID, UserID: user.ID,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
} }
app.db.db.Save(&machine) db.db.Save(&machine)
ips2, err := app.db.getAvailableIPs() ips2, err := db.getAvailableIPs()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(ips2), check.Equals, 1) c.Assert(len(ips2), check.Equals, 1)
c.Assert(ips2[0].String(), check.Equals, expected.String()) c.Assert(ips2[0].String(), check.Equals, expected.String())
c.Assert(channelUpdates, check.Equals, int32(0))
} }

View file

@ -1,4 +1,4 @@
package hscontrol package db
import ( import (
"errors" "errors"
@ -6,10 +6,9 @@ import (
"strings" "strings"
"time" "time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"google.golang.org/protobuf/types/known/timestamppb"
) )
const ( const (
@ -19,22 +18,10 @@ const (
var ErrAPIKeyFailedToParse = errors.New("failed to parse ApiKey") var ErrAPIKeyFailedToParse = errors.New("failed to parse ApiKey")
// APIKey describes the datamodel for API keys used to remotely authenticate with
// headscale.
type APIKey struct {
ID uint64 `gorm:"primary_key"`
Prefix string `gorm:"uniqueIndex"`
Hash []byte
CreatedAt *time.Time
Expiration *time.Time
LastSeen *time.Time
}
// CreateAPIKey creates a new ApiKey in a user, and returns it. // CreateAPIKey creates a new ApiKey in a user, and returns it.
func (hsdb *HSDatabase) CreateAPIKey( func (hsdb *HSDatabase) CreateAPIKey(
expiration *time.Time, expiration *time.Time,
) (string, *APIKey, error) { ) (string, *types.APIKey, error) {
prefix, err := util.GenerateRandomStringURLSafe(apiPrefixLength) prefix, err := util.GenerateRandomStringURLSafe(apiPrefixLength)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
@ -53,7 +40,7 @@ func (hsdb *HSDatabase) CreateAPIKey(
return "", nil, err return "", nil, err
} }
key := APIKey{ key := types.APIKey{
Prefix: prefix, Prefix: prefix,
Hash: hash, Hash: hash,
Expiration: expiration, Expiration: expiration,
@ -67,8 +54,8 @@ func (hsdb *HSDatabase) CreateAPIKey(
} }
// ListAPIKeys returns the list of ApiKeys for a user. // ListAPIKeys returns the list of ApiKeys for a user.
func (hsdb *HSDatabase) ListAPIKeys() ([]APIKey, error) { func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) {
keys := []APIKey{} keys := []types.APIKey{}
if err := hsdb.db.Find(&keys).Error; err != nil { if err := hsdb.db.Find(&keys).Error; err != nil {
return nil, err return nil, err
} }
@ -77,8 +64,8 @@ func (hsdb *HSDatabase) ListAPIKeys() ([]APIKey, error) {
} }
// GetAPIKey returns a ApiKey for a given key. // GetAPIKey returns a ApiKey for a given key.
func (hsdb *HSDatabase) GetAPIKey(prefix string) (*APIKey, error) { func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) {
key := APIKey{} key := types.APIKey{}
if result := hsdb.db.First(&key, "prefix = ?", prefix); result.Error != nil { if result := hsdb.db.First(&key, "prefix = ?", prefix); result.Error != nil {
return nil, result.Error return nil, result.Error
} }
@ -87,9 +74,9 @@ func (hsdb *HSDatabase) GetAPIKey(prefix string) (*APIKey, error) {
} }
// GetAPIKeyByID returns a ApiKey for a given id. // GetAPIKeyByID returns a ApiKey for a given id.
func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*APIKey, error) { func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) {
key := APIKey{} key := types.APIKey{}
if result := hsdb.db.Find(&APIKey{ID: id}).First(&key); result.Error != nil { if result := hsdb.db.Find(&types.APIKey{ID: id}).First(&key); result.Error != nil {
return nil, result.Error return nil, result.Error
} }
@ -98,7 +85,7 @@ func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*APIKey, error) {
// DestroyAPIKey destroys a ApiKey. Returns error if the ApiKey // DestroyAPIKey destroys a ApiKey. Returns error if the ApiKey
// does not exist. // does not exist.
func (hsdb *HSDatabase) DestroyAPIKey(key APIKey) error { func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error {
if result := hsdb.db.Unscoped().Delete(key); result.Error != nil { if result := hsdb.db.Unscoped().Delete(key); result.Error != nil {
return result.Error return result.Error
} }
@ -107,7 +94,7 @@ func (hsdb *HSDatabase) DestroyAPIKey(key APIKey) error {
} }
// ExpireAPIKey marks a ApiKey as expired. // ExpireAPIKey marks a ApiKey as expired.
func (hsdb *HSDatabase) ExpireAPIKey(key *APIKey) error { func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error {
if err := hsdb.db.Model(&key).Update("Expiration", time.Now()).Error; err != nil { if err := hsdb.db.Model(&key).Update("Expiration", time.Now()).Error; err != nil {
return err return err
} }
@ -136,24 +123,3 @@ func (hsdb *HSDatabase) ValidateAPIKey(keyStr string) (bool, error) {
return true, nil return true, nil
} }
func (key *APIKey) toProto() *v1.ApiKey {
protoKey := v1.ApiKey{
Id: key.ID,
Prefix: key.Prefix,
}
if key.Expiration != nil {
protoKey.Expiration = timestamppb.New(*key.Expiration)
}
if key.CreatedAt != nil {
protoKey.CreatedAt = timestamppb.New(*key.CreatedAt)
}
if key.LastSeen != nil {
protoKey.LastSeen = timestamppb.New(*key.LastSeen)
}
return &protoKey
}

View file

@ -1,4 +1,4 @@
package hscontrol package db
import ( import (
"time" "time"
@ -7,7 +7,7 @@ import (
) )
func (*Suite) TestCreateAPIKey(c *check.C) { func (*Suite) TestCreateAPIKey(c *check.C) {
apiKeyStr, apiKey, err := app.db.CreateAPIKey(nil) apiKeyStr, apiKey, err := db.CreateAPIKey(nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(apiKey, check.NotNil) c.Assert(apiKey, check.NotNil)
@ -16,74 +16,82 @@ func (*Suite) TestCreateAPIKey(c *check.C) {
c.Assert(apiKey.Hash, check.NotNil) c.Assert(apiKey.Hash, check.NotNil)
c.Assert(apiKeyStr, check.Not(check.Equals), "") c.Assert(apiKeyStr, check.Not(check.Equals), "")
_, err = app.db.ListAPIKeys() _, err = db.ListAPIKeys()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
keys, err := app.db.ListAPIKeys() keys, err := db.ListAPIKeys()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(keys), check.Equals, 1) c.Assert(len(keys), check.Equals, 1)
c.Assert(channelUpdates, check.Equals, int32(0))
} }
func (*Suite) TestAPIKeyDoesNotExist(c *check.C) { func (*Suite) TestAPIKeyDoesNotExist(c *check.C) {
key, err := app.db.GetAPIKey("does-not-exist") key, err := db.GetAPIKey("does-not-exist")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
c.Assert(key, check.IsNil) c.Assert(key, check.IsNil)
} }
func (*Suite) TestValidateAPIKeyOk(c *check.C) { func (*Suite) TestValidateAPIKeyOk(c *check.C) {
nowPlus2 := time.Now().Add(2 * time.Hour) nowPlus2 := time.Now().Add(2 * time.Hour)
apiKeyStr, apiKey, err := app.db.CreateAPIKey(&nowPlus2) apiKeyStr, apiKey, err := db.CreateAPIKey(&nowPlus2)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(apiKey, check.NotNil) c.Assert(apiKey, check.NotNil)
valid, err := app.db.ValidateAPIKey(apiKeyStr) valid, err := db.ValidateAPIKey(apiKeyStr)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(valid, check.Equals, true) c.Assert(valid, check.Equals, true)
c.Assert(channelUpdates, check.Equals, int32(0))
} }
func (*Suite) TestValidateAPIKeyNotOk(c *check.C) { func (*Suite) TestValidateAPIKeyNotOk(c *check.C) {
nowMinus2 := time.Now().Add(time.Duration(-2) * time.Hour) nowMinus2 := time.Now().Add(time.Duration(-2) * time.Hour)
apiKeyStr, apiKey, err := app.db.CreateAPIKey(&nowMinus2) apiKeyStr, apiKey, err := db.CreateAPIKey(&nowMinus2)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(apiKey, check.NotNil) c.Assert(apiKey, check.NotNil)
valid, err := app.db.ValidateAPIKey(apiKeyStr) valid, err := db.ValidateAPIKey(apiKeyStr)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(valid, check.Equals, false) c.Assert(valid, check.Equals, false)
now := time.Now() now := time.Now()
apiKeyStrNow, apiKey, err := app.db.CreateAPIKey(&now) apiKeyStrNow, apiKey, err := db.CreateAPIKey(&now)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(apiKey, check.NotNil) c.Assert(apiKey, check.NotNil)
validNow, err := app.db.ValidateAPIKey(apiKeyStrNow) validNow, err := db.ValidateAPIKey(apiKeyStrNow)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(validNow, check.Equals, false) c.Assert(validNow, check.Equals, false)
validSilly, err := app.db.ValidateAPIKey("nota.validkey") validSilly, err := db.ValidateAPIKey("nota.validkey")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
c.Assert(validSilly, check.Equals, false) c.Assert(validSilly, check.Equals, false)
validWithErr, err := app.db.ValidateAPIKey("produceerrorkey") validWithErr, err := db.ValidateAPIKey("produceerrorkey")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
c.Assert(validWithErr, check.Equals, false) c.Assert(validWithErr, check.Equals, false)
c.Assert(channelUpdates, check.Equals, int32(0))
} }
func (*Suite) TestExpireAPIKey(c *check.C) { func (*Suite) TestExpireAPIKey(c *check.C) {
nowPlus2 := time.Now().Add(2 * time.Hour) nowPlus2 := time.Now().Add(2 * time.Hour)
apiKeyStr, apiKey, err := app.db.CreateAPIKey(&nowPlus2) apiKeyStr, apiKey, err := db.CreateAPIKey(&nowPlus2)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(apiKey, check.NotNil) c.Assert(apiKey, check.NotNil)
valid, err := app.db.ValidateAPIKey(apiKeyStr) valid, err := db.ValidateAPIKey(apiKeyStr)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(valid, check.Equals, true) c.Assert(valid, check.Equals, true)
err = app.db.ExpireAPIKey(apiKey) err = db.ExpireAPIKey(apiKey)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(apiKey.Expiration, check.NotNil) c.Assert(apiKey.Expiration, check.NotNil)
notValid, err := app.db.ValidateAPIKey(apiKeyStr) notValid, err := db.ValidateAPIKey(apiKeyStr)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(notValid, check.Equals, false) c.Assert(notValid, check.Equals, false)
c.Assert(channelUpdates, check.Equals, int32(0))
} }

View file

@ -1,9 +1,7 @@
package hscontrol package db
import ( import (
"context" "context"
"database/sql/driver"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"net/netip" "net/netip"
@ -11,11 +9,12 @@ import (
"time" "time"
"github.com/glebarez/sqlite" "github.com/glebarez/sqlite"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"gorm.io/driver/postgres" "gorm.io/driver/postgres"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/logger" "gorm.io/gorm/logger"
"tailscale.com/tailcfg"
) )
const ( const (
@ -26,7 +25,6 @@ const (
var ( var (
errValueNotFound = errors.New("not found") errValueNotFound = errors.New("not found")
ErrCannotParsePrefix = errors.New("cannot parse prefix")
errDatabaseNotSupported = errors.New("database type not supported") errDatabaseNotSupported = errors.New("database type not supported")
) )
@ -40,6 +38,7 @@ type KV struct {
type HSDatabase struct { type HSDatabase struct {
db *gorm.DB db *gorm.DB
notifyStateChan chan<- struct{} notifyStateChan chan<- struct{}
notifyPolicyChan chan<- struct{}
ipAllocationMutex sync.Mutex ipAllocationMutex sync.Mutex
@ -54,6 +53,7 @@ func NewHeadscaleDatabase(
dbType, connectionAddr string, dbType, connectionAddr string,
stripEmailDomain, debug bool, stripEmailDomain, debug bool,
notifyStateChan chan<- struct{}, notifyStateChan chan<- struct{},
notifyPolicyChan chan<- struct{},
ipPrefixes []netip.Prefix, ipPrefixes []netip.Prefix,
baseDomain string, baseDomain string,
) (*HSDatabase, error) { ) (*HSDatabase, error) {
@ -65,6 +65,7 @@ func NewHeadscaleDatabase(
db := HSDatabase{ db := HSDatabase{
db: dbConn, db: dbConn,
notifyStateChan: notifyStateChan, notifyStateChan: notifyStateChan,
notifyPolicyChan: notifyPolicyChan,
ipPrefixes: ipPrefixes, ipPrefixes: ipPrefixes,
baseDomain: baseDomain, baseDomain: baseDomain,
@ -79,30 +80,30 @@ func NewHeadscaleDatabase(
_ = dbConn.Migrator().RenameTable("namespaces", "users") _ = dbConn.Migrator().RenameTable("namespaces", "users")
err = dbConn.AutoMigrate(User{}) err = dbConn.AutoMigrate(types.User{})
if err != nil { if err != nil {
return nil, err return nil, err
} }
_ = dbConn.Migrator().RenameColumn(&Machine{}, "namespace_id", "user_id") _ = dbConn.Migrator().RenameColumn(&types.Machine{}, "namespace_id", "user_id")
_ = dbConn.Migrator().RenameColumn(&PreAuthKey{}, "namespace_id", "user_id") _ = dbConn.Migrator().RenameColumn(&types.PreAuthKey{}, "namespace_id", "user_id")
_ = dbConn.Migrator().RenameColumn(&Machine{}, "ip_address", "ip_addresses") _ = dbConn.Migrator().RenameColumn(&types.Machine{}, "ip_address", "ip_addresses")
_ = dbConn.Migrator().RenameColumn(&Machine{}, "name", "hostname") _ = dbConn.Migrator().RenameColumn(&types.Machine{}, "name", "hostname")
// GivenName is used as the primary source of DNS names, make sure // GivenName is used as the primary source of DNS names, make sure
// the field is populated and normalized if it was not when the // the field is populated and normalized if it was not when the
// machine was registered. // machine was registered.
_ = dbConn.Migrator().RenameColumn(&Machine{}, "nickname", "given_name") _ = dbConn.Migrator().RenameColumn(&types.Machine{}, "nickname", "given_name")
// If the Machine table has a column for registered, // If the Machine table has a column for registered,
// find all occourences of "false" and drop them. Then // find all occourences of "false" and drop them. Then
// remove the column. // remove the column.
if dbConn.Migrator().HasColumn(&Machine{}, "registered") { if dbConn.Migrator().HasColumn(&types.Machine{}, "registered") {
log.Info(). log.Info().
Msg(`Database has legacy "registered" column in machine, removing...`) Msg(`Database has legacy "registered" column in machine, removing...`)
machines := Machines{} machines := types.Machines{}
if err := dbConn.Not("registered").Find(&machines).Error; err != nil { if err := dbConn.Not("registered").Find(&machines).Error; err != nil {
log.Error().Err(err).Msg("Error accessing db") log.Error().Err(err).Msg("Error accessing db")
} }
@ -112,7 +113,7 @@ func NewHeadscaleDatabase(
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Str("machine_key", machine.MachineKey). Str("machine_key", machine.MachineKey).
Msg("Deleting unregistered machine") Msg("Deleting unregistered machine")
if err := dbConn.Delete(&Machine{}, machine.ID).Error; err != nil { if err := dbConn.Delete(&types.Machine{}, machine.ID).Error; err != nil {
log.Error(). log.Error().
Err(err). Err(err).
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
@ -121,23 +122,23 @@ func NewHeadscaleDatabase(
} }
} }
err := dbConn.Migrator().DropColumn(&Machine{}, "registered") err := dbConn.Migrator().DropColumn(&types.Machine{}, "registered")
if err != nil { if err != nil {
log.Error().Err(err).Msg("Error dropping registered column") log.Error().Err(err).Msg("Error dropping registered column")
} }
} }
err = dbConn.AutoMigrate(&Route{}) err = dbConn.AutoMigrate(&types.Route{})
if err != nil { if err != nil {
return nil, err return nil, err
} }
if dbConn.Migrator().HasColumn(&Machine{}, "enabled_routes") { if dbConn.Migrator().HasColumn(&types.Machine{}, "enabled_routes") {
log.Info().Msgf("Database has legacy enabled_routes column in machine, migrating...") log.Info().Msgf("Database has legacy enabled_routes column in machine, migrating...")
type MachineAux struct { type MachineAux struct {
ID uint64 ID uint64
EnabledRoutes IPPrefixes EnabledRoutes types.IPPrefixes
} }
machinesAux := []MachineAux{} machinesAux := []MachineAux{}
@ -157,8 +158,8 @@ func NewHeadscaleDatabase(
} }
err = dbConn.Preload("Machine"). err = dbConn.Preload("Machine").
Where("machine_id = ? AND prefix = ?", machine.ID, IPPrefix(prefix)). Where("machine_id = ? AND prefix = ?", machine.ID, types.IPPrefix(prefix)).
First(&Route{}). First(&types.Route{}).
Error Error
if err == nil { if err == nil {
log.Info(). log.Info().
@ -168,11 +169,11 @@ func NewHeadscaleDatabase(
continue continue
} }
route := Route{ route := types.Route{
MachineID: machine.ID, MachineID: machine.ID,
Advertised: true, Advertised: true,
Enabled: true, Enabled: true,
Prefix: IPPrefix(prefix), Prefix: types.IPPrefix(prefix),
} }
if err := dbConn.Create(&route).Error; err != nil { if err := dbConn.Create(&route).Error; err != nil {
log.Error().Err(err).Msg("Error creating route") log.Error().Err(err).Msg("Error creating route")
@ -185,26 +186,26 @@ func NewHeadscaleDatabase(
} }
} }
err = dbConn.Migrator().DropColumn(&Machine{}, "enabled_routes") err = dbConn.Migrator().DropColumn(&types.Machine{}, "enabled_routes")
if err != nil { if err != nil {
log.Error().Err(err).Msg("Error dropping enabled_routes column") log.Error().Err(err).Msg("Error dropping enabled_routes column")
} }
} }
err = dbConn.AutoMigrate(&Machine{}) err = dbConn.AutoMigrate(&types.Machine{})
if err != nil { if err != nil {
return nil, err return nil, err
} }
if dbConn.Migrator().HasColumn(&Machine{}, "given_name") { if dbConn.Migrator().HasColumn(&types.Machine{}, "given_name") {
machines := Machines{} machines := types.Machines{}
if err := dbConn.Find(&machines).Error; err != nil { if err := dbConn.Find(&machines).Error; err != nil {
log.Error().Err(err).Msg("Error accessing db") log.Error().Err(err).Msg("Error accessing db")
} }
for item, machine := range machines { for item, machine := range machines {
if machine.GivenName == "" { if machine.GivenName == "" {
normalizedHostname, err := NormalizeToFQDNRules( normalizedHostname, err := util.NormalizeToFQDNRules(
machine.Hostname, machine.Hostname,
stripEmailDomain, stripEmailDomain,
) )
@ -233,19 +234,19 @@ func NewHeadscaleDatabase(
return nil, err return nil, err
} }
err = dbConn.AutoMigrate(&PreAuthKey{}) err = dbConn.AutoMigrate(&types.PreAuthKey{})
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = dbConn.AutoMigrate(&PreAuthKeyACLTag{}) err = dbConn.AutoMigrate(&types.PreAuthKeyACLTag{})
if err != nil { if err != nil {
return nil, err return nil, err
} }
_ = dbConn.Migrator().DropTable("shared_machines") _ = dbConn.Migrator().DropTable("shared_machines")
err = dbConn.AutoMigrate(&APIKey{}) err = dbConn.AutoMigrate(&types.APIKey{})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -339,7 +340,7 @@ func (hsdb *HSDatabase) setValue(key string, value string) error {
return nil return nil
} }
func (hsdb *HSDatabase) pingDB(ctx context.Context) error { func (hsdb *HSDatabase) PingDB(ctx context.Context) error {
ctx, cancel := context.WithTimeout(ctx, time.Second) ctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel() defer cancel()
sqlDB, err := hsdb.db.DB() sqlDB, err := hsdb.db.DB()
@ -350,97 +351,11 @@ func (hsdb *HSDatabase) pingDB(ctx context.Context) error {
return sqlDB.PingContext(ctx) return sqlDB.PingContext(ctx)
} }
// This is a "wrapper" type around tailscales func (hsdb *HSDatabase) Close() error {
// Hostinfo to allow us to add database "serialization" db, err := hsdb.db.DB()
// methods. This allows us to use a typed values throughout
// the code and not have to marshal/unmarshal and error
// check all over the code.
type HostInfo tailcfg.Hostinfo
func (hi *HostInfo) Scan(destination interface{}) error {
switch value := destination.(type) {
case []byte:
return json.Unmarshal(value, hi)
case string:
return json.Unmarshal([]byte(value), hi)
default:
return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination)
}
}
// Value return json value, implement driver.Valuer interface.
func (hi HostInfo) Value() (driver.Value, error) {
bytes, err := json.Marshal(hi)
return string(bytes), err
}
type IPPrefix netip.Prefix
func (i *IPPrefix) Scan(destination interface{}) error {
switch value := destination.(type) {
case string:
prefix, err := netip.ParsePrefix(value)
if err != nil { if err != nil {
return err return err
} }
*i = IPPrefix(prefix)
return nil return db.Close()
default:
return fmt.Errorf("%w: unexpected data type %T", ErrCannotParsePrefix, destination)
}
}
// Value return json value, implement driver.Valuer interface.
func (i IPPrefix) Value() (driver.Value, error) {
prefixStr := netip.Prefix(i).String()
return prefixStr, nil
}
type IPPrefixes []netip.Prefix
func (i *IPPrefixes) Scan(destination interface{}) error {
switch value := destination.(type) {
case []byte:
return json.Unmarshal(value, i)
case string:
return json.Unmarshal([]byte(value), i)
default:
return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination)
}
}
// Value return json value, implement driver.Valuer interface.
func (i IPPrefixes) Value() (driver.Value, error) {
bytes, err := json.Marshal(i)
return string(bytes), err
}
type StringList []string
func (i *StringList) Scan(destination interface{}) error {
switch value := destination.(type) {
case []byte:
return json.Unmarshal(value, i)
case string:
return json.Unmarshal([]byte(value), i)
default:
return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination)
}
}
// Value return json value, implement driver.Valuer interface.
func (i StringList) Value() (driver.Value, error) {
bytes, err := json.Marshal(i)
return string(bytes), err
} }

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,797 @@
package db
import (
"fmt"
"net/netip"
"regexp"
"strconv"
"testing"
"time"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"gopkg.in/check.v1"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
func (s *Suite) TestGetMachine(c *check.C) {
user, err := db.CreateUser("test")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetMachine("test", "testmachine")
c.Assert(err, check.NotNil)
machine := &types.Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testmachine",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
db.db.Save(machine)
_, err = db.GetMachine("test", "testmachine")
c.Assert(err, check.IsNil)
c.Assert(channelUpdates, check.Equals, int32(0))
}
func (s *Suite) TestGetMachineByID(c *check.C) {
user, err := db.CreateUser("test")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetMachineByID(0)
c.Assert(err, check.NotNil)
machine := types.Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testmachine",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
db.db.Save(&machine)
_, err = db.GetMachineByID(0)
c.Assert(err, check.IsNil)
c.Assert(channelUpdates, check.Equals, int32(0))
}
func (s *Suite) TestGetMachineByNodeKey(c *check.C) {
user, err := db.CreateUser("test")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetMachineByID(0)
c.Assert(err, check.NotNil)
nodeKey := key.NewNode()
machineKey := key.NewMachine()
machine := types.Machine{
ID: 0,
MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()),
NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()),
DiscoKey: "faa",
Hostname: "testmachine",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
db.db.Save(&machine)
_, err = db.GetMachineByNodeKey(nodeKey.Public())
c.Assert(err, check.IsNil)
c.Assert(channelUpdates, check.Equals, int32(0))
}
func (s *Suite) TestGetMachineByAnyNodeKey(c *check.C) {
user, err := db.CreateUser("test")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetMachineByID(0)
c.Assert(err, check.NotNil)
nodeKey := key.NewNode()
oldNodeKey := key.NewNode()
machineKey := key.NewMachine()
machine := types.Machine{
ID: 0,
MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()),
NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()),
DiscoKey: "faa",
Hostname: "testmachine",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
db.db.Save(&machine)
_, err = db.GetMachineByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public())
c.Assert(err, check.IsNil)
c.Assert(channelUpdates, check.Equals, int32(0))
}
func (s *Suite) TestDeleteMachine(c *check.C) {
user, err := db.CreateUser("test")
c.Assert(err, check.IsNil)
machine := types.Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testmachine",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(1),
}
db.db.Save(&machine)
err = db.DeleteMachine(&machine)
c.Assert(err, check.IsNil)
_, err = db.GetMachine(user.Name, "testmachine")
c.Assert(err, check.NotNil)
c.Assert(channelUpdates, check.Equals, int32(0))
}
func (s *Suite) TestHardDeleteMachine(c *check.C) {
user, err := db.CreateUser("test")
c.Assert(err, check.IsNil)
machine := types.Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testmachine3",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(1),
}
db.db.Save(&machine)
err = db.HardDeleteMachine(&machine)
c.Assert(err, check.IsNil)
_, err = db.GetMachine(user.Name, "testmachine3")
c.Assert(err, check.NotNil)
c.Assert(channelUpdates, check.Equals, int32(0))
}
func (s *Suite) TestListPeers(c *check.C) {
user, err := db.CreateUser("test")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetMachineByID(0)
c.Assert(err, check.NotNil)
for index := 0; index <= 10; index++ {
machine := types.Machine{
ID: uint64(index),
MachineKey: "foo" + strconv.Itoa(index),
NodeKey: "bar" + strconv.Itoa(index),
DiscoKey: "faa" + strconv.Itoa(index),
Hostname: "testmachine" + strconv.Itoa(index),
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
db.db.Save(&machine)
}
machine0ByID, err := db.GetMachineByID(0)
c.Assert(err, check.IsNil)
peersOfMachine0, err := db.ListPeers(machine0ByID)
c.Assert(err, check.IsNil)
c.Assert(len(peersOfMachine0), check.Equals, 9)
c.Assert(peersOfMachine0[0].Hostname, check.Equals, "testmachine2")
c.Assert(peersOfMachine0[5].Hostname, check.Equals, "testmachine7")
c.Assert(peersOfMachine0[8].Hostname, check.Equals, "testmachine10")
c.Assert(channelUpdates, check.Equals, int32(0))
}
func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
type base struct {
user *types.User
key *types.PreAuthKey
}
stor := make([]base, 0)
for _, name := range []string{"test", "admin"} {
user, err := db.CreateUser(name)
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
stor = append(stor, base{user, pak})
}
_, err := db.GetMachineByID(0)
c.Assert(err, check.NotNil)
for index := 0; index <= 10; index++ {
machine := types.Machine{
ID: uint64(index),
MachineKey: "foo" + strconv.Itoa(index),
NodeKey: "bar" + strconv.Itoa(index),
DiscoKey: "faa" + strconv.Itoa(index),
IPAddresses: types.MachineAddresses{
netip.MustParseAddr(fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1))),
},
Hostname: "testmachine" + strconv.Itoa(index),
UserID: stor[index%2].user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(stor[index%2].key.ID),
}
db.db.Save(&machine)
}
aclPolicy := &policy.ACLPolicy{
Groups: map[string][]string{
"group:test": {"admin"},
},
Hosts: map[string]netip.Prefix{},
TagOwners: map[string][]string{},
ACLs: []policy.ACL{
{
Action: "accept",
Sources: []string{"admin"},
Destinations: []string{"*:*"},
},
{
Action: "accept",
Sources: []string{"test"},
Destinations: []string{"test:*"},
},
},
Tests: []policy.ACLTest{},
}
adminMachine, err := db.GetMachineByID(1)
c.Logf("Machine(%v), user: %v", adminMachine.Hostname, adminMachine.User)
c.Assert(err, check.IsNil)
testMachine, err := db.GetMachineByID(2)
c.Logf("Machine(%v), user: %v", testMachine.Hostname, testMachine.User)
c.Assert(err, check.IsNil)
machines, err := db.ListMachines()
c.Assert(err, check.IsNil)
aclRules, _, err := policy.GenerateFilterRules(aclPolicy, machines, false)
c.Assert(err, check.IsNil)
peersOfTestMachine := db.filterMachinesByACL(aclRules, testMachine, machines)
peersOfAdminMachine := db.filterMachinesByACL(aclRules, adminMachine, machines)
c.Log(peersOfTestMachine)
c.Assert(len(peersOfTestMachine), check.Equals, 9)
c.Assert(peersOfTestMachine[0].Hostname, check.Equals, "testmachine1")
c.Assert(peersOfTestMachine[1].Hostname, check.Equals, "testmachine3")
c.Assert(peersOfTestMachine[3].Hostname, check.Equals, "testmachine5")
c.Log(peersOfAdminMachine)
c.Assert(len(peersOfAdminMachine), check.Equals, 9)
c.Assert(peersOfAdminMachine[0].Hostname, check.Equals, "testmachine2")
c.Assert(peersOfAdminMachine[2].Hostname, check.Equals, "testmachine4")
c.Assert(peersOfAdminMachine[5].Hostname, check.Equals, "testmachine7")
c.Assert(channelUpdates, check.Equals, int32(0))
}
func (s *Suite) TestExpireMachine(c *check.C) {
user, err := db.CreateUser("test")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetMachine("test", "testmachine")
c.Assert(err, check.NotNil)
machine := &types.Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testmachine",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
Expiry: &time.Time{},
}
db.db.Save(machine)
machineFromDB, err := db.GetMachine("test", "testmachine")
c.Assert(err, check.IsNil)
c.Assert(machineFromDB, check.NotNil)
c.Assert(machineFromDB.IsExpired(), check.Equals, false)
err = db.ExpireMachine(machineFromDB)
c.Assert(err, check.IsNil)
c.Assert(machineFromDB.IsExpired(), check.Equals, true)
c.Assert(channelUpdates, check.Equals, int32(1))
}
func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) {
input := types.MachineAddresses([]netip.Addr{
netip.MustParseAddr("192.0.2.1"),
netip.MustParseAddr("2001:db8::1"),
})
serialized, err := input.Value()
c.Assert(err, check.IsNil)
if serial, ok := serialized.(string); ok {
c.Assert(serial, check.Equals, "192.0.2.1,2001:db8::1")
}
var deserialized types.MachineAddresses
err = deserialized.Scan(serialized)
c.Assert(err, check.IsNil)
c.Assert(len(deserialized), check.Equals, len(input))
for i := range deserialized {
c.Assert(deserialized[i], check.Equals, input[i])
}
c.Assert(channelUpdates, check.Equals, int32(0))
}
func (s *Suite) TestGenerateGivenName(c *check.C) {
user1, err := db.CreateUser("user-1")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user1.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetMachine("user-1", "testmachine")
c.Assert(err, check.NotNil)
machine := &types.Machine{
ID: 0,
MachineKey: "machine-key-1",
NodeKey: "node-key-1",
DiscoKey: "disco-key-1",
Hostname: "hostname-1",
GivenName: "hostname-1",
UserID: user1.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
db.db.Save(machine)
givenName, err := db.GenerateGivenName("machine-key-2", "hostname-2")
comment := check.Commentf("Same user, unique machines, unique hostnames, no conflict")
c.Assert(err, check.IsNil, comment)
c.Assert(givenName, check.Equals, "hostname-2", comment)
givenName, err = db.GenerateGivenName("machine-key-1", "hostname-1")
comment = check.Commentf("Same user, same machine, same hostname, no conflict")
c.Assert(err, check.IsNil, comment)
c.Assert(givenName, check.Equals, "hostname-1", comment)
givenName, err = db.GenerateGivenName("machine-key-2", "hostname-1")
comment = check.Commentf("Same user, unique machines, same hostname, conflict")
c.Assert(err, check.IsNil, comment)
c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", MachineGivenNameHashLength), comment)
givenName, err = db.GenerateGivenName("machine-key-2", "hostname-1")
comment = check.Commentf("Unique users, unique machines, same hostname, conflict")
c.Assert(err, check.IsNil, comment)
c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", MachineGivenNameHashLength), comment)
c.Assert(channelUpdates, check.Equals, int32(0))
}
func (s *Suite) TestSetTags(c *check.C) {
user, err := db.CreateUser("test")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetMachine("test", "testmachine")
c.Assert(err, check.NotNil)
machine := &types.Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testmachine",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
db.db.Save(machine)
// assign simple tags
sTags := []string{"tag:test", "tag:foo"}
err = db.SetTags(machine, sTags)
c.Assert(err, check.IsNil)
machine, err = db.GetMachine("test", "testmachine")
c.Assert(err, check.IsNil)
c.Assert(machine.ForcedTags, check.DeepEquals, types.StringList(sTags))
// assign duplicat tags, expect no errors but no doubles in DB
eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"}
err = db.SetTags(machine, eTags)
c.Assert(err, check.IsNil)
machine, err = db.GetMachine("test", "testmachine")
c.Assert(err, check.IsNil)
c.Assert(
machine.ForcedTags,
check.DeepEquals,
types.StringList([]string{"tag:bar", "tag:test", "tag:unknown"}),
)
c.Assert(channelUpdates, check.Equals, int32(4))
}
func TestHeadscale_generateGivenName(t *testing.T) {
type args struct {
suppliedName string
randomSuffix bool
}
tests := []struct {
name string
db *HSDatabase
args args
want *regexp.Regexp
wantErr bool
}{
{
name: "simple machine name generation",
db: &HSDatabase{
stripEmailDomain: true,
},
args: args{
suppliedName: "testmachine",
randomSuffix: false,
},
want: regexp.MustCompile("^testmachine$"),
wantErr: false,
},
{
name: "machine name with 53 chars",
db: &HSDatabase{
stripEmailDomain: true,
},
args: args{
suppliedName: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine",
randomSuffix: false,
},
want: regexp.MustCompile("^testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine$"),
wantErr: false,
},
{
name: "machine name with 63 chars",
db: &HSDatabase{
stripEmailDomain: true,
},
args: args{
suppliedName: "machineeee12345678901234567890123456789012345678901234567890123",
randomSuffix: false,
},
want: regexp.MustCompile("^machineeee12345678901234567890123456789012345678901234567890123$"),
wantErr: false,
},
{
name: "machine name with 64 chars",
db: &HSDatabase{
stripEmailDomain: true,
},
args: args{
suppliedName: "machineeee123456789012345678901234567890123456789012345678901234",
randomSuffix: false,
},
want: nil,
wantErr: true,
},
{
name: "machine name with 73 chars",
db: &HSDatabase{
stripEmailDomain: true,
},
args: args{
suppliedName: "machineeee123456789012345678901234567890123456789012345678901234567890123",
randomSuffix: false,
},
want: nil,
wantErr: true,
},
{
name: "machine name with random suffix",
db: &HSDatabase{
stripEmailDomain: true,
},
args: args{
suppliedName: "test",
randomSuffix: true,
},
want: regexp.MustCompile(fmt.Sprintf("^test-[a-z0-9]{%d}$", MachineGivenNameHashLength)),
wantErr: false,
},
{
name: "machine name with 63 chars with random suffix",
db: &HSDatabase{
stripEmailDomain: true,
},
args: args{
suppliedName: "machineeee12345678901234567890123456789012345678901234567890123",
randomSuffix: true,
},
want: regexp.MustCompile(fmt.Sprintf("^machineeee1234567890123456789012345678901234567890123-[a-z0-9]{%d}$", MachineGivenNameHashLength)),
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.db.generateGivenName(tt.args.suppliedName, tt.args.randomSuffix)
if (err != nil) != tt.wantErr {
t.Errorf(
"Headscale.GenerateGivenName() error = %v, wantErr %v",
err,
tt.wantErr,
)
return
}
if tt.want != nil && !tt.want.MatchString(got) {
t.Errorf(
"Headscale.GenerateGivenName() = %v, does not match %v",
tt.want,
got,
)
}
if len(got) > util.LabelHostnameLength {
t.Errorf(
"Headscale.GenerateGivenName() = %v is larger than allowed DNS segment %d",
got,
util.LabelHostnameLength,
)
}
})
}
}
func (s *Suite) TestAutoApproveRoutes(c *check.C) {
acl := []byte(`
{
"tagOwners": {
"tag:exit": ["test"],
},
"groups": {
"group:test": ["test"]
},
"acls": [
{"action": "accept", "users": ["*"], "ports": ["*:*"]},
],
"autoApprovers": {
"exitNode": ["tag:exit"],
"routes": {
"10.10.0.0/16": ["group:test"],
"10.11.0.0/16": ["test"],
}
}
}
`)
pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson")
c.Assert(err, check.IsNil)
c.Assert(pol, check.NotNil)
user, err := db.CreateUser("test")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
nodeKey := key.NewNode()
defaultRouteV4 := netip.MustParsePrefix("0.0.0.0/0")
defaultRouteV6 := netip.MustParsePrefix("::/0")
route1 := netip.MustParsePrefix("10.10.0.0/16")
// Check if a subprefix of an autoapproved route is approved
route2 := netip.MustParsePrefix("10.11.0.0/24")
machine := types.Machine{
ID: 0,
MachineKey: "foo",
NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()),
DiscoKey: "faa",
Hostname: "test",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: types.HostInfo{
RequestTags: []string{"tag:exit"},
RoutableIPs: []netip.Prefix{defaultRouteV4, defaultRouteV6, route1, route2},
},
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
}
db.db.Save(&machine)
err = db.ProcessMachineRoutes(&machine)
c.Assert(err, check.IsNil)
machine0ByID, err := db.GetMachineByID(0)
c.Assert(err, check.IsNil)
err = db.EnableAutoApprovedRoutes(pol, machine0ByID)
c.Assert(err, check.IsNil)
enabledRoutes, err := db.GetEnabledRoutes(machine0ByID)
c.Assert(err, check.IsNil)
c.Assert(enabledRoutes, check.HasLen, 4)
c.Assert(channelUpdates, check.Equals, int32(4))
}
func TestMachine_canAccess(t *testing.T) {
type args struct {
filter []tailcfg.FilterRule
machine2 *types.Machine
}
tests := []struct {
name string
machine types.Machine
args args
want bool
}{
{
name: "no-rules",
machine: types.Machine{
IPAddresses: types.MachineAddresses{
netip.MustParseAddr("10.0.0.1"),
},
},
args: args{
filter: []tailcfg.FilterRule{},
machine2: &types.Machine{
IPAddresses: types.MachineAddresses{
netip.MustParseAddr("10.0.0.2"),
},
},
},
want: false,
},
{
name: "wildcard",
machine: types.Machine{
IPAddresses: types.MachineAddresses{
netip.MustParseAddr("10.0.0.1"),
},
},
args: args{
filter: []tailcfg.FilterRule{
{
SrcIPs: []string{"*"},
DstPorts: []tailcfg.NetPortRange{
{
IP: "*",
Ports: tailcfg.PortRange{
First: 0,
Last: 65535,
},
},
},
},
},
machine2: &types.Machine{
IPAddresses: types.MachineAddresses{
netip.MustParseAddr("10.0.0.2"),
},
},
},
want: true,
},
{
name: "explicit-m1-to-m2",
machine: types.Machine{
IPAddresses: types.MachineAddresses{
netip.MustParseAddr("10.0.0.1"),
},
},
args: args{
filter: []tailcfg.FilterRule{
{
SrcIPs: []string{"10.0.0.1"},
DstPorts: []tailcfg.NetPortRange{
{
IP: "10.0.0.2",
Ports: tailcfg.PortRange{
First: 0,
Last: 65535,
},
},
},
},
},
machine2: &types.Machine{
IPAddresses: types.MachineAddresses{
netip.MustParseAddr("10.0.0.2"),
},
},
},
want: true,
},
{
name: "explicit-m2-to-m1",
machine: types.Machine{
IPAddresses: types.MachineAddresses{
netip.MustParseAddr("10.0.0.1"),
},
},
args: args{
filter: []tailcfg.FilterRule{
{
SrcIPs: []string{"10.0.0.2"},
DstPorts: []tailcfg.NetPortRange{
{
IP: "10.0.0.1",
Ports: tailcfg.PortRange{
First: 0,
Last: 65535,
},
},
},
},
},
machine2: &types.Machine{
IPAddresses: types.MachineAddresses{
netip.MustParseAddr("10.0.0.2"),
},
},
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.machine.CanAccess(tt.args.filter, tt.args.machine2); got != tt.want {
t.Errorf("Machine.CanAccess() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -1,17 +1,14 @@
package hscontrol package db
import ( import (
"crypto/rand" "crypto/rand"
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
"strconv"
"strings" "strings"
"time" "time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"google.golang.org/protobuf/types/known/timestamppb"
"gorm.io/gorm" "gorm.io/gorm"
) )
@ -23,28 +20,6 @@ var (
ErrPreAuthKeyACLTagInvalid = errors.New("AuthKey tag is invalid") ErrPreAuthKeyACLTagInvalid = errors.New("AuthKey tag is invalid")
) )
// PreAuthKey describes a pre-authorization key usable in a particular user.
type PreAuthKey struct {
ID uint64 `gorm:"primary_key"`
Key string
UserID uint
User User
Reusable bool
Ephemeral bool `gorm:"default:false"`
Used bool `gorm:"default:false"`
ACLTags []PreAuthKeyACLTag
CreatedAt *time.Time
Expiration *time.Time
}
// PreAuthKeyACLTag describes an autmatic tag applied to a node when registered with the associated PreAuthKey.
type PreAuthKeyACLTag struct {
ID uint64 `gorm:"primary_key"`
PreAuthKeyID uint64
Tag string
}
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it. // CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
func (hsdb *HSDatabase) CreatePreAuthKey( func (hsdb *HSDatabase) CreatePreAuthKey(
userName string, userName string,
@ -52,7 +27,7 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
ephemeral bool, ephemeral bool,
expiration *time.Time, expiration *time.Time,
aclTags []string, aclTags []string,
) (*PreAuthKey, error) { ) (*types.PreAuthKey, error) {
user, err := hsdb.GetUser(userName) user, err := hsdb.GetUser(userName)
if err != nil { if err != nil {
return nil, err return nil, err
@ -74,7 +49,7 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
return nil, err return nil, err
} }
key := PreAuthKey{ key := types.PreAuthKey{
Key: kstr, Key: kstr,
UserID: user.ID, UserID: user.ID,
User: *user, User: *user,
@ -94,7 +69,7 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
for _, tag := range aclTags { for _, tag := range aclTags {
if !seenTags[tag] { if !seenTags[tag] {
if err := db.Save(&PreAuthKeyACLTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil { if err := db.Save(&types.PreAuthKeyACLTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil {
return fmt.Errorf( return fmt.Errorf(
"failed to ceate key tag in the database: %w", "failed to ceate key tag in the database: %w",
err, err,
@ -116,14 +91,14 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
} }
// ListPreAuthKeys returns the list of PreAuthKeys for a user. // ListPreAuthKeys returns the list of PreAuthKeys for a user.
func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]PreAuthKey, error) { func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, error) {
user, err := hsdb.GetUser(userName) user, err := hsdb.GetUser(userName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
keys := []PreAuthKey{} keys := []types.PreAuthKey{}
if err := hsdb.db.Preload("User").Preload("ACLTags").Where(&PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil { if err := hsdb.db.Preload("User").Preload("ACLTags").Where(&types.PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil {
return nil, err return nil, err
} }
@ -131,8 +106,8 @@ func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]PreAuthKey, error) {
} }
// GetPreAuthKey returns a PreAuthKey for a given key. // GetPreAuthKey returns a PreAuthKey for a given key.
func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*PreAuthKey, error) { func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKey, error) {
pak, err := hsdb.checkKeyValidity(key) pak, err := hsdb.ValidatePreAuthKey(key)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -146,9 +121,9 @@ func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*PreAuthKey, err
// DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey // DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey
// does not exist. // does not exist.
func (hsdb *HSDatabase) DestroyPreAuthKey(pak PreAuthKey) error { func (hsdb *HSDatabase) DestroyPreAuthKey(pak types.PreAuthKey) error {
return hsdb.db.Transaction(func(db *gorm.DB) error { return hsdb.db.Transaction(func(db *gorm.DB) error {
if result := db.Unscoped().Where(PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&PreAuthKeyACLTag{}); result.Error != nil { if result := db.Unscoped().Where(types.PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&types.PreAuthKeyACLTag{}); result.Error != nil {
return result.Error return result.Error
} }
@ -161,7 +136,7 @@ func (hsdb *HSDatabase) DestroyPreAuthKey(pak PreAuthKey) error {
} }
// MarkExpirePreAuthKey marks a PreAuthKey as expired. // MarkExpirePreAuthKey marks a PreAuthKey as expired.
func (hsdb *HSDatabase) ExpirePreAuthKey(k *PreAuthKey) error { func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error {
if err := hsdb.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil { if err := hsdb.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
return err return err
} }
@ -170,7 +145,7 @@ func (hsdb *HSDatabase) ExpirePreAuthKey(k *PreAuthKey) error {
} }
// UsePreAuthKey marks a PreAuthKey as used. // UsePreAuthKey marks a PreAuthKey as used.
func (hsdb *HSDatabase) UsePreAuthKey(k *PreAuthKey) error { func (hsdb *HSDatabase) UsePreAuthKey(k *types.PreAuthKey) error {
k.Used = true k.Used = true
if err := hsdb.db.Save(k).Error; err != nil { if err := hsdb.db.Save(k).Error; err != nil {
return fmt.Errorf("failed to update key used status in the database: %w", err) return fmt.Errorf("failed to update key used status in the database: %w", err)
@ -179,10 +154,10 @@ func (hsdb *HSDatabase) UsePreAuthKey(k *PreAuthKey) error {
return nil return nil
} }
// checkKeyValidity does the heavy lifting for validation of the PreAuthKey coming from a node // ValidatePreAuthKey does the heavy lifting for validation of the PreAuthKey coming from a node
// If returns no error and a PreAuthKey, it can be used. // If returns no error and a PreAuthKey, it can be used.
func (hsdb *HSDatabase) checkKeyValidity(k string) (*PreAuthKey, error) { func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) {
pak := PreAuthKey{} pak := types.PreAuthKey{}
if result := hsdb.db.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is( if result := hsdb.db.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is(
result.Error, result.Error,
gorm.ErrRecordNotFound, gorm.ErrRecordNotFound,
@ -198,8 +173,8 @@ func (hsdb *HSDatabase) checkKeyValidity(k string) (*PreAuthKey, error) {
return &pak, nil return &pak, nil
} }
machines := []Machine{} machines := types.Machines{}
if err := hsdb.db.Preload("AuthKey").Where(&Machine{AuthKeyID: uint(pak.ID)}).Find(&machines).Error; err != nil { if err := hsdb.db.Preload("AuthKey").Where(&types.Machine{AuthKeyID: uint(pak.ID)}).Find(&machines).Error; err != nil {
return nil, err return nil, err
} }
@ -219,29 +194,3 @@ func (hsdb *HSDatabase) generateKey() (string, error) {
return hex.EncodeToString(bytes), nil return hex.EncodeToString(bytes), nil
} }
func (key *PreAuthKey) toProto() *v1.PreAuthKey {
protoKey := v1.PreAuthKey{
User: key.User.Name,
Id: strconv.FormatUint(key.ID, util.Base10),
Key: key.Key,
Ephemeral: key.Ephemeral,
Reusable: key.Reusable,
Used: key.Used,
AclTags: make([]string, len(key.ACLTags)),
}
if key.Expiration != nil {
protoKey.Expiration = timestamppb.New(*key.Expiration)
}
if key.CreatedAt != nil {
protoKey.CreatedAt = timestamppb.New(*key.CreatedAt)
}
for idx := range key.ACLTags {
protoKey.AclTags[idx] = key.ACLTags[idx].Tag
}
return &protoKey
}

View file

@ -1,20 +1,22 @@
package hscontrol package db
import ( import (
"time" "time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"gopkg.in/check.v1" "gopkg.in/check.v1"
) )
func (*Suite) TestCreatePreAuthKey(c *check.C) { func (*Suite) TestCreatePreAuthKey(c *check.C) {
_, err := app.db.CreatePreAuthKey("bogus", true, false, nil, nil) _, err := db.CreatePreAuthKey("bogus", true, false, nil, nil)
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
user, err := app.db.CreateUser("test") user, err := db.CreateUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
key, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil) key, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
// Did we get a valid key? // Did we get a valid key?
@ -24,10 +26,10 @@ func (*Suite) TestCreatePreAuthKey(c *check.C) {
// Make sure the User association is populated // Make sure the User association is populated
c.Assert(key.User.Name, check.Equals, user.Name) c.Assert(key.User.Name, check.Equals, user.Name)
_, err = app.db.ListPreAuthKeys("bogus") _, err = db.ListPreAuthKeys("bogus")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
keys, err := app.db.ListPreAuthKeys(user.Name) keys, err := db.ListPreAuthKeys(user.Name)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(keys), check.Equals, 1) c.Assert(len(keys), check.Equals, 1)
@ -36,174 +38,176 @@ func (*Suite) TestCreatePreAuthKey(c *check.C) {
} }
func (*Suite) TestExpiredPreAuthKey(c *check.C) { func (*Suite) TestExpiredPreAuthKey(c *check.C) {
user, err := app.db.CreateUser("test2") user, err := db.CreateUser("test2")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
now := time.Now() now := time.Now()
pak, err := app.db.CreatePreAuthKey(user.Name, true, false, &now, nil) pak, err := db.CreatePreAuthKey(user.Name, true, false, &now, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
key, err := app.db.checkKeyValidity(pak.Key) key, err := db.ValidatePreAuthKey(pak.Key)
c.Assert(err, check.Equals, ErrPreAuthKeyExpired) c.Assert(err, check.Equals, ErrPreAuthKeyExpired)
c.Assert(key, check.IsNil) c.Assert(key, check.IsNil)
} }
func (*Suite) TestPreAuthKeyDoesNotExist(c *check.C) { func (*Suite) TestPreAuthKeyDoesNotExist(c *check.C) {
key, err := app.db.checkKeyValidity("potatoKey") key, err := db.ValidatePreAuthKey("potatoKey")
c.Assert(err, check.Equals, ErrPreAuthKeyNotFound) c.Assert(err, check.Equals, ErrPreAuthKeyNotFound)
c.Assert(key, check.IsNil) c.Assert(key, check.IsNil)
} }
func (*Suite) TestValidateKeyOk(c *check.C) { func (*Suite) TestValidateKeyOk(c *check.C) {
user, err := app.db.CreateUser("test3") user, err := db.CreateUser("test3")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
key, err := app.db.checkKeyValidity(pak.Key) key, err := db.ValidatePreAuthKey(pak.Key)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(key.ID, check.Equals, pak.ID) c.Assert(key.ID, check.Equals, pak.ID)
} }
func (*Suite) TestAlreadyUsedKey(c *check.C) { func (*Suite) TestAlreadyUsedKey(c *check.C) {
user, err := app.db.CreateUser("test4") user, err := db.CreateUser("test4")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
machine := Machine{ machine := types.Machine{
ID: 0, ID: 0,
MachineKey: "foo", MachineKey: "foo",
NodeKey: "bar", NodeKey: "bar",
DiscoKey: "faa", DiscoKey: "faa",
Hostname: "testest", Hostname: "testest",
UserID: user.ID, UserID: user.ID,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
} }
app.db.db.Save(&machine) db.db.Save(&machine)
key, err := app.db.checkKeyValidity(pak.Key) key, err := db.ValidatePreAuthKey(pak.Key)
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed) c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)
c.Assert(key, check.IsNil) c.Assert(key, check.IsNil)
} }
func (*Suite) TestReusableBeingUsedKey(c *check.C) { func (*Suite) TestReusableBeingUsedKey(c *check.C) {
user, err := app.db.CreateUser("test5") user, err := db.CreateUser("test5")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
machine := Machine{ machine := types.Machine{
ID: 1, ID: 1,
MachineKey: "foo", MachineKey: "foo",
NodeKey: "bar", NodeKey: "bar",
DiscoKey: "faa", DiscoKey: "faa",
Hostname: "testest", Hostname: "testest",
UserID: user.ID, UserID: user.ID,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
} }
app.db.db.Save(&machine) db.db.Save(&machine)
key, err := app.db.checkKeyValidity(pak.Key) key, err := db.ValidatePreAuthKey(pak.Key)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(key.ID, check.Equals, pak.ID) c.Assert(key.ID, check.Equals, pak.ID)
} }
func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) { func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) {
user, err := app.db.CreateUser("test6") user, err := db.CreateUser("test6")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
key, err := app.db.checkKeyValidity(pak.Key) key, err := db.ValidatePreAuthKey(pak.Key)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(key.ID, check.Equals, pak.ID) c.Assert(key.ID, check.Equals, pak.ID)
} }
func (*Suite) TestEphemeralKey(c *check.C) { func (*Suite) TestEphemeralKey(c *check.C) {
user, err := app.db.CreateUser("test7") user, err := db.CreateUser("test7")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := app.db.CreatePreAuthKey(user.Name, false, true, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, true, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
now := time.Now() now := time.Now().Add(-time.Second * 30)
machine := Machine{ machine := types.Machine{
ID: 0, ID: 0,
MachineKey: "foo", MachineKey: "foo",
NodeKey: "bar", NodeKey: "bar",
DiscoKey: "faa", DiscoKey: "faa",
Hostname: "testest", Hostname: "testest",
UserID: user.ID, UserID: user.ID,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
LastSeen: &now, LastSeen: &now,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
} }
app.db.db.Save(&machine) db.db.Save(&machine)
_, err = app.db.checkKeyValidity(pak.Key) _, err = db.ValidatePreAuthKey(pak.Key)
// Ephemeral keys are by definition reusable // Ephemeral keys are by definition reusable
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = app.db.GetMachine("test7", "testest") _, err = db.GetMachine("test7", "testest")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
app.expireEphemeralNodesWorker() db.ExpireEphemeralMachines(time.Second * 20)
// The machine record should have been deleted // The machine record should have been deleted
_, err = app.db.GetMachine("test7", "testest") _, err = db.GetMachine("test7", "testest")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
c.Assert(channelUpdates, check.Equals, int32(1))
} }
func (*Suite) TestExpirePreauthKey(c *check.C) { func (*Suite) TestExpirePreauthKey(c *check.C) {
user, err := app.db.CreateUser("test3") user, err := db.CreateUser("test3")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(pak.Expiration, check.IsNil) c.Assert(pak.Expiration, check.IsNil)
err = app.db.ExpirePreAuthKey(pak) err = db.ExpirePreAuthKey(pak)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(pak.Expiration, check.NotNil) c.Assert(pak.Expiration, check.NotNil)
key, err := app.db.checkKeyValidity(pak.Key) key, err := db.ValidatePreAuthKey(pak.Key)
c.Assert(err, check.Equals, ErrPreAuthKeyExpired) c.Assert(err, check.Equals, ErrPreAuthKeyExpired)
c.Assert(key, check.IsNil) c.Assert(key, check.IsNil)
} }
func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) { func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) {
user, err := app.db.CreateUser("test6") user, err := db.CreateUser("test6")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak.Used = true pak.Used = true
app.db.db.Save(&pak) db.db.Save(&pak)
_, err = app.db.checkKeyValidity(pak.Key) _, err = db.ValidatePreAuthKey(pak.Key)
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed) c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)
} }
func (*Suite) TestPreAuthKeyACLTags(c *check.C) { func (*Suite) TestPreAuthKeyACLTags(c *check.C) {
user, err := app.db.CreateUser("test8") user, err := db.CreateUser("test8")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = app.db.CreatePreAuthKey(user.Name, false, false, nil, []string{"badtag"}) _, err = db.CreatePreAuthKey(user.Name, false, false, nil, []string{"badtag"})
c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected
tags := []string{"tag:test1", "tag:test2"} tags := []string{"tag:test1", "tag:test2"}
tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"} tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"}
_, err = app.db.CreatePreAuthKey(user.Name, false, false, nil, tagsWithDuplicate) _, err = db.CreatePreAuthKey(user.Name, false, false, nil, tagsWithDuplicate)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
listedPaks, err := app.db.ListPreAuthKeys("test8") listedPaks, err := db.ListPreAuthKeys("test8")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(listedPaks[0].toProto().AclTags, check.DeepEquals, tags) c.Assert(listedPaks[0].Proto().AclTags, check.DeepEquals, tags)
} }

View file

@ -1,55 +1,19 @@
package hscontrol package db
import ( import (
"errors" "errors"
"fmt"
"net/netip" "net/netip"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"google.golang.org/protobuf/types/known/timestamppb"
"gorm.io/gorm" "gorm.io/gorm"
) )
var ( var ErrRouteIsNotAvailable = errors.New("route is not available")
ErrRouteIsNotAvailable = errors.New("route is not available")
ExitRouteV4 = netip.MustParsePrefix("0.0.0.0/0")
ExitRouteV6 = netip.MustParsePrefix("::/0")
)
type Route struct { func (hsdb *HSDatabase) GetRoutes() (types.Routes, error) {
gorm.Model var routes types.Routes
MachineID uint64
Machine Machine
Prefix IPPrefix
Advertised bool
Enabled bool
IsPrimary bool
}
type Routes []Route
func (r *Route) String() string {
return fmt.Sprintf("%s:%s", r.Machine, netip.Prefix(r.Prefix).String())
}
func (r *Route) isExitRoute() bool {
return netip.Prefix(r.Prefix) == ExitRouteV4 || netip.Prefix(r.Prefix) == ExitRouteV6
}
func (rs Routes) toPrefixes() []netip.Prefix {
prefixes := make([]netip.Prefix, len(rs))
for i, r := range rs {
prefixes[i] = netip.Prefix(r.Prefix)
}
return prefixes
}
func (hsdb *HSDatabase) GetRoutes() ([]Route, error) {
var routes []Route
err := hsdb.db.Preload("Machine").Find(&routes).Error err := hsdb.db.Preload("Machine").Find(&routes).Error
if err != nil { if err != nil {
return nil, err return nil, err
@ -58,8 +22,21 @@ func (hsdb *HSDatabase) GetRoutes() ([]Route, error) {
return routes, nil return routes, nil
} }
func (hsdb *HSDatabase) GetMachineRoutes(m *Machine) ([]Route, error) { func (hsdb *HSDatabase) GetMachineAdvertisedRoutes(machine *types.Machine) (types.Routes, error) {
var routes []Route var routes types.Routes
err := hsdb.db.
Preload("Machine").
Where("machine_id = ? AND advertised = true", machine.ID).
Find(&routes).Error
if err != nil {
return nil, err
}
return routes, nil
}
func (hsdb *HSDatabase) GetMachineRoutes(m *types.Machine) (types.Routes, error) {
var routes types.Routes
err := hsdb.db. err := hsdb.db.
Preload("Machine"). Preload("Machine").
Where("machine_id = ?", m.ID). Where("machine_id = ?", m.ID).
@ -71,8 +48,8 @@ func (hsdb *HSDatabase) GetMachineRoutes(m *Machine) ([]Route, error) {
return routes, nil return routes, nil
} }
func (hsdb *HSDatabase) GetRoute(id uint64) (*Route, error) { func (hsdb *HSDatabase) GetRoute(id uint64) (*types.Route, error) {
var route Route var route types.Route
err := hsdb.db.Preload("Machine").First(&route, id).Error err := hsdb.db.Preload("Machine").First(&route, id).Error
if err != nil { if err != nil {
return nil, err return nil, err
@ -90,8 +67,12 @@ func (hsdb *HSDatabase) EnableRoute(id uint64) error {
// Tailscale requires both IPv4 and IPv6 exit routes to // Tailscale requires both IPv4 and IPv6 exit routes to
// be enabled at the same time, as per // be enabled at the same time, as per
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 // https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
if route.isExitRoute() { if route.IsExitRoute() {
return hsdb.enableRoutes(&route.Machine, ExitRouteV4.String(), ExitRouteV6.String()) return hsdb.enableRoutes(
&route.Machine,
types.ExitRouteV4.String(),
types.ExitRouteV6.String(),
)
} }
return hsdb.enableRoutes(&route.Machine, netip.Prefix(route.Prefix).String()) return hsdb.enableRoutes(&route.Machine, netip.Prefix(route.Prefix).String())
@ -106,7 +87,7 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error {
// Tailscale requires both IPv4 and IPv6 exit routes to // Tailscale requires both IPv4 and IPv6 exit routes to
// be enabled at the same time, as per // be enabled at the same time, as per
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 // https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
if !route.isExitRoute() { if !route.IsExitRoute() {
route.Enabled = false route.Enabled = false
route.IsPrimary = false route.IsPrimary = false
err = hsdb.db.Save(route).Error err = hsdb.db.Save(route).Error
@ -114,7 +95,7 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error {
return err return err
} }
return hsdb.handlePrimarySubnetFailover() return hsdb.HandlePrimarySubnetFailover()
} }
routes, err := hsdb.GetMachineRoutes(&route.Machine) routes, err := hsdb.GetMachineRoutes(&route.Machine)
@ -123,7 +104,7 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error {
} }
for i := range routes { for i := range routes {
if routes[i].isExitRoute() { if routes[i].IsExitRoute() {
routes[i].Enabled = false routes[i].Enabled = false
routes[i].IsPrimary = false routes[i].IsPrimary = false
err = hsdb.db.Save(&routes[i]).Error err = hsdb.db.Save(&routes[i]).Error
@ -133,7 +114,7 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error {
} }
} }
return hsdb.handlePrimarySubnetFailover() return hsdb.HandlePrimarySubnetFailover()
} }
func (hsdb *HSDatabase) DeleteRoute(id uint64) error { func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
@ -145,12 +126,12 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
// Tailscale requires both IPv4 and IPv6 exit routes to // Tailscale requires both IPv4 and IPv6 exit routes to
// be enabled at the same time, as per // be enabled at the same time, as per
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 // https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
if !route.isExitRoute() { if !route.IsExitRoute() {
if err := hsdb.db.Unscoped().Delete(&route).Error; err != nil { if err := hsdb.db.Unscoped().Delete(&route).Error; err != nil {
return err return err
} }
return hsdb.handlePrimarySubnetFailover() return hsdb.HandlePrimarySubnetFailover()
} }
routes, err := hsdb.GetMachineRoutes(&route.Machine) routes, err := hsdb.GetMachineRoutes(&route.Machine)
@ -158,9 +139,9 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
return err return err
} }
routesToDelete := []Route{} routesToDelete := types.Routes{}
for _, r := range routes { for _, r := range routes {
if r.isExitRoute() { if r.IsExitRoute() {
routesToDelete = append(routesToDelete, r) routesToDelete = append(routesToDelete, r)
} }
} }
@ -169,10 +150,10 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
return err return err
} }
return hsdb.handlePrimarySubnetFailover() return hsdb.HandlePrimarySubnetFailover()
} }
func (hsdb *HSDatabase) DeleteMachineRoutes(m *Machine) error { func (hsdb *HSDatabase) DeleteMachineRoutes(m *types.Machine) error {
routes, err := hsdb.GetMachineRoutes(m) routes, err := hsdb.GetMachineRoutes(m)
if err != nil { if err != nil {
return err return err
@ -184,14 +165,14 @@ func (hsdb *HSDatabase) DeleteMachineRoutes(m *Machine) error {
} }
} }
return hsdb.handlePrimarySubnetFailover() return hsdb.HandlePrimarySubnetFailover()
} }
// isUniquePrefix returns if there is another machine providing the same route already. // isUniquePrefix returns if there is another machine providing the same route already.
func (hsdb *HSDatabase) isUniquePrefix(route Route) bool { func (hsdb *HSDatabase) isUniquePrefix(route types.Route) bool {
var count int64 var count int64
hsdb.db. hsdb.db.
Model(&Route{}). Model(&types.Route{}).
Where("prefix = ? AND machine_id != ? AND advertised = ? AND enabled = ?", Where("prefix = ? AND machine_id != ? AND advertised = ? AND enabled = ?",
route.Prefix, route.Prefix,
route.MachineID, route.MachineID,
@ -200,11 +181,11 @@ func (hsdb *HSDatabase) isUniquePrefix(route Route) bool {
return count == 0 return count == 0
} }
func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*Route, error) { func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*types.Route, error) {
var route Route var route types.Route
err := hsdb.db. err := hsdb.db.
Preload("Machine"). Preload("Machine").
Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", IPPrefix(prefix), true, true, true). Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", types.IPPrefix(prefix), true, true, true).
First(&route).Error First(&route).Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err return nil, err
@ -219,8 +200,8 @@ func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*Route, error) {
// getMachinePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover) // getMachinePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover)
// Exit nodes are not considered for this, as they are never marked as Primary. // Exit nodes are not considered for this, as they are never marked as Primary.
func (hsdb *HSDatabase) getMachinePrimaryRoutes(m *Machine) ([]Route, error) { func (hsdb *HSDatabase) GetMachinePrimaryRoutes(m *types.Machine) (types.Routes, error) {
var routes []Route var routes types.Routes
err := hsdb.db. err := hsdb.db.
Preload("Machine"). Preload("Machine").
Where("machine_id = ? AND advertised = ? AND enabled = ? AND is_primary = ?", m.ID, true, true, true). Where("machine_id = ? AND advertised = ? AND enabled = ? AND is_primary = ?", m.ID, true, true, true).
@ -232,8 +213,8 @@ func (hsdb *HSDatabase) getMachinePrimaryRoutes(m *Machine) ([]Route, error) {
return routes, nil return routes, nil
} }
func (hsdb *HSDatabase) processMachineRoutes(machine *Machine) error { func (hsdb *HSDatabase) ProcessMachineRoutes(machine *types.Machine) error {
currentRoutes := []Route{} currentRoutes := types.Routes{}
err := hsdb.db.Where("machine_id = ?", machine.ID).Find(&currentRoutes).Error err := hsdb.db.Where("machine_id = ?", machine.ID).Find(&currentRoutes).Error
if err != nil { if err != nil {
return err return err
@ -266,9 +247,9 @@ func (hsdb *HSDatabase) processMachineRoutes(machine *Machine) error {
for prefix, exists := range advertisedRoutes { for prefix, exists := range advertisedRoutes {
if !exists { if !exists {
route := Route{ route := types.Route{
MachineID: machine.ID, MachineID: machine.ID,
Prefix: IPPrefix(prefix), Prefix: types.IPPrefix(prefix),
Advertised: true, Advertised: true,
Enabled: false, Enabled: false,
} }
@ -282,9 +263,9 @@ func (hsdb *HSDatabase) processMachineRoutes(machine *Machine) error {
return nil return nil
} }
func (hsdb *HSDatabase) handlePrimarySubnetFailover() error { func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error {
// first, get all the enabled routes // first, get all the enabled routes
var routes []Route var routes types.Routes
err := hsdb.db. err := hsdb.db.
Preload("Machine"). Preload("Machine").
Where("advertised = ? AND enabled = ?", true, true). Where("advertised = ? AND enabled = ?", true, true).
@ -295,7 +276,7 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error {
routesChanged := false routesChanged := false
for pos, route := range routes { for pos, route := range routes {
if route.isExitRoute() { if route.IsExitRoute() {
continue continue
} }
@ -321,7 +302,7 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error {
} }
if route.IsPrimary { if route.IsPrimary {
if route.Machine.isOnline() { if route.Machine.IsOnline() {
continue continue
} }
@ -332,7 +313,7 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error {
Msgf("machine offline, finding a new primary subnet") Msgf("machine offline, finding a new primary subnet")
// find a new primary route // find a new primary route
var newPrimaryRoutes []Route var newPrimaryRoutes types.Routes
err := hsdb.db. err := hsdb.db.
Preload("Machine"). Preload("Machine").
Where("prefix = ? AND machine_id != ? AND advertised = ? AND enabled = ?", Where("prefix = ? AND machine_id != ? AND advertised = ? AND enabled = ?",
@ -346,9 +327,9 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error {
return err return err
} }
var newPrimaryRoute *Route var newPrimaryRoute *types.Route
for pos, r := range newPrimaryRoutes { for pos, r := range newPrimaryRoutes {
if r.Machine.isOnline() { if r.Machine.IsOnline() {
newPrimaryRoute = &newPrimaryRoutes[pos] newPrimaryRoute = &newPrimaryRoutes[pos]
break break
@ -399,27 +380,78 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error {
return nil return nil
} }
func (rs Routes) toProto() []*v1.Route { // EnableAutoApprovedRoutes enables any routes advertised by a machine that match the ACL autoApprovers policy.
protoRoutes := []*v1.Route{} func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
aclPolicy *policy.ACLPolicy,
for _, route := range rs { machine *types.Machine,
protoRoute := v1.Route{ ) error {
Id: uint64(route.ID), if len(machine.IPAddresses) == 0 {
Machine: route.Machine.toProto(), return nil // This machine has no IPAddresses, so can't possibly match any autoApprovers ACLs
Prefix: netip.Prefix(route.Prefix).String(),
Advertised: route.Advertised,
Enabled: route.Enabled,
IsPrimary: route.IsPrimary,
CreatedAt: timestamppb.New(route.CreatedAt),
UpdatedAt: timestamppb.New(route.UpdatedAt),
} }
if route.DeletedAt.Valid { routes, err := hsdb.GetMachineAdvertisedRoutes(machine)
protoRoute.DeletedAt = timestamppb.New(route.DeletedAt.Time) if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
log.Error().
Caller().
Err(err).
Str("machine", machine.Hostname).
Msg("Could not get advertised routes for machine")
return err
} }
protoRoutes = append(protoRoutes, &protoRoute) approvedRoutes := types.Routes{}
for _, advertisedRoute := range routes {
if advertisedRoute.Enabled {
continue
} }
return protoRoutes routeApprovers, err := aclPolicy.AutoApprovers.GetRouteApprovers(
netip.Prefix(advertisedRoute.Prefix),
)
if err != nil {
log.Err(err).
Str("advertisedRoute", advertisedRoute.String()).
Uint64("machineId", machine.ID).
Msg("Failed to resolve autoApprovers for advertised route")
return err
}
for _, approvedAlias := range routeApprovers {
if approvedAlias == machine.User.Name {
approvedRoutes = append(approvedRoutes, advertisedRoute)
} else {
// TODO(kradalby): figure out how to get this to depend on less stuff
approvedIps, err := aclPolicy.ExpandAlias(types.Machines{*machine}, approvedAlias, hsdb.stripEmailDomain)
if err != nil {
log.Err(err).
Str("alias", approvedAlias).
Msg("Failed to expand alias when processing autoApprovers policy")
return err
}
// approvedIPs should contain all of machine's IPs if it matches the rule, so check for first
if approvedIps.Contains(machine.IPAddresses[0]) {
approvedRoutes = append(approvedRoutes, advertisedRoute)
}
}
}
}
for _, approvedRoute := range approvedRoutes {
err := hsdb.EnableRoute(uint64(approvedRoute.ID))
if err != nil {
log.Err(err).
Str("approvedRoute", approvedRoute.String()).
Uint64("machineId", machine.ID).
Msg("Failed to enable approved route")
return err
}
}
return nil
} }

View file

@ -1,9 +1,11 @@
package hscontrol package db
import ( import (
"net/netip" "net/netip"
"time" "time"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"gopkg.in/check.v1" "gopkg.in/check.v1"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
@ -11,13 +13,13 @@ import (
) )
func (s *Suite) TestGetRoutes(c *check.C) { func (s *Suite) TestGetRoutes(c *check.C) {
user, err := app.db.CreateUser("test") user, err := db.CreateUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = app.db.GetMachine("test", "test_get_route_machine") _, err = db.GetMachine("test", "test_get_route_machine")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
route, err := netip.ParsePrefix("10.0.0.0/24") route, err := netip.ParsePrefix("10.0.0.0/24")
@ -27,41 +29,43 @@ func (s *Suite) TestGetRoutes(c *check.C) {
RoutableIPs: []netip.Prefix{route}, RoutableIPs: []netip.Prefix{route},
} }
machine := Machine{ machine := types.Machine{
ID: 0, ID: 0,
MachineKey: "foo", MachineKey: "foo",
NodeKey: "bar", NodeKey: "bar",
DiscoKey: "faa", DiscoKey: "faa",
Hostname: "test_get_route_machine", Hostname: "test_get_route_machine",
UserID: user.ID, UserID: user.ID,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
HostInfo: HostInfo(hostInfo), HostInfo: types.HostInfo(hostInfo),
} }
app.db.db.Save(&machine) db.db.Save(&machine)
err = app.db.processMachineRoutes(&machine) err = db.ProcessMachineRoutes(&machine)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
advertisedRoutes, err := app.db.GetAdvertisedRoutes(&machine) advertisedRoutes, err := db.GetAdvertisedRoutes(&machine)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(advertisedRoutes), check.Equals, 1) c.Assert(len(advertisedRoutes), check.Equals, 1)
err = app.db.enableRoutes(&machine, "192.168.0.0/24") err = db.enableRoutes(&machine, "192.168.0.0/24")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
err = app.db.enableRoutes(&machine, "10.0.0.0/24") err = db.enableRoutes(&machine, "10.0.0.0/24")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(channelUpdates, check.Equals, int32(0))
} }
func (s *Suite) TestGetEnableRoutes(c *check.C) { func (s *Suite) TestGetEnableRoutes(c *check.C) {
user, err := app.db.CreateUser("test") user, err := db.CreateUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = app.db.GetMachine("test", "test_enable_route_machine") _, err = db.GetMachine("test", "test_enable_route_machine")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
route, err := netip.ParsePrefix( route, err := netip.ParsePrefix(
@ -78,65 +82,67 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
RoutableIPs: []netip.Prefix{route, route2}, RoutableIPs: []netip.Prefix{route, route2},
} }
machine := Machine{ machine := types.Machine{
ID: 0, ID: 0,
MachineKey: "foo", MachineKey: "foo",
NodeKey: "bar", NodeKey: "bar",
DiscoKey: "faa", DiscoKey: "faa",
Hostname: "test_enable_route_machine", Hostname: "test_enable_route_machine",
UserID: user.ID, UserID: user.ID,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
HostInfo: HostInfo(hostInfo), HostInfo: types.HostInfo(hostInfo),
} }
app.db.db.Save(&machine) db.db.Save(&machine)
err = app.db.processMachineRoutes(&machine) err = db.ProcessMachineRoutes(&machine)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
availableRoutes, err := app.db.GetAdvertisedRoutes(&machine) availableRoutes, err := db.GetAdvertisedRoutes(&machine)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(availableRoutes), check.Equals, 2) c.Assert(len(availableRoutes), check.Equals, 2)
noEnabledRoutes, err := app.db.GetEnabledRoutes(&machine) noEnabledRoutes, err := db.GetEnabledRoutes(&machine)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(noEnabledRoutes), check.Equals, 0) c.Assert(len(noEnabledRoutes), check.Equals, 0)
err = app.db.enableRoutes(&machine, "192.168.0.0/24") err = db.enableRoutes(&machine, "192.168.0.0/24")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
err = app.db.enableRoutes(&machine, "10.0.0.0/24") err = db.enableRoutes(&machine, "10.0.0.0/24")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
enabledRoutes, err := app.db.GetEnabledRoutes(&machine) enabledRoutes, err := db.GetEnabledRoutes(&machine)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes), check.Equals, 1) c.Assert(len(enabledRoutes), check.Equals, 1)
// Adding it twice will just let it pass through // Adding it twice will just let it pass through
err = app.db.enableRoutes(&machine, "10.0.0.0/24") err = db.enableRoutes(&machine, "10.0.0.0/24")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
enableRoutesAfterDoubleApply, err := app.db.GetEnabledRoutes(&machine) enableRoutesAfterDoubleApply, err := db.GetEnabledRoutes(&machine)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(enableRoutesAfterDoubleApply), check.Equals, 1) c.Assert(len(enableRoutesAfterDoubleApply), check.Equals, 1)
err = app.db.enableRoutes(&machine, "150.0.10.0/25") err = db.enableRoutes(&machine, "150.0.10.0/25")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
enabledRoutesWithAdditionalRoute, err := app.db.GetEnabledRoutes(&machine) enabledRoutesWithAdditionalRoute, err := db.GetEnabledRoutes(&machine)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutesWithAdditionalRoute), check.Equals, 2) c.Assert(len(enabledRoutesWithAdditionalRoute), check.Equals, 2)
c.Assert(channelUpdates, check.Equals, int32(3))
} }
func (s *Suite) TestIsUniquePrefix(c *check.C) { func (s *Suite) TestIsUniquePrefix(c *check.C) {
user, err := app.db.CreateUser("test") user, err := db.CreateUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = app.db.GetMachine("test", "test_enable_route_machine") _, err = db.GetMachine("test", "test_enable_route_machine")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
route, err := netip.ParsePrefix( route, err := netip.ParsePrefix(
@ -152,75 +158,77 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
hostInfo1 := tailcfg.Hostinfo{ hostInfo1 := tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{route, route2}, RoutableIPs: []netip.Prefix{route, route2},
} }
machine1 := Machine{ machine1 := types.Machine{
ID: 1, ID: 1,
MachineKey: "foo", MachineKey: "foo",
NodeKey: "bar", NodeKey: "bar",
DiscoKey: "faa", DiscoKey: "faa",
Hostname: "test_enable_route_machine", Hostname: "test_enable_route_machine",
UserID: user.ID, UserID: user.ID,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
HostInfo: HostInfo(hostInfo1), HostInfo: types.HostInfo(hostInfo1),
} }
app.db.db.Save(&machine1) db.db.Save(&machine1)
err = app.db.processMachineRoutes(&machine1) err = db.ProcessMachineRoutes(&machine1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
err = app.db.enableRoutes(&machine1, route.String()) err = db.enableRoutes(&machine1, route.String())
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
err = app.db.enableRoutes(&machine1, route2.String()) err = db.enableRoutes(&machine1, route2.String())
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
hostInfo2 := tailcfg.Hostinfo{ hostInfo2 := tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{route2}, RoutableIPs: []netip.Prefix{route2},
} }
machine2 := Machine{ machine2 := types.Machine{
ID: 2, ID: 2,
MachineKey: "foo", MachineKey: "foo",
NodeKey: "bar", NodeKey: "bar",
DiscoKey: "faa", DiscoKey: "faa",
Hostname: "test_enable_route_machine", Hostname: "test_enable_route_machine",
UserID: user.ID, UserID: user.ID,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
HostInfo: HostInfo(hostInfo2), HostInfo: types.HostInfo(hostInfo2),
} }
app.db.db.Save(&machine2) db.db.Save(&machine2)
err = app.db.processMachineRoutes(&machine2) err = db.ProcessMachineRoutes(&machine2)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
err = app.db.enableRoutes(&machine2, route2.String()) err = db.enableRoutes(&machine2, route2.String())
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1) enabledRoutes1, err := db.GetEnabledRoutes(&machine1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes1), check.Equals, 2) c.Assert(len(enabledRoutes1), check.Equals, 2)
enabledRoutes2, err := app.db.GetEnabledRoutes(&machine2) enabledRoutes2, err := db.GetEnabledRoutes(&machine2)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes2), check.Equals, 1) c.Assert(len(enabledRoutes2), check.Equals, 1)
routes, err := app.db.getMachinePrimaryRoutes(&machine1) routes, err := db.GetMachinePrimaryRoutes(&machine1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(routes), check.Equals, 2) c.Assert(len(routes), check.Equals, 2)
routes, err = app.db.getMachinePrimaryRoutes(&machine2) routes, err = db.GetMachinePrimaryRoutes(&machine2)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(routes), check.Equals, 0) c.Assert(len(routes), check.Equals, 0)
c.Assert(channelUpdates, check.Equals, int32(3))
} }
func (s *Suite) TestSubnetFailover(c *check.C) { func (s *Suite) TestSubnetFailover(c *check.C) {
user, err := app.db.CreateUser("test") user, err := db.CreateUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = app.db.GetMachine("test", "test_enable_route_machine") _, err = db.GetMachine("test", "test_enable_route_machine")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
prefix, err := netip.ParsePrefix( prefix, err := netip.ParsePrefix(
@ -238,134 +246,136 @@ func (s *Suite) TestSubnetFailover(c *check.C) {
} }
now := time.Now() now := time.Now()
machine1 := Machine{ machine1 := types.Machine{
ID: 1, ID: 1,
MachineKey: "foo", MachineKey: "foo",
NodeKey: "bar", NodeKey: "bar",
DiscoKey: "faa", DiscoKey: "faa",
Hostname: "test_enable_route_machine", Hostname: "test_enable_route_machine",
UserID: user.ID, UserID: user.ID,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
HostInfo: HostInfo(hostInfo1), HostInfo: types.HostInfo(hostInfo1),
LastSeen: &now, LastSeen: &now,
} }
app.db.db.Save(&machine1) db.db.Save(&machine1)
err = app.db.processMachineRoutes(&machine1) err = db.ProcessMachineRoutes(&machine1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
err = app.db.enableRoutes(&machine1, prefix.String()) err = db.enableRoutes(&machine1, prefix.String())
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
err = app.db.enableRoutes(&machine1, prefix2.String()) err = db.enableRoutes(&machine1, prefix2.String())
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
err = app.db.handlePrimarySubnetFailover() err = db.HandlePrimarySubnetFailover()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1) enabledRoutes1, err := db.GetEnabledRoutes(&machine1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes1), check.Equals, 2) c.Assert(len(enabledRoutes1), check.Equals, 2)
route, err := app.db.getPrimaryRoute(prefix) route, err := db.getPrimaryRoute(prefix)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(route.MachineID, check.Equals, machine1.ID) c.Assert(route.MachineID, check.Equals, machine1.ID)
hostInfo2 := tailcfg.Hostinfo{ hostInfo2 := tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{prefix2}, RoutableIPs: []netip.Prefix{prefix2},
} }
machine2 := Machine{ machine2 := types.Machine{
ID: 2, ID: 2,
MachineKey: "foo", MachineKey: "foo",
NodeKey: "bar", NodeKey: "bar",
DiscoKey: "faa", DiscoKey: "faa",
Hostname: "test_enable_route_machine", Hostname: "test_enable_route_machine",
UserID: user.ID, UserID: user.ID,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
HostInfo: HostInfo(hostInfo2), HostInfo: types.HostInfo(hostInfo2),
LastSeen: &now, LastSeen: &now,
} }
app.db.db.Save(&machine2) db.db.Save(&machine2)
err = app.db.processMachineRoutes(&machine2) err = db.ProcessMachineRoutes(&machine2)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
err = app.db.enableRoutes(&machine2, prefix2.String()) err = db.enableRoutes(&machine2, prefix2.String())
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
err = app.db.handlePrimarySubnetFailover() err = db.HandlePrimarySubnetFailover()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
enabledRoutes1, err = app.db.GetEnabledRoutes(&machine1) enabledRoutes1, err = db.GetEnabledRoutes(&machine1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes1), check.Equals, 2) c.Assert(len(enabledRoutes1), check.Equals, 2)
enabledRoutes2, err := app.db.GetEnabledRoutes(&machine2) enabledRoutes2, err := db.GetEnabledRoutes(&machine2)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes2), check.Equals, 1) c.Assert(len(enabledRoutes2), check.Equals, 1)
routes, err := app.db.getMachinePrimaryRoutes(&machine1) routes, err := db.GetMachinePrimaryRoutes(&machine1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(routes), check.Equals, 2) c.Assert(len(routes), check.Equals, 2)
routes, err = app.db.getMachinePrimaryRoutes(&machine2) routes, err = db.GetMachinePrimaryRoutes(&machine2)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(routes), check.Equals, 0) c.Assert(len(routes), check.Equals, 0)
// lets make machine1 lastseen 10 mins ago // lets make machine1 lastseen 10 mins ago
before := now.Add(-10 * time.Minute) before := now.Add(-10 * time.Minute)
machine1.LastSeen = &before machine1.LastSeen = &before
err = app.db.db.Save(&machine1).Error err = db.db.Save(&machine1).Error
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
err = app.db.handlePrimarySubnetFailover() err = db.HandlePrimarySubnetFailover()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
routes, err = app.db.getMachinePrimaryRoutes(&machine1) routes, err = db.GetMachinePrimaryRoutes(&machine1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(routes), check.Equals, 1) c.Assert(len(routes), check.Equals, 1)
routes, err = app.db.getMachinePrimaryRoutes(&machine2) routes, err = db.GetMachinePrimaryRoutes(&machine2)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(routes), check.Equals, 1) c.Assert(len(routes), check.Equals, 1)
machine2.HostInfo = HostInfo(tailcfg.Hostinfo{ machine2.HostInfo = types.HostInfo(tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{prefix, prefix2}, RoutableIPs: []netip.Prefix{prefix, prefix2},
}) })
err = app.db.db.Save(&machine2).Error err = db.db.Save(&machine2).Error
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
err = app.db.processMachineRoutes(&machine2) err = db.ProcessMachineRoutes(&machine2)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
err = app.db.enableRoutes(&machine2, prefix.String()) err = db.enableRoutes(&machine2, prefix.String())
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
err = app.db.handlePrimarySubnetFailover() err = db.HandlePrimarySubnetFailover()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
routes, err = app.db.getMachinePrimaryRoutes(&machine1) routes, err = db.GetMachinePrimaryRoutes(&machine1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(routes), check.Equals, 0) c.Assert(len(routes), check.Equals, 0)
routes, err = app.db.getMachinePrimaryRoutes(&machine2) routes, err = db.GetMachinePrimaryRoutes(&machine2)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(routes), check.Equals, 2) c.Assert(len(routes), check.Equals, 2)
c.Assert(channelUpdates, check.Equals, int32(6))
} }
// TestAllowedIPRoutes tests that the AllowedIPs are correctly set for a node, // TestAllowedIPRoutes tests that the AllowedIPs are correctly set for a node,
// including both the primary routes the node is responsible for, and the // including both the primary routes the node is responsible for, and the
// exit node routes if enabled. // exit node routes if enabled.
func (s *Suite) TestAllowedIPRoutes(c *check.C) { func (s *Suite) TestAllowedIPRoutes(c *check.C) {
user, err := app.db.CreateUser("test") user, err := db.CreateUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = app.db.GetMachine("test", "test_enable_route_machine") _, err = db.GetMachine("test", "test_enable_route_machine")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
prefix, err := netip.ParsePrefix( prefix, err := netip.ParsePrefix(
@ -397,35 +407,35 @@ func (s *Suite) TestAllowedIPRoutes(c *check.C) {
machineKey := key.NewMachine() machineKey := key.NewMachine()
now := time.Now() now := time.Now()
machine1 := Machine{ machine1 := types.Machine{
ID: 1, ID: 1,
MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()), MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()),
NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()), NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()),
DiscoKey: util.DiscoPublicKeyStripPrefix(discoKey.Public()), DiscoKey: util.DiscoPublicKeyStripPrefix(discoKey.Public()),
Hostname: "test_enable_route_machine", Hostname: "test_enable_route_machine",
UserID: user.ID, UserID: user.ID,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
HostInfo: HostInfo(hostInfo1), HostInfo: types.HostInfo(hostInfo1),
LastSeen: &now, LastSeen: &now,
} }
app.db.db.Save(&machine1) db.db.Save(&machine1)
err = app.db.processMachineRoutes(&machine1) err = db.ProcessMachineRoutes(&machine1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
err = app.db.enableRoutes(&machine1, prefix.String()) err = db.enableRoutes(&machine1, prefix.String())
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
// We do not enable this one on purpose to test that it is not enabled // We do not enable this one on purpose to test that it is not enabled
// err = app.db.enableRoutes(&machine1, prefix2.String()) // err = db.enableRoutes(&machine1, prefix2.String())
// c.Assert(err, check.IsNil) // c.Assert(err, check.IsNil)
routes, err := app.db.GetMachineRoutes(&machine1) routes, err := db.GetMachineRoutes(&machine1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
for _, route := range routes { for _, route := range routes {
if route.isExitRoute() { if route.IsExitRoute() {
err = app.db.EnableRoute(uint64(route.ID)) err = db.EnableRoute(uint64(route.ID))
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
// We only enable one exit route, so we can test that both are enabled // We only enable one exit route, so we can test that both are enabled
@ -433,14 +443,14 @@ func (s *Suite) TestAllowedIPRoutes(c *check.C) {
} }
} }
err = app.db.handlePrimarySubnetFailover() err = db.HandlePrimarySubnetFailover()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1) enabledRoutes1, err := db.GetEnabledRoutes(&machine1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes1), check.Equals, 3) c.Assert(len(enabledRoutes1), check.Equals, 3)
peer, err := app.db.toNode(machine1, app.aclPolicy, "headscale.net", nil) peer, err := db.TailNode(machine1, &policy.ACLPolicy{}, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(peer.AllowedIPs), check.Equals, 3) c.Assert(len(peer.AllowedIPs), check.Equals, 3)
@ -461,44 +471,46 @@ func (s *Suite) TestAllowedIPRoutes(c *check.C) {
// Now we disable only one of the exit routes // Now we disable only one of the exit routes
// and we see if both are disabled // and we see if both are disabled
var exitRouteV4 Route var exitRouteV4 types.Route
for _, route := range routes { for _, route := range routes {
if route.isExitRoute() && netip.Prefix(route.Prefix) == prefixExitNodeV4 { if route.IsExitRoute() && netip.Prefix(route.Prefix) == prefixExitNodeV4 {
exitRouteV4 = route exitRouteV4 = route
break break
} }
} }
err = app.db.DisableRoute(uint64(exitRouteV4.ID)) err = db.DisableRoute(uint64(exitRouteV4.ID))
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
enabledRoutes1, err = app.db.GetEnabledRoutes(&machine1) enabledRoutes1, err = db.GetEnabledRoutes(&machine1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes1), check.Equals, 1) c.Assert(len(enabledRoutes1), check.Equals, 1)
// and now we delete only one of the exit routes // and now we delete only one of the exit routes
// and we check if both are deleted // and we check if both are deleted
routes, err = app.db.GetMachineRoutes(&machine1) routes, err = db.GetMachineRoutes(&machine1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(routes), check.Equals, 4) c.Assert(len(routes), check.Equals, 4)
err = app.db.DeleteRoute(uint64(exitRouteV4.ID)) err = db.DeleteRoute(uint64(exitRouteV4.ID))
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
routes, err = app.db.GetMachineRoutes(&machine1) routes, err = db.GetMachineRoutes(&machine1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(routes), check.Equals, 2) c.Assert(len(routes), check.Equals, 2)
c.Assert(channelUpdates, check.Equals, int32(2))
} }
func (s *Suite) TestDeleteRoutes(c *check.C) { func (s *Suite) TestDeleteRoutes(c *check.C) {
user, err := app.db.CreateUser("test") user, err := db.CreateUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = app.db.GetMachine("test", "test_enable_route_machine") _, err = db.GetMachine("test", "test_enable_route_machine")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
prefix, err := netip.ParsePrefix( prefix, err := netip.ParsePrefix(
@ -516,36 +528,38 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
} }
now := time.Now() now := time.Now()
machine1 := Machine{ machine1 := types.Machine{
ID: 1, ID: 1,
MachineKey: "foo", MachineKey: "foo",
NodeKey: "bar", NodeKey: "bar",
DiscoKey: "faa", DiscoKey: "faa",
Hostname: "test_enable_route_machine", Hostname: "test_enable_route_machine",
UserID: user.ID, UserID: user.ID,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
HostInfo: HostInfo(hostInfo1), HostInfo: types.HostInfo(hostInfo1),
LastSeen: &now, LastSeen: &now,
} }
app.db.db.Save(&machine1) db.db.Save(&machine1)
err = app.db.processMachineRoutes(&machine1) err = db.ProcessMachineRoutes(&machine1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
err = app.db.enableRoutes(&machine1, prefix.String()) err = db.enableRoutes(&machine1, prefix.String())
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
err = app.db.enableRoutes(&machine1, prefix2.String()) err = db.enableRoutes(&machine1, prefix2.String())
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
routes, err := app.db.GetMachineRoutes(&machine1) routes, err := db.GetMachineRoutes(&machine1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
err = app.db.DeleteRoute(uint64(routes[0].ID)) err = db.DeleteRoute(uint64(routes[0].ID))
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1) enabledRoutes1, err := db.GetEnabledRoutes(&machine1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes1), check.Equals, 1) c.Assert(len(enabledRoutes1), check.Equals, 1)
c.Assert(channelUpdates, check.Equals, int32(2))
} }

View file

@ -0,0 +1,74 @@
package db
import (
"net/netip"
"os"
"sync/atomic"
"testing"
"gopkg.in/check.v1"
)
func Test(t *testing.T) {
check.TestingT(t)
}
var _ = check.Suite(&Suite{})
type Suite struct{}
var (
tmpDir string
db *HSDatabase
// channelUpdates counts the number of times
// either of the channels was notified.
channelUpdates int32
)
func (s *Suite) SetUpTest(c *check.C) {
atomic.StoreInt32(&channelUpdates, 0)
s.ResetDB(c)
}
func (s *Suite) TearDownTest(c *check.C) {
os.RemoveAll(tmpDir)
}
func notificationSink(c <-chan struct{}) {
for {
<-c
atomic.AddInt32(&channelUpdates, 1)
}
}
func (s *Suite) ResetDB(c *check.C) {
if len(tmpDir) != 0 {
os.RemoveAll(tmpDir)
}
var err error
tmpDir, err = os.MkdirTemp("", "autoygg-client-test")
if err != nil {
c.Fatal(err)
}
sink := make(chan struct{})
go notificationSink(sink)
db, err = NewHeadscaleDatabase(
"sqlite3",
tmpDir+"/headscale_test.db",
false,
false,
sink,
sink,
[]netip.Prefix{
netip.MustParsePrefix("10.27.0.0/23"),
},
"",
)
if err != nil {
c.Fatal(err)
}
}

View file

@ -1,17 +1,12 @@
package hscontrol package db
import ( import (
"errors" "errors"
"fmt" "fmt"
"regexp"
"strconv"
"strings"
"time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"google.golang.org/protobuf/types/known/timestamppb"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
) )
@ -20,33 +15,16 @@ var (
ErrUserExists = errors.New("user already exists") ErrUserExists = errors.New("user already exists")
ErrUserNotFound = errors.New("user not found") ErrUserNotFound = errors.New("user not found")
ErrUserStillHasNodes = errors.New("user not empty: node(s) found") ErrUserStillHasNodes = errors.New("user not empty: node(s) found")
ErrInvalidUserName = errors.New("invalid user name")
) )
const (
// value related to RFC 1123 and 952.
labelHostnameLength = 63
)
var invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+")
// User is the way Headscale implements the concept of users in Tailscale
//
// At the end of the day, users in Tailscale are some kind of 'bubbles' or users
// that contain our machines.
type User struct {
gorm.Model
Name string `gorm:"unique"`
}
// CreateUser creates a new User. Returns error if could not be created // CreateUser creates a new User. Returns error if could not be created
// or another user already exists. // or another user already exists.
func (hsdb *HSDatabase) CreateUser(name string) (*User, error) { func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) {
err := CheckForFQDNRules(name) err := util.CheckForFQDNRules(name)
if err != nil { if err != nil {
return nil, err return nil, err
} }
user := User{} user := types.User{}
if err := hsdb.db.Where("name = ?", name).First(&user).Error; err == nil { if err := hsdb.db.Where("name = ?", name).First(&user).Error; err == nil {
return nil, ErrUserExists return nil, ErrUserExists
} }
@ -105,7 +83,7 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
if err != nil { if err != nil {
return err return err
} }
err = CheckForFQDNRules(newName) err = util.CheckForFQDNRules(newName)
if err != nil { if err != nil {
return err return err
} }
@ -127,8 +105,8 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
} }
// GetUser fetches a user by name. // GetUser fetches a user by name.
func (hsdb *HSDatabase) GetUser(name string) (*User, error) { func (hsdb *HSDatabase) GetUser(name string) (*types.User, error) {
user := User{} user := types.User{}
if result := hsdb.db.First(&user, "name = ?", name); errors.Is( if result := hsdb.db.First(&user, "name = ?", name); errors.Is(
result.Error, result.Error,
gorm.ErrRecordNotFound, gorm.ErrRecordNotFound,
@ -140,8 +118,8 @@ func (hsdb *HSDatabase) GetUser(name string) (*User, error) {
} }
// ListUsers gets all the existing users. // ListUsers gets all the existing users.
func (hsdb *HSDatabase) ListUsers() ([]User, error) { func (hsdb *HSDatabase) ListUsers() ([]types.User, error) {
users := []User{} users := []types.User{}
if err := hsdb.db.Find(&users).Error; err != nil { if err := hsdb.db.Find(&users).Error; err != nil {
return nil, err return nil, err
} }
@ -150,8 +128,8 @@ func (hsdb *HSDatabase) ListUsers() ([]User, error) {
} }
// ListMachinesByUser gets all the nodes in a given user. // ListMachinesByUser gets all the nodes in a given user.
func (hsdb *HSDatabase) ListMachinesByUser(name string) ([]Machine, error) { func (hsdb *HSDatabase) ListMachinesByUser(name string) (types.Machines, error) {
err := CheckForFQDNRules(name) err := util.CheckForFQDNRules(name)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -160,8 +138,8 @@ func (hsdb *HSDatabase) ListMachinesByUser(name string) ([]Machine, error) {
return nil, err return nil, err
} }
machines := []Machine{} machines := types.Machines{}
if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&Machine{UserID: user.ID}).Find(&machines).Error; err != nil { if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Machine{UserID: user.ID}).Find(&machines).Error; err != nil {
return nil, err return nil, err
} }
@ -169,8 +147,8 @@ func (hsdb *HSDatabase) ListMachinesByUser(name string) ([]Machine, error) {
} }
// SetMachineUser assigns a Machine to a user. // SetMachineUser assigns a Machine to a user.
func (hsdb *HSDatabase) SetMachineUser(machine *Machine, username string) error { func (hsdb *HSDatabase) SetMachineUser(machine *types.Machine, username string) error {
err := CheckForFQDNRules(username) err := util.CheckForFQDNRules(username)
if err != nil { if err != nil {
return err return err
} }
@ -186,37 +164,11 @@ func (hsdb *HSDatabase) SetMachineUser(machine *Machine, username string) error
return nil return nil
} }
func (n *User) toTailscaleUser() *tailcfg.User { func (hsdb *HSDatabase) GetMapResponseUserProfiles(
user := tailcfg.User{ machine types.Machine,
ID: tailcfg.UserID(n.ID), peers types.Machines,
LoginName: n.Name,
DisplayName: n.Name,
ProfilePicURL: "",
Domain: "headscale.net",
Logins: []tailcfg.LoginID{},
Created: time.Time{},
}
return &user
}
func (n *User) toTailscaleLogin() *tailcfg.Login {
login := tailcfg.Login{
ID: tailcfg.LoginID(n.ID),
LoginName: n.Name,
DisplayName: n.Name,
ProfilePicURL: "",
Domain: "headscale.net",
}
return &login
}
func (hsdb *HSDatabase) getMapResponseUserProfiles(
machine Machine,
peers Machines,
) []tailcfg.UserProfile { ) []tailcfg.UserProfile {
userMap := make(map[string]User) userMap := make(map[string]types.User)
userMap[machine.User.Name] = machine.User userMap[machine.User.Name] = machine.User
for _, peer := range peers { for _, peer := range peers {
userMap[peer.User.Name] = peer.User // not worth checking if already is there userMap[peer.User.Name] = peer.User // not worth checking if already is there
@ -240,63 +192,3 @@ func (hsdb *HSDatabase) getMapResponseUserProfiles(
return profiles return profiles
} }
func (n *User) toProto() *v1.User {
return &v1.User{
Id: strconv.FormatUint(uint64(n.ID), util.Base10),
Name: n.Name,
CreatedAt: timestamppb.New(n.CreatedAt),
}
}
// NormalizeToFQDNRules will replace forbidden chars in user
// it can also return an error if the user doesn't respect RFC 952 and 1123.
func NormalizeToFQDNRules(name string, stripEmailDomain bool) (string, error) {
name = strings.ToLower(name)
name = strings.ReplaceAll(name, "'", "")
atIdx := strings.Index(name, "@")
if stripEmailDomain && atIdx > 0 {
name = name[:atIdx]
} else {
name = strings.ReplaceAll(name, "@", ".")
}
name = invalidCharsInUserRegex.ReplaceAllString(name, "-")
for _, elt := range strings.Split(name, ".") {
if len(elt) > labelHostnameLength {
return "", fmt.Errorf(
"label %v is more than 63 chars: %w",
elt,
ErrInvalidUserName,
)
}
}
return name, nil
}
func CheckForFQDNRules(name string) error {
if len(name) > labelHostnameLength {
return fmt.Errorf(
"DNS segment must not be over 63 chars. %v doesn't comply with this rule: %w",
name,
ErrInvalidUserName,
)
}
if strings.ToLower(name) != name {
return fmt.Errorf(
"DNS segment should be lowercase. %v doesn't comply with this rule: %w",
name,
ErrInvalidUserName,
)
}
if invalidCharsInUserRegex.MatchString(name) {
return fmt.Errorf(
"DNS segment should only be composed of lowercase ASCII letters numbers, hyphen and dots. %v doesn't comply with theses rules: %w",
name,
ErrInvalidUserName,
)
}
return nil
}

277
hscontrol/db/users_test.go Normal file
View file

@ -0,0 +1,277 @@
package db
import (
"net/netip"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"gopkg.in/check.v1"
"gorm.io/gorm"
"tailscale.com/tailcfg"
)
func (s *Suite) TestCreateAndDestroyUser(c *check.C) {
user, err := db.CreateUser("test")
c.Assert(err, check.IsNil)
c.Assert(user.Name, check.Equals, "test")
users, err := db.ListUsers()
c.Assert(err, check.IsNil)
c.Assert(len(users), check.Equals, 1)
err = db.DestroyUser("test")
c.Assert(err, check.IsNil)
_, err = db.GetUser("test")
c.Assert(err, check.NotNil)
}
func (s *Suite) TestDestroyUserErrors(c *check.C) {
err := db.DestroyUser("test")
c.Assert(err, check.Equals, ErrUserNotFound)
user, err := db.CreateUser("test")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
err = db.DestroyUser("test")
c.Assert(err, check.IsNil)
result := db.db.Preload("User").First(&pak, "key = ?", pak.Key)
// destroying a user also deletes all associated preauthkeys
c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound)
user, err = db.CreateUser("test")
c.Assert(err, check.IsNil)
pak, err = db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
machine := types.Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testmachine",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
db.db.Save(&machine)
err = db.DestroyUser("test")
c.Assert(err, check.Equals, ErrUserStillHasNodes)
}
func (s *Suite) TestRenameUser(c *check.C) {
userTest, err := db.CreateUser("test")
c.Assert(err, check.IsNil)
c.Assert(userTest.Name, check.Equals, "test")
users, err := db.ListUsers()
c.Assert(err, check.IsNil)
c.Assert(len(users), check.Equals, 1)
err = db.RenameUser("test", "test-renamed")
c.Assert(err, check.IsNil)
_, err = db.GetUser("test")
c.Assert(err, check.Equals, ErrUserNotFound)
_, err = db.GetUser("test-renamed")
c.Assert(err, check.IsNil)
err = db.RenameUser("test-does-not-exit", "test")
c.Assert(err, check.Equals, ErrUserNotFound)
userTest2, err := db.CreateUser("test2")
c.Assert(err, check.IsNil)
c.Assert(userTest2.Name, check.Equals, "test2")
err = db.RenameUser("test2", "test-renamed")
c.Assert(err, check.Equals, ErrUserExists)
}
func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
userShared1, err := db.CreateUser("shared1")
c.Assert(err, check.IsNil)
userShared2, err := db.CreateUser("shared2")
c.Assert(err, check.IsNil)
userShared3, err := db.CreateUser("shared3")
c.Assert(err, check.IsNil)
preAuthKeyShared1, err := db.CreatePreAuthKey(
userShared1.Name,
false,
false,
nil,
nil,
)
c.Assert(err, check.IsNil)
preAuthKeyShared2, err := db.CreatePreAuthKey(
userShared2.Name,
false,
false,
nil,
nil,
)
c.Assert(err, check.IsNil)
preAuthKeyShared3, err := db.CreatePreAuthKey(
userShared3.Name,
false,
false,
nil,
nil,
)
c.Assert(err, check.IsNil)
preAuthKey2Shared1, err := db.CreatePreAuthKey(
userShared1.Name,
false,
false,
nil,
nil,
)
c.Assert(err, check.IsNil)
_, err = db.GetMachine(userShared1.Name, "test_get_shared_nodes_1")
c.Assert(err, check.NotNil)
machineInShared1 := &types.Machine{
ID: 1,
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
Hostname: "test_get_shared_nodes_1",
UserID: userShared1.ID,
User: *userShared1,
RegisterMethod: util.RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
AuthKeyID: uint(preAuthKeyShared1.ID),
}
db.db.Save(machineInShared1)
_, err = db.GetMachine(userShared1.Name, machineInShared1.Hostname)
c.Assert(err, check.IsNil)
machineInShared2 := &types.Machine{
ID: 2,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Hostname: "test_get_shared_nodes_2",
UserID: userShared2.ID,
User: *userShared2,
RegisterMethod: util.RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
AuthKeyID: uint(preAuthKeyShared2.ID),
}
db.db.Save(machineInShared2)
_, err = db.GetMachine(userShared2.Name, machineInShared2.Hostname)
c.Assert(err, check.IsNil)
machineInShared3 := &types.Machine{
ID: 3,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Hostname: "test_get_shared_nodes_3",
UserID: userShared3.ID,
User: *userShared3,
RegisterMethod: util.RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
AuthKeyID: uint(preAuthKeyShared3.ID),
}
db.db.Save(machineInShared3)
_, err = db.GetMachine(userShared3.Name, machineInShared3.Hostname)
c.Assert(err, check.IsNil)
machine2InShared1 := &types.Machine{
ID: 4,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Hostname: "test_get_shared_nodes_4",
UserID: userShared1.ID,
User: *userShared1,
RegisterMethod: util.RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")},
AuthKeyID: uint(preAuthKey2Shared1.ID),
}
db.db.Save(machine2InShared1)
peersOfMachine1InShared1, err := db.getPeers([]tailcfg.FilterRule{}, machineInShared1)
c.Assert(err, check.IsNil)
userProfiles := db.GetMapResponseUserProfiles(
*machineInShared1,
peersOfMachine1InShared1,
)
c.Assert(len(userProfiles), check.Equals, 3)
found := false
for _, userProfiles := range userProfiles {
if userProfiles.DisplayName == userShared1.Name {
found = true
break
}
}
c.Assert(found, check.Equals, true)
found = false
for _, userProfile := range userProfiles {
if userProfile.DisplayName == userShared2.Name {
found = true
break
}
}
c.Assert(found, check.Equals, true)
}
func (s *Suite) TestSetMachineUser(c *check.C) {
oldUser, err := db.CreateUser("old")
c.Assert(err, check.IsNil)
newUser, err := db.CreateUser("new")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(oldUser.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
machine := types.Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testmachine",
UserID: oldUser.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
db.db.Save(&machine)
c.Assert(machine.UserID, check.Equals, oldUser.ID)
err = db.SetMachineUser(&machine, newUser.Name)
c.Assert(err, check.IsNil)
c.Assert(machine.UserID, check.Equals, newUser.ID)
c.Assert(machine.User.Name, check.Equals, newUser.Name)
err = db.SetMachineUser(&machine, "non-existing-user")
c.Assert(err, check.Equals, ErrUserNotFound)
err = db.SetMachineUser(&machine, newUser.Name)
c.Assert(err, check.IsNil)
c.Assert(machine.UserID, check.Equals, newUser.ID)
c.Assert(machine.User.Name, check.Equals, newUser.Name)
}

View file

@ -7,6 +7,7 @@ import (
"strings" "strings"
mapset "github.com/deckarep/golang-set/v2" mapset "github.com/deckarep/golang-set/v2"
"github.com/juanfont/headscale/hscontrol/types"
"go4.org/netipx" "go4.org/netipx"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/dnstype" "tailscale.com/types/dnstype"
@ -165,7 +166,7 @@ func generateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
// //
// This will produce a resolver like: // This will produce a resolver like:
// `https://dns.nextdns.io/<nextdns-id>?device_name=node-name&device_model=linux&device_ip=100.64.0.1` // `https://dns.nextdns.io/<nextdns-id>?device_name=node-name&device_model=linux&device_ip=100.64.0.1`
func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine Machine) { func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine types.Machine) {
for _, resolver := range resolvers { for _, resolver := range resolvers {
if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) { if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) {
attrs := url.Values{ attrs := url.Values{
@ -185,8 +186,8 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine Machine) {
func getMapResponseDNSConfig( func getMapResponseDNSConfig(
dnsConfigOrig *tailcfg.DNSConfig, dnsConfigOrig *tailcfg.DNSConfig,
baseDomain string, baseDomain string,
machine Machine, machine types.Machine,
peers Machines, peers types.Machines,
) *tailcfg.DNSConfig { ) *tailcfg.DNSConfig {
var dnsConfig *tailcfg.DNSConfig = dnsConfigOrig.Clone() var dnsConfig *tailcfg.DNSConfig = dnsConfigOrig.Clone()
if dnsConfigOrig != nil && dnsConfigOrig.Proxied { // if MagicDNS is enabled if dnsConfigOrig != nil && dnsConfigOrig.Proxied { // if MagicDNS is enabled
@ -200,7 +201,7 @@ func getMapResponseDNSConfig(
), ),
) )
userSet := mapset.NewSet[User]() userSet := mapset.NewSet[types.User]()
userSet.Add(machine.User) userSet.Add(machine.User)
for _, p := range peers { for _, p := range peers {
userSet.Add(p.User) userSet.Add(p.User)

View file

@ -4,6 +4,8 @@ import (
"fmt" "fmt"
"net/netip" "net/netip"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"gopkg.in/check.v1" "gopkg.in/check.v1"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/dnstype" "tailscale.com/types/dnstype"
@ -160,7 +162,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
_, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1") _, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
machineInShared1 := &Machine{ machineInShared1 := &types.Machine{
ID: 1, ID: 1,
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
@ -168,16 +170,17 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
Hostname: "test_get_shared_nodes_1", Hostname: "test_get_shared_nodes_1",
UserID: userShared1.ID, UserID: userShared1.ID,
User: *userShared1, User: *userShared1,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")}, IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
AuthKeyID: uint(preAuthKeyInShared1.ID), AuthKeyID: uint(preAuthKeyInShared1.ID),
} }
app.db.db.Save(machineInShared1) err = app.db.MachineSave(machineInShared1)
c.Assert(err, check.IsNil)
_, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname) _, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
machineInShared2 := &Machine{ machineInShared2 := &types.Machine{
ID: 2, ID: 2,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
@ -185,16 +188,17 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
Hostname: "test_get_shared_nodes_2", Hostname: "test_get_shared_nodes_2",
UserID: userShared2.ID, UserID: userShared2.ID,
User: *userShared2, User: *userShared2,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")}, IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
AuthKeyID: uint(preAuthKeyInShared2.ID), AuthKeyID: uint(preAuthKeyInShared2.ID),
} }
app.db.db.Save(machineInShared2) err = app.db.MachineSave(machineInShared2)
c.Assert(err, check.IsNil)
_, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname) _, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
machineInShared3 := &Machine{ machineInShared3 := &types.Machine{
ID: 3, ID: 3,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
@ -202,16 +206,17 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
Hostname: "test_get_shared_nodes_3", Hostname: "test_get_shared_nodes_3",
UserID: userShared3.ID, UserID: userShared3.ID,
User: *userShared3, User: *userShared3,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")}, IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
AuthKeyID: uint(preAuthKeyInShared3.ID), AuthKeyID: uint(preAuthKeyInShared3.ID),
} }
app.db.db.Save(machineInShared3) err = app.db.MachineSave(machineInShared3)
c.Assert(err, check.IsNil)
_, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname) _, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
machine2InShared1 := &Machine{ machine2InShared1 := &types.Machine{
ID: 4, ID: 4,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
@ -219,11 +224,12 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
Hostname: "test_get_shared_nodes_4", Hostname: "test_get_shared_nodes_4",
UserID: userShared1.ID, UserID: userShared1.ID,
User: *userShared1, User: *userShared1,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")}, IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")},
AuthKeyID: uint(PreAuthKey2InShared1.ID), AuthKeyID: uint(PreAuthKey2InShared1.ID),
} }
app.db.db.Save(machine2InShared1) err = app.db.MachineSave(machine2InShared1)
c.Assert(err, check.IsNil)
baseDomain := "foobar.headscale.net" baseDomain := "foobar.headscale.net"
dnsConfigOrig := tailcfg.DNSConfig{ dnsConfigOrig := tailcfg.DNSConfig{
@ -232,7 +238,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
Proxied: true, Proxied: true,
} }
peersOfMachineInShared1, err := app.db.getPeers(app.aclPolicy, app.aclRules, machineInShared1) peersOfMachineInShared1, err := app.db.GetValidPeers(app.aclRules, machineInShared1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
dnsConfig := getMapResponseDNSConfig( dnsConfig := getMapResponseDNSConfig(
@ -307,7 +313,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
_, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1") _, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
machineInShared1 := &Machine{ machineInShared1 := &types.Machine{
ID: 1, ID: 1,
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
@ -315,16 +321,17 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
Hostname: "test_get_shared_nodes_1", Hostname: "test_get_shared_nodes_1",
UserID: userShared1.ID, UserID: userShared1.ID,
User: *userShared1, User: *userShared1,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")}, IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
AuthKeyID: uint(preAuthKeyInShared1.ID), AuthKeyID: uint(preAuthKeyInShared1.ID),
} }
app.db.db.Save(machineInShared1) err = app.db.MachineSave(machineInShared1)
c.Assert(err, check.IsNil)
_, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname) _, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
machineInShared2 := &Machine{ machineInShared2 := &types.Machine{
ID: 2, ID: 2,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
@ -332,16 +339,17 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
Hostname: "test_get_shared_nodes_2", Hostname: "test_get_shared_nodes_2",
UserID: userShared2.ID, UserID: userShared2.ID,
User: *userShared2, User: *userShared2,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")}, IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
AuthKeyID: uint(preAuthKeyInShared2.ID), AuthKeyID: uint(preAuthKeyInShared2.ID),
} }
app.db.db.Save(machineInShared2) err = app.db.MachineSave(machineInShared2)
c.Assert(err, check.IsNil)
_, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname) _, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
machineInShared3 := &Machine{ machineInShared3 := &types.Machine{
ID: 3, ID: 3,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
@ -349,16 +357,17 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
Hostname: "test_get_shared_nodes_3", Hostname: "test_get_shared_nodes_3",
UserID: userShared3.ID, UserID: userShared3.ID,
User: *userShared3, User: *userShared3,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")}, IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
AuthKeyID: uint(preAuthKeyInShared3.ID), AuthKeyID: uint(preAuthKeyInShared3.ID),
} }
app.db.db.Save(machineInShared3) err = app.db.MachineSave(machineInShared3)
c.Assert(err, check.IsNil)
_, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname) _, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
machine2InShared1 := &Machine{ machine2InShared1 := &types.Machine{
ID: 4, ID: 4,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
@ -366,11 +375,12 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
Hostname: "test_get_shared_nodes_4", Hostname: "test_get_shared_nodes_4",
UserID: userShared1.ID, UserID: userShared1.ID,
User: *userShared1, User: *userShared1,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")}, IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")},
AuthKeyID: uint(preAuthKey2InShared1.ID), AuthKeyID: uint(preAuthKey2InShared1.ID),
} }
app.db.db.Save(machine2InShared1) err = app.db.MachineSave(machine2InShared1)
c.Assert(err, check.IsNil)
baseDomain := "foobar.headscale.net" baseDomain := "foobar.headscale.net"
dnsConfigOrig := tailcfg.DNSConfig{ dnsConfigOrig := tailcfg.DNSConfig{
@ -379,7 +389,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
Proxied: false, Proxied: false,
} }
peersOfMachine1Shared1, err := app.db.getPeers(app.aclPolicy, app.aclRules, machineInShared1) peersOfMachine1Shared1, err := app.db.GetValidPeers(app.aclRules, machineInShared1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
dnsConfig := getMapResponseDNSConfig( dnsConfig := getMapResponseDNSConfig(

View file

@ -8,6 +8,7 @@ import (
"time" "time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
@ -36,7 +37,7 @@ func (api headscaleV1APIServer) GetUser(
return nil, err return nil, err
} }
return &v1.GetUserResponse{User: user.toProto()}, nil return &v1.GetUserResponse{User: user.Proto()}, nil
} }
func (api headscaleV1APIServer) CreateUser( func (api headscaleV1APIServer) CreateUser(
@ -48,7 +49,7 @@ func (api headscaleV1APIServer) CreateUser(
return nil, err return nil, err
} }
return &v1.CreateUserResponse{User: user.toProto()}, nil return &v1.CreateUserResponse{User: user.Proto()}, nil
} }
func (api headscaleV1APIServer) RenameUser( func (api headscaleV1APIServer) RenameUser(
@ -65,7 +66,7 @@ func (api headscaleV1APIServer) RenameUser(
return nil, err return nil, err
} }
return &v1.RenameUserResponse{User: user.toProto()}, nil return &v1.RenameUserResponse{User: user.Proto()}, nil
} }
func (api headscaleV1APIServer) DeleteUser( func (api headscaleV1APIServer) DeleteUser(
@ -91,7 +92,7 @@ func (api headscaleV1APIServer) ListUsers(
response := make([]*v1.User, len(users)) response := make([]*v1.User, len(users))
for index, user := range users { for index, user := range users {
response[index] = user.toProto() response[index] = user.Proto()
} }
log.Trace().Caller().Interface("users", response).Msg("") log.Trace().Caller().Interface("users", response).Msg("")
@ -128,7 +129,7 @@ func (api headscaleV1APIServer) CreatePreAuthKey(
return nil, err return nil, err
} }
return &v1.CreatePreAuthKeyResponse{PreAuthKey: preAuthKey.toProto()}, nil return &v1.CreatePreAuthKeyResponse{PreAuthKey: preAuthKey.Proto()}, nil
} }
func (api headscaleV1APIServer) ExpirePreAuthKey( func (api headscaleV1APIServer) ExpirePreAuthKey(
@ -159,7 +160,7 @@ func (api headscaleV1APIServer) ListPreAuthKeys(
response := make([]*v1.PreAuthKey, len(preAuthKeys)) response := make([]*v1.PreAuthKey, len(preAuthKeys))
for index, key := range preAuthKeys { for index, key := range preAuthKeys {
response[index] = key.toProto() response[index] = key.Proto()
} }
return &v1.ListPreAuthKeysResponse{PreAuthKeys: response}, nil return &v1.ListPreAuthKeysResponse{PreAuthKeys: response}, nil
@ -179,13 +180,13 @@ func (api headscaleV1APIServer) RegisterMachine(
request.GetKey(), request.GetKey(),
request.GetUser(), request.GetUser(),
nil, nil,
RegisterMethodCLI, util.RegisterMethodCLI,
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &v1.RegisterMachineResponse{Machine: machine.toProto()}, nil return &v1.RegisterMachineResponse{Machine: machine.Proto()}, nil
} }
func (api headscaleV1APIServer) GetMachine( func (api headscaleV1APIServer) GetMachine(
@ -197,7 +198,7 @@ func (api headscaleV1APIServer) GetMachine(
return nil, err return nil, err
} }
return &v1.GetMachineResponse{Machine: machine.toProto()}, nil return &v1.GetMachineResponse{Machine: machine.Proto()}, nil
} }
func (api headscaleV1APIServer) SetTags( func (api headscaleV1APIServer) SetTags(
@ -218,7 +219,7 @@ func (api headscaleV1APIServer) SetTags(
} }
} }
err = api.h.db.SetTags(machine, request.GetTags(), api.h.UpdateACLRules) err = api.h.db.SetTags(machine, request.GetTags())
if err != nil { if err != nil {
return &v1.SetTagsResponse{ return &v1.SetTagsResponse{
Machine: nil, Machine: nil,
@ -230,7 +231,7 @@ func (api headscaleV1APIServer) SetTags(
Strs("tags", request.GetTags()). Strs("tags", request.GetTags()).
Msg("Changing tags of machine") Msg("Changing tags of machine")
return &v1.SetTagsResponse{Machine: machine.toProto()}, nil return &v1.SetTagsResponse{Machine: machine.Proto()}, nil
} }
func validateTag(tag string) error { func validateTag(tag string) error {
@ -283,7 +284,7 @@ func (api headscaleV1APIServer) ExpireMachine(
Time("expiry", *machine.Expiry). Time("expiry", *machine.Expiry).
Msg("machine expired") Msg("machine expired")
return &v1.ExpireMachineResponse{Machine: machine.toProto()}, nil return &v1.ExpireMachineResponse{Machine: machine.Proto()}, nil
} }
func (api headscaleV1APIServer) RenameMachine( func (api headscaleV1APIServer) RenameMachine(
@ -308,7 +309,7 @@ func (api headscaleV1APIServer) RenameMachine(
Str("new_name", request.GetNewName()). Str("new_name", request.GetNewName()).
Msg("machine renamed") Msg("machine renamed")
return &v1.RenameMachineResponse{Machine: machine.toProto()}, nil return &v1.RenameMachineResponse{Machine: machine.Proto()}, nil
} }
func (api headscaleV1APIServer) ListMachines( func (api headscaleV1APIServer) ListMachines(
@ -323,7 +324,7 @@ func (api headscaleV1APIServer) ListMachines(
response := make([]*v1.Machine, len(machines)) response := make([]*v1.Machine, len(machines))
for index, machine := range machines { for index, machine := range machines {
response[index] = machine.toProto() response[index] = machine.Proto()
} }
return &v1.ListMachinesResponse{Machines: response}, nil return &v1.ListMachinesResponse{Machines: response}, nil
@ -336,9 +337,8 @@ func (api headscaleV1APIServer) ListMachines(
response := make([]*v1.Machine, len(machines)) response := make([]*v1.Machine, len(machines))
for index, machine := range machines { for index, machine := range machines {
m := machine.toProto() m := machine.Proto()
validTags, invalidTags := getTags( validTags, invalidTags := api.h.ACLPolicy.GetTagsOfMachine(
api.h.aclPolicy,
machine, machine,
api.h.cfg.OIDC.StripEmaildomain, api.h.cfg.OIDC.StripEmaildomain,
) )
@ -364,7 +364,7 @@ func (api headscaleV1APIServer) MoveMachine(
return nil, err return nil, err
} }
return &v1.MoveMachineResponse{Machine: machine.toProto()}, nil return &v1.MoveMachineResponse{Machine: machine.Proto()}, nil
} }
func (api headscaleV1APIServer) GetRoutes( func (api headscaleV1APIServer) GetRoutes(
@ -377,7 +377,7 @@ func (api headscaleV1APIServer) GetRoutes(
} }
return &v1.GetRoutesResponse{ return &v1.GetRoutesResponse{
Routes: Routes(routes).toProto(), Routes: types.Routes(routes).Proto(),
}, nil }, nil
} }
@ -420,7 +420,7 @@ func (api headscaleV1APIServer) GetMachineRoutes(
} }
return &v1.GetMachineRoutesResponse{ return &v1.GetMachineRoutesResponse{
Routes: Routes(routes).toProto(), Routes: types.Routes(routes).Proto(),
}, nil }, nil
} }
@ -459,7 +459,7 @@ func (api headscaleV1APIServer) ExpireApiKey(
ctx context.Context, ctx context.Context,
request *v1.ExpireApiKeyRequest, request *v1.ExpireApiKeyRequest,
) (*v1.ExpireApiKeyResponse, error) { ) (*v1.ExpireApiKeyResponse, error) {
var apiKey *APIKey var apiKey *types.APIKey
var err error var err error
apiKey, err = api.h.db.GetAPIKey(request.Prefix) apiKey, err = api.h.db.GetAPIKey(request.Prefix)
@ -486,7 +486,7 @@ func (api headscaleV1APIServer) ListApiKeys(
response := make([]*v1.ApiKey, len(apiKeys)) response := make([]*v1.ApiKey, len(apiKeys))
for index, key := range apiKeys { for index, key := range apiKeys {
response[index] = key.toProto() response[index] = key.Proto()
} }
return &v1.ListApiKeysResponse{ApiKeys: response}, nil return &v1.ListApiKeysResponse{ApiKeys: response}, nil
@ -524,7 +524,7 @@ func (api headscaleV1APIServer) DebugCreateMachine(
return nil, err return nil, err
} }
newMachine := Machine{ newMachine := types.Machine{
MachineKey: request.GetKey(), MachineKey: request.GetKey(),
Hostname: request.GetName(), Hostname: request.GetName(),
GivenName: givenName, GivenName: givenName,
@ -534,7 +534,7 @@ func (api headscaleV1APIServer) DebugCreateMachine(
LastSeen: &time.Time{}, LastSeen: &time.Time{},
LastSuccessfulUpdate: &time.Time{}, LastSuccessfulUpdate: &time.Time{},
HostInfo: HostInfo(hostinfo), HostInfo: types.HostInfo(hostinfo),
} }
nodeKey := key.NodePublic{} nodeKey := key.NodePublic{}
@ -549,7 +549,7 @@ func (api headscaleV1APIServer) DebugCreateMachine(
registerCacheExpiration, registerCacheExpiration,
) )
return &v1.DebugCreateMachineResponse{Machine: newMachine.toProto()}, nil return &v1.DebugCreateMachineResponse{Machine: newMachine.Proto()}, nil
} }
func (api headscaleV1APIServer) mustEmbedUnimplementedHeadscaleServiceServer() {} func (api headscaleV1APIServer) mustEmbedUnimplementedHeadscaleServiceServer() {}

File diff suppressed because it is too large Load diff

View file

@ -1,142 +0,0 @@
package hscontrol
import (
"fmt"
"net/netip"
"strings"
"go4.org/netipx"
"tailscale.com/tailcfg"
)
// This is borrowed from, and updated to use IPSet
// https://github.com/tailscale/tailscale/blob/71029cea2ddf82007b80f465b256d027eab0f02d/wgengine/filter/tailcfg.go#L97-L162
// TODO(kradalby): contribute upstream and make public.
var (
zeroIP4 = netip.AddrFrom4([4]byte{})
zeroIP6 = netip.AddrFrom16([16]byte{})
)
// parseIPSet parses arg as one:
//
// - an IP address (IPv4 or IPv6)
// - the string "*" to match everything (both IPv4 & IPv6)
// - a CIDR (e.g. "192.168.0.0/16")
// - a range of two IPs, inclusive, separated by hyphen ("2eff::1-2eff::0800")
//
// bits, if non-nil, is the legacy SrcBits CIDR length to make a IP
// address (without a slash) treated as a CIDR of *bits length.
// nolint
func parseIPSet(arg string, bits *int) (*netipx.IPSet, error) {
var ipSet netipx.IPSetBuilder
if arg == "*" {
ipSet.AddPrefix(netip.PrefixFrom(zeroIP4, 0))
ipSet.AddPrefix(netip.PrefixFrom(zeroIP6, 0))
return ipSet.IPSet()
}
if strings.Contains(arg, "/") {
pfx, err := netip.ParsePrefix(arg)
if err != nil {
return nil, err
}
if pfx != pfx.Masked() {
return nil, fmt.Errorf("%v contains non-network bits set", pfx)
}
ipSet.AddPrefix(pfx)
return ipSet.IPSet()
}
if strings.Count(arg, "-") == 1 {
ip1s, ip2s, _ := strings.Cut(arg, "-")
ip1, err := netip.ParseAddr(ip1s)
if err != nil {
return nil, err
}
ip2, err := netip.ParseAddr(ip2s)
if err != nil {
return nil, err
}
r := netipx.IPRangeFrom(ip1, ip2)
if !r.IsValid() {
return nil, fmt.Errorf("invalid IP range %q", arg)
}
for _, prefix := range r.Prefixes() {
ipSet.AddPrefix(prefix)
}
return ipSet.IPSet()
}
ip, err := netip.ParseAddr(arg)
if err != nil {
return nil, fmt.Errorf("invalid IP address %q", arg)
}
bits8 := uint8(ip.BitLen())
if bits != nil {
if *bits < 0 || *bits > int(bits8) {
return nil, fmt.Errorf("invalid CIDR size %d for IP %q", *bits, arg)
}
bits8 = uint8(*bits)
}
ipSet.AddPrefix(netip.PrefixFrom(ip, int(bits8)))
return ipSet.IPSet()
}
type Match struct {
Srcs *netipx.IPSet
Dests *netipx.IPSet
}
func MatchFromFilterRule(rule tailcfg.FilterRule) Match {
srcs := new(netipx.IPSetBuilder)
dests := new(netipx.IPSetBuilder)
for _, srcIP := range rule.SrcIPs {
set, _ := parseIPSet(srcIP, nil)
srcs.AddSet(set)
}
for _, dest := range rule.DstPorts {
set, _ := parseIPSet(dest.IP, nil)
dests.AddSet(set)
}
srcsSet, _ := srcs.IPSet()
destsSet, _ := dests.IPSet()
match := Match{
Srcs: srcsSet,
Dests: destsSet,
}
return match
}
func (m *Match) SrcsContainsIPs(ips []netip.Addr) bool {
for _, ip := range ips {
if m.Srcs.Contains(ip) {
return true
}
}
return false
}
func (m *Match) DestsContainsIP(ips []netip.Addr) bool {
for _, ip := range ips {
if m.Dests.Contains(ip) {
return true
}
}
return false
}

View file

@ -14,6 +14,8 @@ import (
"github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"golang.org/x/oauth2" "golang.org/x/oauth2"
@ -638,7 +640,7 @@ func getUserName(
claims *IDTokenClaims, claims *IDTokenClaims,
stripEmaildomain bool, stripEmaildomain bool,
) (string, error) { ) (string, error) {
userName, err := NormalizeToFQDNRules( userName, err := util.NormalizeToFQDNRules(
claims.Email, claims.Email,
stripEmaildomain, stripEmaildomain,
) )
@ -663,9 +665,9 @@ func getUserName(
func (h *Headscale) findOrCreateNewUserForOIDCCallback( func (h *Headscale) findOrCreateNewUserForOIDCCallback(
writer http.ResponseWriter, writer http.ResponseWriter,
userName string, userName string,
) (*User, error) { ) (*types.User, error) {
user, err := h.db.GetUser(userName) user, err := h.db.GetUser(userName)
if errors.Is(err, ErrUserNotFound) { if errors.Is(err, db.ErrUserNotFound) {
user, err = h.db.CreateUser(userName) user, err = h.db.CreateUser(userName)
if err != nil { if err != nil {
@ -709,7 +711,7 @@ func (h *Headscale) findOrCreateNewUserForOIDCCallback(
func (h *Headscale) registerMachineForOIDCCallback( func (h *Headscale) registerMachineForOIDCCallback(
writer http.ResponseWriter, writer http.ResponseWriter,
user *User, user *types.User,
nodeKey *key.NodePublic, nodeKey *key.NodePublic,
expiry time.Time, expiry time.Time,
) error { ) error {
@ -719,7 +721,7 @@ func (h *Headscale) registerMachineForOIDCCallback(
nodeKey.String(), nodeKey.String(),
user.Name, user.Name,
&expiry, &expiry,
RegisterMethodOIDC, util.RegisterMethodOIDC,
); err != nil { ); err != nil {
log.Error(). log.Error().
Caller(). Caller().

View file

@ -1,4 +1,4 @@
package hscontrol package policy
import ( import (
"encoding/json" "encoding/json"
@ -12,6 +12,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/tailscale/hujson" "github.com/tailscale/hujson"
@ -22,12 +23,12 @@ import (
) )
var ( var (
errEmptyPolicy = errors.New("empty policy") ErrEmptyPolicy = errors.New("empty policy")
errInvalidAction = errors.New("invalid action") ErrInvalidAction = errors.New("invalid action")
errInvalidGroup = errors.New("invalid group") ErrInvalidGroup = errors.New("invalid group")
errInvalidTag = errors.New("invalid tag") ErrInvalidTag = errors.New("invalid tag")
errInvalidPortFormat = errors.New("invalid port format") ErrInvalidPortFormat = errors.New("invalid port format")
errWildcardIsNeeded = errors.New("wildcard as port is required for the protocol") ErrWildcardIsNeeded = errors.New("wildcard as port is required for the protocol")
) )
const ( const (
@ -56,7 +57,7 @@ const (
var featureEnableSSH = envknob.RegisterBool("HEADSCALE_EXPERIMENTAL_FEATURE_SSH") var featureEnableSSH = envknob.RegisterBool("HEADSCALE_EXPERIMENTAL_FEATURE_SSH")
// LoadACLPolicyFromPath loads the ACL policy from the specify path, and generates the ACL rules. // LoadACLPolicyFromPath loads the ACL policy from the specify path, and generates the ACL rules.
func (h *Headscale) LoadACLPolicyFromPath(path string) error { func LoadACLPolicyFromPath(path string) (*ACLPolicy, error) {
log.Debug(). log.Debug().
Str("func", "LoadACLPolicy"). Str("func", "LoadACLPolicy").
Str("path", path). Str("path", path).
@ -64,13 +65,13 @@ func (h *Headscale) LoadACLPolicyFromPath(path string) error {
policyFile, err := os.Open(path) policyFile, err := os.Open(path)
if err != nil { if err != nil {
return err return nil, err
} }
defer policyFile.Close() defer policyFile.Close()
policyBytes, err := io.ReadAll(policyFile) policyBytes, err := io.ReadAll(policyFile)
if err != nil { if err != nil {
return err return nil, err
} }
log.Debug(). log.Debug().
@ -80,90 +81,90 @@ func (h *Headscale) LoadACLPolicyFromPath(path string) error {
switch filepath.Ext(path) { switch filepath.Ext(path) {
case ".yml", ".yaml": case ".yml", ".yaml":
return h.LoadACLPolicyFromBytes(policyBytes, "yaml") return LoadACLPolicyFromBytes(policyBytes, "yaml")
} }
return h.LoadACLPolicyFromBytes(policyBytes, "hujson") return LoadACLPolicyFromBytes(policyBytes, "hujson")
} }
func (h *Headscale) LoadACLPolicyFromBytes(acl []byte, format string) error { func LoadACLPolicyFromBytes(acl []byte, format string) (*ACLPolicy, error) {
var policy ACLPolicy var policy ACLPolicy
switch format { switch format {
case "yaml": case "yaml":
err := yaml.Unmarshal(acl, &policy) err := yaml.Unmarshal(acl, &policy)
if err != nil { if err != nil {
return err return nil, err
} }
default: default:
ast, err := hujson.Parse(acl) ast, err := hujson.Parse(acl)
if err != nil { if err != nil {
return err return nil, err
} }
ast.Standardize() ast.Standardize()
acl = ast.Pack() acl = ast.Pack()
err = json.Unmarshal(acl, &policy) err = json.Unmarshal(acl, &policy)
if err != nil { if err != nil {
return err return nil, err
} }
} }
if policy.IsZero() { if policy.IsZero() {
return errEmptyPolicy return nil, ErrEmptyPolicy
} }
h.aclPolicy = &policy return &policy, nil
return h.UpdateACLRules()
} }
func (h *Headscale) UpdateACLRules() error { // TODO(kradalby): This needs to be replace with something that generates
machines, err := h.db.ListMachines() // the rules as needed and not stores it on the global object, rules are
// per node and that should be taken into account.
func GenerateFilterRules(
policy *ACLPolicy,
machines types.Machines,
stripEmailDomain bool,
) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) {
if policy == nil {
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, ErrEmptyPolicy
}
rules, err := policy.generateFilterRules(machines, stripEmailDomain)
if err != nil { if err != nil {
return err return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
}
if h.aclPolicy == nil {
return errEmptyPolicy
}
rules, err := h.aclPolicy.generateFilterRules(machines, h.cfg.OIDC.StripEmaildomain)
if err != nil {
return err
} }
log.Trace().Interface("ACL", rules).Msg("ACL rules generated") log.Trace().Interface("ACL", rules).Msg("ACL rules generated")
h.aclRules = rules
var sshPolicy *tailcfg.SSHPolicy
if featureEnableSSH() { if featureEnableSSH() {
sshRules, err := h.generateSSHRules() sshRules, err := generateSSHRules(policy, machines, stripEmailDomain)
if err != nil { if err != nil {
return err return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
} }
log.Trace().Interface("SSH", sshRules).Msg("SSH rules generated") log.Trace().Interface("SSH", sshRules).Msg("SSH rules generated")
if h.sshPolicy == nil { if sshPolicy == nil {
h.sshPolicy = &tailcfg.SSHPolicy{} sshPolicy = &tailcfg.SSHPolicy{}
} }
h.sshPolicy.Rules = sshRules sshPolicy.Rules = sshRules
} else if h.aclPolicy != nil && len(h.aclPolicy.SSHs) > 0 { } else if policy != nil && len(policy.SSHs) > 0 {
log.Info().Msg("SSH ACLs has been defined, but HEADSCALE_EXPERIMENTAL_FEATURE_SSH is not enabled, this is a unstable feature, check docs before activating") log.Info().Msg("SSH ACLs has been defined, but HEADSCALE_EXPERIMENTAL_FEATURE_SSH is not enabled, this is a unstable feature, check docs before activating")
} }
return nil return rules, sshPolicy, nil
} }
// generateFilterRules takes a set of machines and an ACLPolicy and generates a // generateFilterRules takes a set of machines and an ACLPolicy and generates a
// set of Tailscale compatible FilterRules used to allow traffic on clients. // set of Tailscale compatible FilterRules used to allow traffic on clients.
func (pol *ACLPolicy) generateFilterRules( func (pol *ACLPolicy) generateFilterRules(
machines []Machine, machines types.Machines,
stripEmailDomain bool, stripEmailDomain bool,
) ([]tailcfg.FilterRule, error) { ) ([]tailcfg.FilterRule, error) {
rules := []tailcfg.FilterRule{} rules := []tailcfg.FilterRule{}
for index, acl := range pol.ACLs { for index, acl := range pol.ACLs {
if acl.Action != "accept" { if acl.Action != "accept" {
return nil, errInvalidAction return nil, ErrInvalidAction
} }
srcIPs := []string{} srcIPs := []string{}
@ -219,16 +220,15 @@ func (pol *ACLPolicy) generateFilterRules(
return rules, nil return rules, nil
} }
func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) { func generateSSHRules(
policy *ACLPolicy,
machines types.Machines,
stripEmailDomain bool,
) ([]*tailcfg.SSHRule, error) {
rules := []*tailcfg.SSHRule{} rules := []*tailcfg.SSHRule{}
if h.aclPolicy == nil { if policy == nil {
return nil, errEmptyPolicy return nil, ErrEmptyPolicy
}
machines, err := h.db.ListMachines()
if err != nil {
return nil, err
} }
acceptAction := tailcfg.SSHAction{ acceptAction := tailcfg.SSHAction{
@ -251,7 +251,7 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) {
AllowLocalPortForwarding: false, AllowLocalPortForwarding: false,
} }
for index, sshACL := range h.aclPolicy.SSHs { for index, sshACL := range policy.SSHs {
action := rejectAction action := rejectAction
switch sshACL.Action { switch sshACL.Action {
case "accept": case "accept":
@ -266,9 +266,9 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) {
} }
default: default:
log.Error(). log.Error().
Msgf("Error parsing SSH %d, unknown action '%s'", index, sshACL.Action) Msgf("Error parsing SSH %d, unknown action '%s', skipping", index, sshACL.Action)
return nil, err continue
} }
principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources)) principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources))
@ -278,7 +278,7 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) {
Any: true, Any: true,
}) })
} else if isGroup(rawSrc) { } else if isGroup(rawSrc) {
users, err := h.aclPolicy.getUsersInGroup(rawSrc, h.cfg.OIDC.StripEmaildomain) users, err := policy.getUsersInGroup(rawSrc, stripEmailDomain)
if err != nil { if err != nil {
log.Error(). log.Error().
Msgf("Error parsing SSH %d, Source %d", index, innerIndex) Msgf("Error parsing SSH %d, Source %d", index, innerIndex)
@ -292,10 +292,10 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) {
}) })
} }
} else { } else {
expandedSrcs, err := h.aclPolicy.expandAlias( expandedSrcs, err := policy.ExpandAlias(
machines, machines,
rawSrc, rawSrc,
h.cfg.OIDC.StripEmaildomain, stripEmailDomain,
) )
if err != nil { if err != nil {
log.Error(). log.Error().
@ -346,10 +346,10 @@ func sshCheckAction(duration string) (*tailcfg.SSHAction, error) {
// with the given src alias. // with the given src alias.
func (pol *ACLPolicy) getIPsFromSource( func (pol *ACLPolicy) getIPsFromSource(
src string, src string,
machines []Machine, machines types.Machines,
stripEmaildomain bool, stripEmaildomain bool,
) ([]string, error) { ) ([]string, error) {
ipSet, err := pol.expandAlias(machines, src, stripEmaildomain) ipSet, err := pol.ExpandAlias(machines, src, stripEmaildomain)
if err != nil { if err != nil {
return []string{}, err return []string{}, err
} }
@ -367,7 +367,7 @@ func (pol *ACLPolicy) getIPsFromSource(
// which are associated with the dest alias. // which are associated with the dest alias.
func (pol *ACLPolicy) getNetPortRangeFromDestination( func (pol *ACLPolicy) getNetPortRangeFromDestination(
dest string, dest string,
machines []Machine, machines types.Machines,
needsWildcard bool, needsWildcard bool,
stripEmaildomain bool, stripEmaildomain bool,
) ([]tailcfg.NetPortRange, error) { ) ([]tailcfg.NetPortRange, error) {
@ -390,7 +390,7 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination(
return nil, fmt.Errorf( return nil, fmt.Errorf(
"failed to parse destination, tokens %v: %w", "failed to parse destination, tokens %v: %w",
tokens, tokens,
errInvalidPortFormat, ErrInvalidPortFormat,
) )
} else { } else {
tokens = []string{maybeIPv6Str, port} tokens = []string{maybeIPv6Str, port}
@ -414,7 +414,7 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination(
alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1]) alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1])
} }
expanded, err := pol.expandAlias( expanded, err := pol.ExpandAlias(
machines, machines,
alias, alias,
stripEmaildomain, stripEmaildomain,
@ -499,13 +499,13 @@ func parseProtocol(protocol string) ([]int, bool, error) {
// - an ip // - an ip
// - a cidr // - a cidr
// and transform these in IPAddresses. // and transform these in IPAddresses.
func (pol *ACLPolicy) expandAlias( func (pol *ACLPolicy) ExpandAlias(
machines Machines, machines types.Machines,
alias string, alias string,
stripEmailDomain bool, stripEmailDomain bool,
) (*netipx.IPSet, error) { ) (*netipx.IPSet, error) {
if isWildcard(alias) { if isWildcard(alias) {
return parseIPSet("*", nil) return util.ParseIPSet("*", nil)
} }
build := netipx.IPSetBuilder{} build := netipx.IPSetBuilder{}
@ -532,9 +532,9 @@ func (pol *ACLPolicy) expandAlias(
// if alias is an host // if alias is an host
// Note, this is recursive. // Note, this is recursive.
if h, ok := pol.Hosts[alias]; ok { if h, ok := pol.Hosts[alias]; ok {
log.Trace().Str("host", h.String()).Msg("expandAlias got hosts entry") log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry")
return pol.expandAlias(machines, h.String(), stripEmailDomain) return pol.ExpandAlias(machines, h.String(), stripEmailDomain)
} }
// if alias is an IP // if alias is an IP
@ -557,11 +557,11 @@ func (pol *ACLPolicy) expandAlias(
// we assume in this function that we only have nodes from 1 user. // we assume in this function that we only have nodes from 1 user.
func excludeCorrectlyTaggedNodes( func excludeCorrectlyTaggedNodes(
aclPolicy *ACLPolicy, aclPolicy *ACLPolicy,
nodes []Machine, nodes types.Machines,
user string, user string,
stripEmailDomain bool, stripEmailDomain bool,
) []Machine { ) types.Machines {
out := []Machine{} out := types.Machines{}
tags := []string{} tags := []string{}
for tag := range aclPolicy.TagOwners { for tag := range aclPolicy.TagOwners {
owners, _ := getTagOwners(aclPolicy, user, stripEmailDomain) owners, _ := getTagOwners(aclPolicy, user, stripEmailDomain)
@ -601,7 +601,7 @@ func expandPorts(portsStr string, needsWildcard bool) (*[]tailcfg.PortRange, err
} }
if needsWildcard { if needsWildcard {
return nil, errWildcardIsNeeded return nil, ErrWildcardIsNeeded
} }
ports := []tailcfg.PortRange{} ports := []tailcfg.PortRange{}
@ -634,15 +634,15 @@ func expandPorts(portsStr string, needsWildcard bool) (*[]tailcfg.PortRange, err
}) })
default: default:
return nil, errInvalidPortFormat return nil, ErrInvalidPortFormat
} }
} }
return &ports, nil return &ports, nil
} }
func filterMachinesByUser(machines []Machine, user string) []Machine { func filterMachinesByUser(machines types.Machines, user string) types.Machines {
out := []Machine{} out := types.Machines{}
for _, machine := range machines { for _, machine := range machines {
if machine.User.Name == user { if machine.User.Name == user {
out = append(out, machine) out = append(out, machine)
@ -664,7 +664,7 @@ func getTagOwners(
if !ok { if !ok {
return []string{}, fmt.Errorf( return []string{}, fmt.Errorf(
"%w. %v isn't owned by a TagOwner. Please add one first. https://tailscale.com/kb/1018/acls/#tag-owners", "%w. %v isn't owned by a TagOwner. Please add one first. https://tailscale.com/kb/1018/acls/#tag-owners",
errInvalidTag, ErrInvalidTag,
tag, tag,
) )
} }
@ -696,22 +696,22 @@ func (pol *ACLPolicy) getUsersInGroup(
return []string{}, fmt.Errorf( return []string{}, fmt.Errorf(
"group %v isn't registered. %w", "group %v isn't registered. %w",
group, group,
errInvalidGroup, ErrInvalidGroup,
) )
} }
for _, group := range aclGroups { for _, group := range aclGroups {
if isGroup(group) { if isGroup(group) {
return []string{}, fmt.Errorf( return []string{}, fmt.Errorf(
"%w. A group cannot be composed of groups. https://tailscale.com/kb/1018/acls/#groups", "%w. A group cannot be composed of groups. https://tailscale.com/kb/1018/acls/#groups",
errInvalidGroup, ErrInvalidGroup,
) )
} }
grp, err := NormalizeToFQDNRules(group, stripEmailDomain) grp, err := util.NormalizeToFQDNRules(group, stripEmailDomain)
if err != nil { if err != nil {
return []string{}, fmt.Errorf( return []string{}, fmt.Errorf(
"failed to normalize group %q, err: %w", "failed to normalize group %q, err: %w",
group, group,
errInvalidGroup, ErrInvalidGroup,
) )
} }
users = append(users, grp) users = append(users, grp)
@ -722,7 +722,7 @@ func (pol *ACLPolicy) getUsersInGroup(
func (pol *ACLPolicy) getIPsFromGroup( func (pol *ACLPolicy) getIPsFromGroup(
group string, group string,
machines Machines, machines types.Machines,
stripEmailDomain bool, stripEmailDomain bool,
) (*netipx.IPSet, error) { ) (*netipx.IPSet, error) {
build := netipx.IPSetBuilder{} build := netipx.IPSetBuilder{}
@ -743,7 +743,7 @@ func (pol *ACLPolicy) getIPsFromGroup(
func (pol *ACLPolicy) getIPsFromTag( func (pol *ACLPolicy) getIPsFromTag(
alias string, alias string,
machines Machines, machines types.Machines,
stripEmailDomain bool, stripEmailDomain bool,
) (*netipx.IPSet, error) { ) (*netipx.IPSet, error) {
build := netipx.IPSetBuilder{} build := netipx.IPSetBuilder{}
@ -758,12 +758,12 @@ func (pol *ACLPolicy) getIPsFromTag(
// find tag owners // find tag owners
owners, err := getTagOwners(pol, alias, stripEmailDomain) owners, err := getTagOwners(pol, alias, stripEmailDomain)
if err != nil { if err != nil {
if errors.Is(err, errInvalidTag) { if errors.Is(err, ErrInvalidTag) {
ipSet, _ := build.IPSet() ipSet, _ := build.IPSet()
if len(ipSet.Prefixes()) == 0 { if len(ipSet.Prefixes()) == 0 {
return ipSet, fmt.Errorf( return ipSet, fmt.Errorf(
"%w. %v isn't owned by a TagOwner and no forced tags are defined", "%w. %v isn't owned by a TagOwner and no forced tags are defined",
errInvalidTag, ErrInvalidTag,
alias, alias,
) )
} }
@ -790,7 +790,7 @@ func (pol *ACLPolicy) getIPsFromTag(
func (pol *ACLPolicy) getIPsForUser( func (pol *ACLPolicy) getIPsForUser(
user string, user string,
machines Machines, machines types.Machines,
stripEmailDomain bool, stripEmailDomain bool,
) (*netipx.IPSet, error) { ) (*netipx.IPSet, error) {
build := netipx.IPSetBuilder{} build := netipx.IPSetBuilder{}
@ -812,9 +812,9 @@ func (pol *ACLPolicy) getIPsForUser(
func (pol *ACLPolicy) getIPsFromSingleIP( func (pol *ACLPolicy) getIPsFromSingleIP(
ip netip.Addr, ip netip.Addr,
machines Machines, machines types.Machines,
) (*netipx.IPSet, error) { ) (*netipx.IPSet, error) {
log.Trace().Str("ip", ip.String()).Msg("expandAlias got ip") log.Trace().Str("ip", ip.String()).Msg("ExpandAlias got ip")
matches := machines.FilterByIP(ip) matches := machines.FilterByIP(ip)
@ -830,7 +830,7 @@ func (pol *ACLPolicy) getIPsFromSingleIP(
func (pol *ACLPolicy) getIPsFromIPPrefix( func (pol *ACLPolicy) getIPsFromIPPrefix(
prefix netip.Prefix, prefix netip.Prefix,
machines Machines, machines types.Machines,
) (*netipx.IPSet, error) { ) (*netipx.IPSet, error) {
log.Trace().Str("prefix", prefix.String()).Msg("expandAlias got prefix") log.Trace().Str("prefix", prefix.String()).Msg("expandAlias got prefix")
build := netipx.IPSetBuilder{} build := netipx.IPSetBuilder{}
@ -862,3 +862,65 @@ func isGroup(str string) bool {
func isTag(str string) bool { func isTag(str string) bool {
return strings.HasPrefix(str, "tag:") return strings.HasPrefix(str, "tag:")
} }
// getTags will return the tags of the current machine.
// Invalid tags are tags added by a user on a node, and that user doesn't have authority to add this tag.
// Valid tags are tags added by a user that is allowed in the ACL policy to add this tag.
func (pol *ACLPolicy) GetTagsOfMachine(
machine types.Machine,
stripEmailDomain bool,
) ([]string, []string) {
validTags := make([]string, 0)
invalidTags := make([]string, 0)
validTagMap := make(map[string]bool)
invalidTagMap := make(map[string]bool)
for _, tag := range machine.HostInfo.RequestTags {
owners, err := getTagOwners(pol, tag, stripEmailDomain)
if errors.Is(err, ErrInvalidTag) {
invalidTagMap[tag] = true
continue
}
var found bool
for _, owner := range owners {
if machine.User.Name == owner {
found = true
}
}
if found {
validTagMap[tag] = true
} else {
invalidTagMap[tag] = true
}
}
for tag := range invalidTagMap {
invalidTags = append(invalidTags, tag)
}
for tag := range validTagMap {
validTags = append(validTags, tag)
}
return validTags, invalidTags
}
// FilterMachinesByACL returns the list of peers authorized to be accessed from a given machine.
func FilterMachinesByACL(
machine *types.Machine,
machines types.Machines,
filter []tailcfg.FilterRule,
) types.Machines {
result := types.Machines{}
for index, peer := range machines {
if peer.ID == machine.ID {
continue
}
if machine.CanAccess(filter, &machines[index]) || peer.CanAccess(filter, machine) {
result = append(result, peer)
}
}
return result
}

File diff suppressed because it is too large Load diff

View file

@ -1,4 +1,4 @@
package hscontrol package policy
import ( import (
"encoding/json" "encoding/json"

View file

@ -0,0 +1,61 @@
package matcher
import (
"net/netip"
"github.com/juanfont/headscale/hscontrol/util"
"go4.org/netipx"
"tailscale.com/tailcfg"
)
type Match struct {
Srcs *netipx.IPSet
Dests *netipx.IPSet
}
func MatchFromFilterRule(rule tailcfg.FilterRule) Match {
srcs := new(netipx.IPSetBuilder)
dests := new(netipx.IPSetBuilder)
for _, srcIP := range rule.SrcIPs {
set, _ := util.ParseIPSet(srcIP, nil)
srcs.AddSet(set)
}
for _, dest := range rule.DstPorts {
set, _ := util.ParseIPSet(dest.IP, nil)
dests.AddSet(set)
}
srcsSet, _ := srcs.IPSet()
destsSet, _ := dests.IPSet()
match := Match{
Srcs: srcsSet,
Dests: destsSet,
}
return match
}
func (m *Match) SrcsContainsIPs(ips []netip.Addr) bool {
for _, ip := range ips {
if m.Srcs.Contains(ip) {
return true
}
}
return false
}
func (m *Match) DestsContainsIP(ips []netip.Addr) bool {
for _, ip := range ips {
if m.Dests.Contains(ip) {
return true
}
}
return false
}

View file

@ -0,0 +1 @@
package matcher

View file

@ -9,6 +9,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"gorm.io/gorm" "gorm.io/gorm"
@ -171,7 +172,7 @@ func (h *Headscale) handleRegisterCommon(
// that we rely on a method that calls back some how (OpenID or CLI) // that we rely on a method that calls back some how (OpenID or CLI)
// We create the machine and then keep it around until a callback // We create the machine and then keep it around until a callback
// happens // happens
newMachine := Machine{ newMachine := types.Machine{
MachineKey: util.MachinePublicKeyStripPrefix(machineKey), MachineKey: util.MachinePublicKeyStripPrefix(machineKey),
Hostname: registerRequest.Hostinfo.Hostname, Hostname: registerRequest.Hostinfo.Hostname,
GivenName: givenName, GivenName: givenName,
@ -214,8 +215,7 @@ func (h *Headscale) handleRegisterCommon(
[]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)), []byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)),
) )
if err != nil || storedMachineKey.IsZero() { if err != nil || storedMachineKey.IsZero() {
machine.MachineKey = util.MachinePublicKeyStripPrefix(machineKey) if err := h.db.MachineSetMachineKey(machine, machineKey); err != nil {
if err := h.db.db.Save(&machine).Error; err != nil {
log.Error(). log.Error().
Caller(). Caller().
Str("func", "RegistrationHandler"). Str("func", "RegistrationHandler").
@ -244,7 +244,7 @@ func (h *Headscale) handleRegisterCommon(
// If machine is not expired, and it is register, we have a already accepted this machine, // If machine is not expired, and it is register, we have a already accepted this machine,
// let it proceed with a valid registration // let it proceed with a valid registration
if !machine.isExpired() { if !machine.IsExpired() {
h.handleMachineValidRegistrationCommon(writer, *machine, machineKey, isNoise) h.handleMachineValidRegistrationCommon(writer, *machine, machineKey, isNoise)
return return
@ -253,7 +253,7 @@ func (h *Headscale) handleRegisterCommon(
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration // The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
if machine.NodeKey == util.NodePublicKeyStripPrefix(registerRequest.OldNodeKey) && if machine.NodeKey == util.NodePublicKeyStripPrefix(registerRequest.OldNodeKey) &&
!machine.isExpired() { !machine.IsExpired() {
h.handleMachineRefreshKeyCommon( h.handleMachineRefreshKeyCommon(
writer, writer,
registerRequest, registerRequest,
@ -312,7 +312,7 @@ func (h *Headscale) handleAuthKeyCommon(
Msgf("Processing auth key for %s", registerRequest.Hostinfo.Hostname) Msgf("Processing auth key for %s", registerRequest.Hostinfo.Hostname)
resp := tailcfg.RegisterResponse{} resp := tailcfg.RegisterResponse{}
pak, err := h.db.checkKeyValidity(registerRequest.Auth.AuthKey) pak, err := h.db.ValidatePreAuthKey(registerRequest.Auth.AuthKey)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
@ -333,7 +333,7 @@ func (h *Headscale) handleAuthKeyCommon(
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
http.Error(writer, "Internal server error", http.StatusInternalServerError) http.Error(writer, "Internal server error", http.StatusInternalServerError)
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name). machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name).
Inc() Inc()
return return
@ -358,10 +358,10 @@ func (h *Headscale) handleAuthKeyCommon(
Msg("Failed authentication via AuthKey") Msg("Failed authentication via AuthKey")
if pak != nil { if pak != nil {
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name). machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name).
Inc() Inc()
} else { } else {
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", "unknown").Inc() machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", "unknown").Inc()
} }
return return
@ -401,10 +401,10 @@ func (h *Headscale) handleAuthKeyCommon(
return return
} }
aclTags := pak.toProto().AclTags aclTags := pak.Proto().AclTags
if len(aclTags) > 0 { if len(aclTags) > 0 {
// This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login // This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login
err = h.db.SetTags(machine, aclTags, h.UpdateACLRules) err = h.db.SetTags(machine, aclTags)
if err != nil { if err != nil {
log.Error(). log.Error().
@ -433,17 +433,17 @@ func (h *Headscale) handleAuthKeyCommon(
return return
} }
machineToRegister := Machine{ machineToRegister := types.Machine{
Hostname: registerRequest.Hostinfo.Hostname, Hostname: registerRequest.Hostinfo.Hostname,
GivenName: givenName, GivenName: givenName,
UserID: pak.User.ID, UserID: pak.User.ID,
MachineKey: util.MachinePublicKeyStripPrefix(machineKey), MachineKey: util.MachinePublicKeyStripPrefix(machineKey),
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
Expiry: &registerRequest.Expiry, Expiry: &registerRequest.Expiry,
NodeKey: nodeKey, NodeKey: nodeKey,
LastSeen: &now, LastSeen: &now,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
ForcedTags: pak.toProto().AclTags, ForcedTags: pak.Proto().AclTags,
} }
machine, err = h.db.RegisterMachine( machine, err = h.db.RegisterMachine(
@ -455,7 +455,7 @@ func (h *Headscale) handleAuthKeyCommon(
Bool("noise", isNoise). Bool("noise", isNoise).
Err(err). Err(err).
Msg("could not register machine") Msg("could not register machine")
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name). machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name).
Inc() Inc()
http.Error(writer, "Internal server error", http.StatusInternalServerError) http.Error(writer, "Internal server error", http.StatusInternalServerError)
@ -470,7 +470,7 @@ func (h *Headscale) handleAuthKeyCommon(
Bool("noise", isNoise). Bool("noise", isNoise).
Err(err). Err(err).
Msg("Failed to use pre-auth key") Msg("Failed to use pre-auth key")
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name). machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name).
Inc() Inc()
http.Error(writer, "Internal server error", http.StatusInternalServerError) http.Error(writer, "Internal server error", http.StatusInternalServerError)
@ -478,10 +478,10 @@ func (h *Headscale) handleAuthKeyCommon(
} }
resp.MachineAuthorized = true resp.MachineAuthorized = true
resp.User = *pak.User.toTailscaleUser() resp.User = *pak.User.TailscaleUser()
// Provide LoginName when registering with pre-auth key // Provide LoginName when registering with pre-auth key
// Otherwise it will need to exec `tailscale up` twice to fetch the *LoginName* // Otherwise it will need to exec `tailscale up` twice to fetch the *LoginName*
resp.Login = *pak.User.toTailscaleLogin() resp.Login = *pak.User.TailscaleLogin()
respBody, err := h.marshalResponse(resp, machineKey, isNoise) respBody, err := h.marshalResponse(resp, machineKey, isNoise)
if err != nil { if err != nil {
@ -492,13 +492,13 @@ func (h *Headscale) handleAuthKeyCommon(
Str("machine", registerRequest.Hostinfo.Hostname). Str("machine", registerRequest.Hostinfo.Hostname).
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name). machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name).
Inc() Inc()
http.Error(writer, "Internal server error", http.StatusInternalServerError) http.Error(writer, "Internal server error", http.StatusInternalServerError)
return return
} }
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "success", pak.User.Name). machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "success", pak.User.Name).
Inc() Inc()
writer.Header().Set("Content-Type", "application/json; charset=utf-8") writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK) writer.WriteHeader(http.StatusOK)
@ -581,7 +581,7 @@ func (h *Headscale) handleNewMachineCommon(
func (h *Headscale) handleMachineLogOutCommon( func (h *Headscale) handleMachineLogOutCommon(
writer http.ResponseWriter, writer http.ResponseWriter,
machine Machine, machine types.Machine,
machineKey key.MachinePublic, machineKey key.MachinePublic,
isNoise bool, isNoise bool,
) { ) {
@ -608,7 +608,7 @@ func (h *Headscale) handleMachineLogOutCommon(
resp.AuthURL = "" resp.AuthURL = ""
resp.MachineAuthorized = false resp.MachineAuthorized = false
resp.NodeKeyExpired = true resp.NodeKeyExpired = true
resp.User = *machine.User.toTailscaleUser() resp.User = *machine.User.TailscaleUser()
respBody, err := h.marshalResponse(resp, machineKey, isNoise) respBody, err := h.marshalResponse(resp, machineKey, isNoise)
if err != nil { if err != nil {
log.Error(). log.Error().
@ -634,7 +634,7 @@ func (h *Headscale) handleMachineLogOutCommon(
return return
} }
if machine.isEphemeral() { if machine.IsEphemeral() {
err = h.db.HardDeleteMachine(&machine) err = h.db.HardDeleteMachine(&machine)
if err != nil { if err != nil {
log.Error(). log.Error().
@ -655,7 +655,7 @@ func (h *Headscale) handleMachineLogOutCommon(
func (h *Headscale) handleMachineValidRegistrationCommon( func (h *Headscale) handleMachineValidRegistrationCommon(
writer http.ResponseWriter, writer http.ResponseWriter,
machine Machine, machine types.Machine,
machineKey key.MachinePublic, machineKey key.MachinePublic,
isNoise bool, isNoise bool,
) { ) {
@ -670,8 +670,8 @@ func (h *Headscale) handleMachineValidRegistrationCommon(
resp.AuthURL = "" resp.AuthURL = ""
resp.MachineAuthorized = true resp.MachineAuthorized = true
resp.User = *machine.User.toTailscaleUser() resp.User = *machine.User.TailscaleUser()
resp.Login = *machine.User.toTailscaleLogin() resp.Login = *machine.User.TailscaleLogin()
respBody, err := h.marshalResponse(resp, machineKey, isNoise) respBody, err := h.marshalResponse(resp, machineKey, isNoise)
if err != nil { if err != nil {
@ -710,7 +710,7 @@ func (h *Headscale) handleMachineValidRegistrationCommon(
func (h *Headscale) handleMachineRefreshKeyCommon( func (h *Headscale) handleMachineRefreshKeyCommon(
writer http.ResponseWriter, writer http.ResponseWriter,
registerRequest tailcfg.RegisterRequest, registerRequest tailcfg.RegisterRequest,
machine Machine, machine types.Machine,
machineKey key.MachinePublic, machineKey key.MachinePublic,
isNoise bool, isNoise bool,
) { ) {
@ -721,9 +721,9 @@ func (h *Headscale) handleMachineRefreshKeyCommon(
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Msg("We have the OldNodeKey in the database. This is a key refresh") Msg("We have the OldNodeKey in the database. This is a key refresh")
machine.NodeKey = util.NodePublicKeyStripPrefix(registerRequest.NodeKey)
if err := h.db.db.Save(&machine).Error; err != nil { err := h.db.MachineSetNodeKey(&machine, registerRequest.NodeKey)
if err != nil {
log.Error(). log.Error().
Caller(). Caller().
Err(err). Err(err).
@ -734,7 +734,7 @@ func (h *Headscale) handleMachineRefreshKeyCommon(
} }
resp.AuthURL = "" resp.AuthURL = ""
resp.User = *machine.User.toTailscaleUser() resp.User = *machine.User.TailscaleUser()
respBody, err := h.marshalResponse(resp, machineKey, isNoise) respBody, err := h.marshalResponse(resp, machineKey, isNoise)
if err != nil { if err != nil {
log.Error(). log.Error().
@ -770,7 +770,7 @@ func (h *Headscale) handleMachineRefreshKeyCommon(
func (h *Headscale) handleMachineExpiredOrLoggedOutCommon( func (h *Headscale) handleMachineExpiredOrLoggedOutCommon(
writer http.ResponseWriter, writer http.ResponseWriter,
registerRequest tailcfg.RegisterRequest, registerRequest tailcfg.RegisterRequest,
machine Machine, machine types.Machine,
machineKey key.MachinePublic, machineKey key.MachinePublic,
isNoise bool, isNoise bool,
) { ) {

View file

@ -6,6 +6,7 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
@ -24,16 +25,16 @@ const machineNameContextKey = contextKey("machineName")
func (h *Headscale) handlePollCommon( func (h *Headscale) handlePollCommon(
writer http.ResponseWriter, writer http.ResponseWriter,
ctx context.Context, ctx context.Context,
machine *Machine, machine *types.Machine,
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
isNoise bool, isNoise bool,
) { ) {
machine.Hostname = mapRequest.Hostinfo.Hostname machine.Hostname = mapRequest.Hostinfo.Hostname
machine.HostInfo = HostInfo(*mapRequest.Hostinfo) machine.HostInfo = types.HostInfo(*mapRequest.Hostinfo)
machine.DiscoKey = util.DiscoPublicKeyStripPrefix(mapRequest.DiscoKey) machine.DiscoKey = util.DiscoPublicKeyStripPrefix(mapRequest.DiscoKey)
now := time.Now().UTC() now := time.Now().UTC()
err := h.db.processMachineRoutes(machine) err := h.db.ProcessMachineRoutes(machine)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
@ -43,18 +44,13 @@ func (h *Headscale) handlePollCommon(
} }
// update ACLRules with peer informations (to update server tags if necessary) // update ACLRules with peer informations (to update server tags if necessary)
if h.aclPolicy != nil { if h.ACLPolicy != nil {
err := h.UpdateACLRules() // TODO(kradalby): Since this is not blocking, I might have introduced a bug here.
if err != nil { // It will be resolved later as we change up the policy stuff.
log.Error(). h.policyUpdateChan <- struct{}{}
Caller().
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Err(err)
}
// update routes with peer information // update routes with peer information
err = h.db.EnableAutoApprovedRoutes(h.aclPolicy, machine) err = h.db.EnableAutoApprovedRoutes(h.ACLPolicy, machine)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
@ -78,8 +74,7 @@ func (h *Headscale) handlePollCommon(
machine.LastSeen = &now machine.LastSeen = &now
} }
if err := h.db.db.Updates(machine).Error; err != nil { if err := h.db.MachineSave(machine); err != nil {
if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Bool("noise", isNoise). Bool("noise", isNoise).
@ -91,7 +86,6 @@ func (h *Headscale) handlePollCommon(
return return
} }
}
mapResp, err := h.getMapResponseData(mapRequest, machine, isNoise) mapResp, err := h.getMapResponseData(mapRequest, machine, isNoise)
if err != nil { if err != nil {
@ -244,7 +238,7 @@ func (h *Headscale) handlePollCommon(
func (h *Headscale) pollNetMapStream( func (h *Headscale) pollNetMapStream(
writer http.ResponseWriter, writer http.ResponseWriter,
ctxReq context.Context, ctxReq context.Context,
machine *Machine, machine *types.Machine,
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
pollDataChan chan []byte, pollDataChan chan []byte,
keepAliveChan chan []byte, keepAliveChan chan []byte,
@ -457,7 +451,7 @@ func (h *Headscale) pollNetMapStream(
updateRequestsReceivedOnChannel.WithLabelValues(machine.User.Name, machine.Hostname). updateRequestsReceivedOnChannel.WithLabelValues(machine.User.Name, machine.Hostname).
Inc() Inc()
if h.db.isOutdated(machine, h.getLastStateChange()) { if h.db.IsOutdated(machine, h.getLastStateChange()) {
var lastUpdate time.Time var lastUpdate time.Time
if machine.LastSuccessfulUpdate != nil { if machine.LastSuccessfulUpdate != nil {
lastUpdate = *machine.LastSuccessfulUpdate lastUpdate = *machine.LastSuccessfulUpdate
@ -626,7 +620,7 @@ func (h *Headscale) scheduledPollWorker(
updateChan chan struct{}, updateChan chan struct{},
keepAliveChan chan []byte, keepAliveChan chan []byte,
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
machine *Machine, machine *types.Machine,
isNoise bool, isNoise bool,
) { ) {
keepAliveTicker := time.NewTicker(keepAliveInterval) keepAliveTicker := time.NewTicker(keepAliveInterval)

View file

@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"sync" "sync"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/klauspost/compress/zstd" "github.com/klauspost/compress/zstd"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -15,7 +16,7 @@ import (
func (h *Headscale) getMapResponseData( func (h *Headscale) getMapResponseData(
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
machine *Machine, machine *types.Machine,
isNoise bool, isNoise bool,
) ([]byte, error) { ) ([]byte, error) {
mapResponse, err := h.generateMapResponse(mapRequest, machine) mapResponse, err := h.generateMapResponse(mapRequest, machine)
@ -43,7 +44,7 @@ func (h *Headscale) getMapResponseData(
func (h *Headscale) getMapKeepAliveResponseData( func (h *Headscale) getMapKeepAliveResponseData(
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
machine *Machine, machine *types.Machine,
isNoise bool, isNoise bool,
) ([]byte, error) { ) ([]byte, error) {
keepAliveResponse := tailcfg.MapResponse{ keepAliveResponse := tailcfg.MapResponse{

View file

@ -18,7 +18,7 @@ type Suite struct{}
var ( var (
tmpDir string tmpDir string
app Headscale app *Headscale
) )
func (s *Suite) SetUpTest(c *check.C) { func (s *Suite) SetUpTest(c *check.C) {
@ -34,11 +34,15 @@ func (s *Suite) ResetDB(c *check.C) {
os.RemoveAll(tmpDir) os.RemoveAll(tmpDir)
} }
var err error var err error
tmpDir, err = os.MkdirTemp("", "autoygg-client-test") tmpDir, err = os.MkdirTemp("", "autoygg-client-test2")
if err != nil { if err != nil {
c.Fatal(err) c.Fatal(err)
} }
cfg := Config{ cfg := Config{
PrivateKeyPath: tmpDir + "/private.key",
NoisePrivateKeyPath: tmpDir + "/noise_private.key",
DBtype: "sqlite3",
DBpath: tmpDir + "/headscale_test.db",
IPPrefixes: []netip.Prefix{ IPPrefixes: []netip.Prefix{
netip.MustParsePrefix("10.27.0.0/23"), netip.MustParsePrefix("10.27.0.0/23"),
}, },
@ -47,29 +51,8 @@ func (s *Suite) ResetDB(c *check.C) {
}, },
} }
// TODO(kradalby): make this use NewHeadscale properly so it doesnt drift app, err = NewHeadscale(&cfg)
app = Headscale{
cfg: &cfg,
dbType: "sqlite3",
dbString: tmpDir + "/headscale_test.db",
stateUpdateChan: make(chan struct{}),
cancelStateUpdateChan: make(chan struct{}),
}
go app.watchStateChannel()
db, err := NewHeadscaleDatabase(
app.dbType,
app.dbString,
cfg.OIDC.StripEmaildomain,
false,
app.stateUpdateChan,
cfg.IPPrefixes,
"",
)
if err != nil { if err != nil {
c.Fatal(err) c.Fatal(err)
} }
app.db = db
} }

View file

@ -0,0 +1,41 @@
package types
import (
"time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"google.golang.org/protobuf/types/known/timestamppb"
)
// APIKey describes the datamodel for API keys used to remotely authenticate with
// headscale.
type APIKey struct {
ID uint64 `gorm:"primary_key"`
Prefix string `gorm:"uniqueIndex"`
Hash []byte
CreatedAt *time.Time
Expiration *time.Time
LastSeen *time.Time
}
func (key *APIKey) Proto() *v1.ApiKey {
protoKey := v1.ApiKey{
Id: key.ID,
Prefix: key.Prefix,
}
if key.Expiration != nil {
protoKey.Expiration = timestamppb.New(*key.Expiration)
}
if key.CreatedAt != nil {
protoKey.CreatedAt = timestamppb.New(*key.CreatedAt)
}
if key.LastSeen != nil {
protoKey.LastSeen = timestamppb.New(*key.LastSeen)
}
return &protoKey
}

108
hscontrol/types/common.go Normal file
View file

@ -0,0 +1,108 @@
package types
import (
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"net/netip"
"tailscale.com/tailcfg"
)
var ErrCannotParsePrefix = errors.New("cannot parse prefix")
// This is a "wrapper" type around tailscales
// Hostinfo to allow us to add database "serialization"
// methods. This allows us to use a typed values throughout
// the code and not have to marshal/unmarshal and error
// check all over the code.
type HostInfo tailcfg.Hostinfo
func (hi *HostInfo) Scan(destination interface{}) error {
switch value := destination.(type) {
case []byte:
return json.Unmarshal(value, hi)
case string:
return json.Unmarshal([]byte(value), hi)
default:
return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination)
}
}
// Value return json value, implement driver.Valuer interface.
func (hi HostInfo) Value() (driver.Value, error) {
bytes, err := json.Marshal(hi)
return string(bytes), err
}
type IPPrefix netip.Prefix
func (i *IPPrefix) Scan(destination interface{}) error {
switch value := destination.(type) {
case string:
prefix, err := netip.ParsePrefix(value)
if err != nil {
return err
}
*i = IPPrefix(prefix)
return nil
default:
return fmt.Errorf("%w: unexpected data type %T", ErrCannotParsePrefix, destination)
}
}
// Value return json value, implement driver.Valuer interface.
func (i IPPrefix) Value() (driver.Value, error) {
prefixStr := netip.Prefix(i).String()
return prefixStr, nil
}
type IPPrefixes []netip.Prefix
func (i *IPPrefixes) Scan(destination interface{}) error {
switch value := destination.(type) {
case []byte:
return json.Unmarshal(value, i)
case string:
return json.Unmarshal([]byte(value), i)
default:
return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination)
}
}
// Value return json value, implement driver.Valuer interface.
func (i IPPrefixes) Value() (driver.Value, error) {
bytes, err := json.Marshal(i)
return string(bytes), err
}
type StringList []string
func (i *StringList) Scan(destination interface{}) error {
switch value := destination.(type) {
case []byte:
return json.Unmarshal(value, i)
case string:
return json.Unmarshal([]byte(value), i)
default:
return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination)
}
}
// Value return json value, implement driver.Valuer interface.
func (i StringList) Value() (driver.Value, error) {
bytes, err := json.Marshal(i)
return string(bytes), err
}

254
hscontrol/types/machine.go Normal file
View file

@ -0,0 +1,254 @@
package types
import (
"database/sql/driver"
"errors"
"fmt"
"net/netip"
"strings"
"time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"go4.org/netipx"
"google.golang.org/protobuf/types/known/timestamppb"
"tailscale.com/tailcfg"
)
const (
// TODO(kradalby): Move out of here when we got circdeps under control.
keepAliveInterval = 60 * time.Second
)
var ErrMachineAddressesInvalid = errors.New("failed to parse machine addresses")
// Machine is a Headscale client.
type Machine struct {
ID uint64 `gorm:"primary_key"`
MachineKey string `gorm:"type:varchar(64);unique_index"`
NodeKey string
DiscoKey string
IPAddresses MachineAddresses
// Hostname represents the name given by the Tailscale
// client during registration
Hostname string
// Givenname represents either:
// a DNS normalized version of Hostname
// a valid name set by the User
//
// GivenName is the name used in all DNS related
// parts of headscale.
GivenName string `gorm:"type:varchar(63);unique_index"`
UserID uint
User User `gorm:"foreignKey:UserID"`
RegisterMethod string
ForcedTags StringList
// TODO(kradalby): This seems like irrelevant information?
AuthKeyID uint
AuthKey *PreAuthKey
LastSeen *time.Time
LastSuccessfulUpdate *time.Time
Expiry *time.Time
HostInfo HostInfo
Endpoints StringList
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt *time.Time
}
type (
Machines []Machine
MachinesP []*Machine
)
type MachineAddresses []netip.Addr
func (ma MachineAddresses) ToStringSlice() []string {
strSlice := make([]string, 0, len(ma))
for _, addr := range ma {
strSlice = append(strSlice, addr.String())
}
return strSlice
}
// AppendToIPSet adds the individual ips in MachineAddresses to a
// given netipx.IPSetBuilder.
func (ma MachineAddresses) AppendToIPSet(build *netipx.IPSetBuilder) {
for _, ip := range ma {
build.Add(ip)
}
}
func (ma *MachineAddresses) Scan(destination interface{}) error {
switch value := destination.(type) {
case string:
addresses := strings.Split(value, ",")
*ma = (*ma)[:0]
for _, addr := range addresses {
if len(addr) < 1 {
continue
}
parsed, err := netip.ParseAddr(addr)
if err != nil {
return err
}
*ma = append(*ma, parsed)
}
return nil
default:
return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination)
}
}
// Value return json value, implement driver.Valuer interface.
func (ma MachineAddresses) Value() (driver.Value, error) {
addresses := strings.Join(ma.ToStringSlice(), ",")
return addresses, nil
}
// IsExpired returns whether the machine registration has expired.
func (machine Machine) IsExpired() bool {
// If Expiry is not set, the client has not indicated that
// it wants an expiry time, it is therefor considered
// to mean "not expired"
if machine.Expiry == nil || machine.Expiry.IsZero() {
return false
}
return time.Now().UTC().After(*machine.Expiry)
}
// IsOnline returns if the machine is connected to Headscale.
// This is really a naive implementation, as we don't really see
// if there is a working connection between the client and the server.
func (machine *Machine) IsOnline() bool {
if machine.LastSeen == nil {
return false
}
if machine.IsExpired() {
return false
}
return machine.LastSeen.After(time.Now().Add(-keepAliveInterval))
}
// IsEphemeral returns if the machine is registered as an Ephemeral node.
// https://tailscale.com/kb/1111/ephemeral-nodes/
func (machine *Machine) IsEphemeral() bool {
return machine.AuthKey != nil && machine.AuthKey.Ephemeral
}
func (machine *Machine) CanAccess(filter []tailcfg.FilterRule, machine2 *Machine) bool {
for _, rule := range filter {
// TODO(kradalby): Cache or pregen this
matcher := matcher.MatchFromFilterRule(rule)
if !matcher.SrcsContainsIPs([]netip.Addr(machine.IPAddresses)) {
continue
}
if matcher.DestsContainsIP([]netip.Addr(machine2.IPAddresses)) {
return true
}
}
return false
}
func (machines Machines) FilterByIP(ip netip.Addr) Machines {
found := make(Machines, 0)
for _, machine := range machines {
for _, mIP := range machine.IPAddresses {
if ip == mIP {
found = append(found, machine)
}
}
}
return found
}
func (machine *Machine) Proto() *v1.Machine {
machineProto := &v1.Machine{
Id: machine.ID,
MachineKey: machine.MachineKey,
NodeKey: machine.NodeKey,
DiscoKey: machine.DiscoKey,
IpAddresses: machine.IPAddresses.ToStringSlice(),
Name: machine.Hostname,
GivenName: machine.GivenName,
User: machine.User.Proto(),
ForcedTags: machine.ForcedTags,
Online: machine.IsOnline(),
// TODO(kradalby): Implement register method enum converter
// RegisterMethod: ,
CreatedAt: timestamppb.New(machine.CreatedAt),
}
if machine.AuthKey != nil {
machineProto.PreAuthKey = machine.AuthKey.Proto()
}
if machine.LastSeen != nil {
machineProto.LastSeen = timestamppb.New(*machine.LastSeen)
}
if machine.LastSuccessfulUpdate != nil {
machineProto.LastSuccessfulUpdate = timestamppb.New(
*machine.LastSuccessfulUpdate,
)
}
if machine.Expiry != nil {
machineProto.Expiry = timestamppb.New(*machine.Expiry)
}
return machineProto
}
// GetHostInfo returns a Hostinfo struct for the machine.
func (machine *Machine) GetHostInfo() tailcfg.Hostinfo {
return tailcfg.Hostinfo(machine.HostInfo)
}
func (machine Machine) String() string {
return machine.Hostname
}
func (machines Machines) String() string {
temp := make([]string, len(machines))
for index, machine := range machines {
temp[index] = machine.Hostname
}
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
}
// TODO(kradalby): Remove when we have generics...
func (machines MachinesP) String() string {
temp := make([]string, len(machines))
for index, machine := range machines {
temp[index] = machine.Hostname
}
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
}

View file

@ -0,0 +1 @@
package types

View file

@ -0,0 +1,58 @@
package types
import (
"strconv"
"time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/util"
"google.golang.org/protobuf/types/known/timestamppb"
)
// PreAuthKey describes a pre-authorization key usable in a particular user.
type PreAuthKey struct {
ID uint64 `gorm:"primary_key"`
Key string
UserID uint
User User
Reusable bool
Ephemeral bool `gorm:"default:false"`
Used bool `gorm:"default:false"`
ACLTags []PreAuthKeyACLTag
CreatedAt *time.Time
Expiration *time.Time
}
// PreAuthKeyACLTag describes an autmatic tag applied to a node when registered with the associated PreAuthKey.
type PreAuthKeyACLTag struct {
ID uint64 `gorm:"primary_key"`
PreAuthKeyID uint64
Tag string
}
func (key *PreAuthKey) Proto() *v1.PreAuthKey {
protoKey := v1.PreAuthKey{
User: key.User.Name,
Id: strconv.FormatUint(key.ID, util.Base10),
Key: key.Key,
Ephemeral: key.Ephemeral,
Reusable: key.Reusable,
Used: key.Used,
AclTags: make([]string, len(key.ACLTags)),
}
if key.Expiration != nil {
protoKey.Expiration = timestamppb.New(*key.Expiration)
}
if key.CreatedAt != nil {
protoKey.CreatedAt = timestamppb.New(*key.CreatedAt)
}
for idx := range key.ACLTags {
protoKey.AclTags[idx] = key.ACLTags[idx].Tag
}
return &protoKey
}

71
hscontrol/types/routes.go Normal file
View file

@ -0,0 +1,71 @@
package types
import (
"fmt"
"net/netip"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"google.golang.org/protobuf/types/known/timestamppb"
"gorm.io/gorm"
)
var (
ExitRouteV4 = netip.MustParsePrefix("0.0.0.0/0")
ExitRouteV6 = netip.MustParsePrefix("::/0")
)
type Route struct {
gorm.Model
MachineID uint64
Machine Machine
Prefix IPPrefix
Advertised bool
Enabled bool
IsPrimary bool
}
type Routes []Route
func (r *Route) String() string {
return fmt.Sprintf("%s:%s", r.Machine, netip.Prefix(r.Prefix).String())
}
func (r *Route) IsExitRoute() bool {
return netip.Prefix(r.Prefix) == ExitRouteV4 || netip.Prefix(r.Prefix) == ExitRouteV6
}
func (rs Routes) Prefixes() []netip.Prefix {
prefixes := make([]netip.Prefix, len(rs))
for i, r := range rs {
prefixes[i] = netip.Prefix(r.Prefix)
}
return prefixes
}
func (rs Routes) Proto() []*v1.Route {
protoRoutes := []*v1.Route{}
for _, route := range rs {
protoRoute := v1.Route{
Id: uint64(route.ID),
Machine: route.Machine.Proto(),
Prefix: netip.Prefix(route.Prefix).String(),
Advertised: route.Advertised,
Enabled: route.Enabled,
IsPrimary: route.IsPrimary,
CreatedAt: timestamppb.New(route.CreatedAt),
UpdatedAt: timestamppb.New(route.UpdatedAt),
}
if route.DeletedAt.Valid {
protoRoute.DeletedAt = timestamppb.New(route.DeletedAt.Time)
}
protoRoutes = append(protoRoutes, &protoRoute)
}
return protoRoutes
}

55
hscontrol/types/users.go Normal file
View file

@ -0,0 +1,55 @@
package types
import (
"strconv"
"time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/util"
"google.golang.org/protobuf/types/known/timestamppb"
"gorm.io/gorm"
"tailscale.com/tailcfg"
)
// User is the way Headscale implements the concept of users in Tailscale
//
// At the end of the day, users in Tailscale are some kind of 'bubbles' or users
// that contain our machines.
type User struct {
gorm.Model
Name string `gorm:"unique"`
}
func (n *User) TailscaleUser() *tailcfg.User {
user := tailcfg.User{
ID: tailcfg.UserID(n.ID),
LoginName: n.Name,
DisplayName: n.Name,
ProfilePicURL: "",
Domain: "headscale.net",
Logins: []tailcfg.LoginID{},
Created: time.Time{},
}
return &user
}
func (n *User) TailscaleLogin() *tailcfg.Login {
login := tailcfg.Login{
ID: tailcfg.LoginID(n.ID),
LoginName: n.Name,
DisplayName: n.Name,
ProfilePicURL: "",
Domain: "headscale.net",
}
return &login
}
func (n *User) Proto() *v1.User {
return &v1.User{
Id: strconv.FormatUint(uint64(n.ID), util.Base10),
Name: n.Name,
CreatedAt: timestamppb.New(n.CreatedAt),
}
}

View file

@ -1,415 +0,0 @@
package hscontrol
import (
"net/netip"
"testing"
"gopkg.in/check.v1"
"gorm.io/gorm"
)
func (s *Suite) TestCreateAndDestroyUser(c *check.C) {
user, err := app.db.CreateUser("test")
c.Assert(err, check.IsNil)
c.Assert(user.Name, check.Equals, "test")
users, err := app.db.ListUsers()
c.Assert(err, check.IsNil)
c.Assert(len(users), check.Equals, 1)
err = app.db.DestroyUser("test")
c.Assert(err, check.IsNil)
_, err = app.db.GetUser("test")
c.Assert(err, check.NotNil)
}
func (s *Suite) TestDestroyUserErrors(c *check.C) {
err := app.db.DestroyUser("test")
c.Assert(err, check.Equals, ErrUserNotFound)
user, err := app.db.CreateUser("test")
c.Assert(err, check.IsNil)
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
err = app.db.DestroyUser("test")
c.Assert(err, check.IsNil)
result := app.db.db.Preload("User").First(&pak, "key = ?", pak.Key)
// destroying a user also deletes all associated preauthkeys
c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound)
user, err = app.db.CreateUser("test")
c.Assert(err, check.IsNil)
pak, err = app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testmachine",
UserID: user.ID,
RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
app.db.db.Save(&machine)
err = app.db.DestroyUser("test")
c.Assert(err, check.Equals, ErrUserStillHasNodes)
}
func (s *Suite) TestRenameUser(c *check.C) {
userTest, err := app.db.CreateUser("test")
c.Assert(err, check.IsNil)
c.Assert(userTest.Name, check.Equals, "test")
users, err := app.db.ListUsers()
c.Assert(err, check.IsNil)
c.Assert(len(users), check.Equals, 1)
err = app.db.RenameUser("test", "test-renamed")
c.Assert(err, check.IsNil)
_, err = app.db.GetUser("test")
c.Assert(err, check.Equals, ErrUserNotFound)
_, err = app.db.GetUser("test-renamed")
c.Assert(err, check.IsNil)
err = app.db.RenameUser("test-does-not-exit", "test")
c.Assert(err, check.Equals, ErrUserNotFound)
userTest2, err := app.db.CreateUser("test2")
c.Assert(err, check.IsNil)
c.Assert(userTest2.Name, check.Equals, "test2")
err = app.db.RenameUser("test2", "test-renamed")
c.Assert(err, check.Equals, ErrUserExists)
}
func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
userShared1, err := app.db.CreateUser("shared1")
c.Assert(err, check.IsNil)
userShared2, err := app.db.CreateUser("shared2")
c.Assert(err, check.IsNil)
userShared3, err := app.db.CreateUser("shared3")
c.Assert(err, check.IsNil)
preAuthKeyShared1, err := app.db.CreatePreAuthKey(
userShared1.Name,
false,
false,
nil,
nil,
)
c.Assert(err, check.IsNil)
preAuthKeyShared2, err := app.db.CreatePreAuthKey(
userShared2.Name,
false,
false,
nil,
nil,
)
c.Assert(err, check.IsNil)
preAuthKeyShared3, err := app.db.CreatePreAuthKey(
userShared3.Name,
false,
false,
nil,
nil,
)
c.Assert(err, check.IsNil)
preAuthKey2Shared1, err := app.db.CreatePreAuthKey(
userShared1.Name,
false,
false,
nil,
nil,
)
c.Assert(err, check.IsNil)
_, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1")
c.Assert(err, check.NotNil)
machineInShared1 := &Machine{
ID: 1,
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
Hostname: "test_get_shared_nodes_1",
UserID: userShared1.ID,
User: *userShared1,
RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
AuthKeyID: uint(preAuthKeyShared1.ID),
}
app.db.db.Save(machineInShared1)
_, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname)
c.Assert(err, check.IsNil)
machineInShared2 := &Machine{
ID: 2,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Hostname: "test_get_shared_nodes_2",
UserID: userShared2.ID,
User: *userShared2,
RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
AuthKeyID: uint(preAuthKeyShared2.ID),
}
app.db.db.Save(machineInShared2)
_, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname)
c.Assert(err, check.IsNil)
machineInShared3 := &Machine{
ID: 3,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Hostname: "test_get_shared_nodes_3",
UserID: userShared3.ID,
User: *userShared3,
RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
AuthKeyID: uint(preAuthKeyShared3.ID),
}
app.db.db.Save(machineInShared3)
_, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname)
c.Assert(err, check.IsNil)
machine2InShared1 := &Machine{
ID: 4,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Hostname: "test_get_shared_nodes_4",
UserID: userShared1.ID,
User: *userShared1,
RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")},
AuthKeyID: uint(preAuthKey2Shared1.ID),
}
app.db.db.Save(machine2InShared1)
peersOfMachine1InShared1, err := app.db.getPeers(app.aclPolicy, app.aclRules, machineInShared1)
c.Assert(err, check.IsNil)
userProfiles := app.db.getMapResponseUserProfiles(
*machineInShared1,
peersOfMachine1InShared1,
)
c.Assert(len(userProfiles), check.Equals, 3)
found := false
for _, userProfiles := range userProfiles {
if userProfiles.DisplayName == userShared1.Name {
found = true
break
}
}
c.Assert(found, check.Equals, true)
found = false
for _, userProfile := range userProfiles {
if userProfile.DisplayName == userShared2.Name {
found = true
break
}
}
c.Assert(found, check.Equals, true)
}
func TestNormalizeToFQDNRules(t *testing.T) {
type args struct {
name string
stripEmailDomain bool
}
tests := []struct {
name string
args args
want string
wantErr bool
}{
{
name: "normalize simple name",
args: args{
name: "normalize-simple.name",
stripEmailDomain: false,
},
want: "normalize-simple.name",
wantErr: false,
},
{
name: "normalize an email",
args: args{
name: "foo.bar@example.com",
stripEmailDomain: false,
},
want: "foo.bar.example.com",
wantErr: false,
},
{
name: "normalize an email domain should be removed",
args: args{
name: "foo.bar@example.com",
stripEmailDomain: true,
},
want: "foo.bar",
wantErr: false,
},
{
name: "strip enabled no email passed as argument",
args: args{
name: "not-email-and-strip-enabled",
stripEmailDomain: true,
},
want: "not-email-and-strip-enabled",
wantErr: false,
},
{
name: "normalize complex email",
args: args{
name: "foo.bar+complex-email@example.com",
stripEmailDomain: false,
},
want: "foo.bar-complex-email.example.com",
wantErr: false,
},
{
name: "user name with space",
args: args{
name: "name space",
stripEmailDomain: false,
},
want: "name-space",
wantErr: false,
},
{
name: "user with quote",
args: args{
name: "Jamie's iPhone 5",
stripEmailDomain: false,
},
want: "jamies-iphone-5",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NormalizeToFQDNRules(tt.args.name, tt.args.stripEmailDomain)
if (err != nil) != tt.wantErr {
t.Errorf(
"NormalizeToFQDNRules() error = %v, wantErr %v",
err,
tt.wantErr,
)
return
}
if got != tt.want {
t.Errorf("NormalizeToFQDNRules() = %v, want %v", got, tt.want)
}
})
}
}
func TestCheckForFQDNRules(t *testing.T) {
type args struct {
name string
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "valid: user",
args: args{name: "valid-user"},
wantErr: false,
},
{
name: "invalid: capitalized user",
args: args{name: "Invalid-CapItaLIzed-user"},
wantErr: true,
},
{
name: "invalid: email as user",
args: args{name: "foo.bar@example.com"},
wantErr: true,
},
{
name: "invalid: chars in user name",
args: args{name: "super-user+name"},
wantErr: true,
},
{
name: "invalid: too long name for user",
args: args{
name: "super-long-useruseruser-name-that-should-be-a-little-more-than-63-chars",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := CheckForFQDNRules(tt.args.name); (err != nil) != tt.wantErr {
t.Errorf("CheckForFQDNRules() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func (s *Suite) TestSetMachineUser(c *check.C) {
oldUser, err := app.db.CreateUser("old")
c.Assert(err, check.IsNil)
newUser, err := app.db.CreateUser("new")
c.Assert(err, check.IsNil)
pak, err := app.db.CreatePreAuthKey(oldUser.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testmachine",
UserID: oldUser.ID,
RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
app.db.db.Save(&machine)
c.Assert(machine.UserID, check.Equals, oldUser.ID)
err = app.db.SetMachineUser(&machine, newUser.Name)
c.Assert(err, check.IsNil)
c.Assert(machine.UserID, check.Equals, newUser.ID)
c.Assert(machine.User.Name, check.Equals, newUser.Name)
err = app.db.SetMachineUser(&machine, "non-existing-user")
c.Assert(err, check.Equals, ErrUserNotFound)
err = app.db.SetMachineUser(&machine, newUser.Name)
c.Assert(err, check.IsNil)
c.Assert(machine.UserID, check.Equals, newUser.ID)
c.Assert(machine.User.Name, check.Equals, newUser.Name)
}

View file

@ -1,12 +1,94 @@
package util package util
import ( import (
"fmt"
"net/netip" "net/netip"
"reflect" "reflect"
"strings"
"go4.org/netipx" "go4.org/netipx"
) )
// This is borrowed from, and updated to use IPSet
// https://github.com/tailscale/tailscale/blob/71029cea2ddf82007b80f465b256d027eab0f02d/wgengine/filter/tailcfg.go#L97-L162
// TODO(kradalby): contribute upstream and make public.
var (
zeroIP4 = netip.AddrFrom4([4]byte{})
zeroIP6 = netip.AddrFrom16([16]byte{})
)
// parseIPSet parses arg as one:
//
// - an IP address (IPv4 or IPv6)
// - the string "*" to match everything (both IPv4 & IPv6)
// - a CIDR (e.g. "192.168.0.0/16")
// - a range of two IPs, inclusive, separated by hyphen ("2eff::1-2eff::0800")
//
// bits, if non-nil, is the legacy SrcBits CIDR length to make a IP
// address (without a slash) treated as a CIDR of *bits length.
// nolint
func ParseIPSet(arg string, bits *int) (*netipx.IPSet, error) {
var ipSet netipx.IPSetBuilder
if arg == "*" {
ipSet.AddPrefix(netip.PrefixFrom(zeroIP4, 0))
ipSet.AddPrefix(netip.PrefixFrom(zeroIP6, 0))
return ipSet.IPSet()
}
if strings.Contains(arg, "/") {
pfx, err := netip.ParsePrefix(arg)
if err != nil {
return nil, err
}
if pfx != pfx.Masked() {
return nil, fmt.Errorf("%v contains non-network bits set", pfx)
}
ipSet.AddPrefix(pfx)
return ipSet.IPSet()
}
if strings.Count(arg, "-") == 1 {
ip1s, ip2s, _ := strings.Cut(arg, "-")
ip1, err := netip.ParseAddr(ip1s)
if err != nil {
return nil, err
}
ip2, err := netip.ParseAddr(ip2s)
if err != nil {
return nil, err
}
r := netipx.IPRangeFrom(ip1, ip2)
if !r.IsValid() {
return nil, fmt.Errorf("invalid IP range %q", arg)
}
for _, prefix := range r.Prefixes() {
ipSet.AddPrefix(prefix)
}
return ipSet.IPSet()
}
ip, err := netip.ParseAddr(arg)
if err != nil {
return nil, fmt.Errorf("invalid IP address %q", arg)
}
bits8 := uint8(ip.BitLen())
if bits != nil {
if *bits < 0 || *bits > int(bits8) {
return nil, fmt.Errorf("invalid CIDR size %d for IP %q", *bits, arg)
}
bits8 = uint8(*bits)
}
ipSet.AddPrefix(netip.PrefixFrom(ip, int(bits8)))
return ipSet.IPSet()
}
func GetIPPrefixEndpoints(na netip.Prefix) (netip.Addr, netip.Addr) { func GetIPPrefixEndpoints(na netip.Prefix) (netip.Addr, netip.Addr) {
var network, broadcast netip.Addr var network, broadcast netip.Addr
ipRange := netipx.RangeOfPrefix(na) ipRange := netipx.RangeOfPrefix(na)

View file

@ -1,4 +1,4 @@
package hscontrol package util
import ( import (
"net/netip" "net/netip"
@ -105,7 +105,7 @@ func Test_parseIPSet(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := parseIPSet(tt.args.arg, tt.args.bits) got, err := ParseIPSet(tt.args.arg, tt.args.bits)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("parseIPSet() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("parseIPSet() error = %v, wantErr %v", err, tt.wantErr)

7
hscontrol/util/const.go Normal file
View file

@ -0,0 +1,7 @@
package util
const (
RegisterMethodAuthKey = "authkey"
RegisterMethodOIDC = "oidc"
RegisterMethodCLI = "cli"
)

69
hscontrol/util/dns.go Normal file
View file

@ -0,0 +1,69 @@
package util
import (
"errors"
"fmt"
"regexp"
"strings"
)
const (
// value related to RFC 1123 and 952.
LabelHostnameLength = 63
)
var invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+")
var ErrInvalidUserName = errors.New("invalid user name")
// NormalizeToFQDNRules will replace forbidden chars in user
// it can also return an error if the user doesn't respect RFC 952 and 1123.
func NormalizeToFQDNRules(name string, stripEmailDomain bool) (string, error) {
name = strings.ToLower(name)
name = strings.ReplaceAll(name, "'", "")
atIdx := strings.Index(name, "@")
if stripEmailDomain && atIdx > 0 {
name = name[:atIdx]
} else {
name = strings.ReplaceAll(name, "@", ".")
}
name = invalidCharsInUserRegex.ReplaceAllString(name, "-")
for _, elt := range strings.Split(name, ".") {
if len(elt) > LabelHostnameLength {
return "", fmt.Errorf(
"label %v is more than 63 chars: %w",
elt,
ErrInvalidUserName,
)
}
}
return name, nil
}
func CheckForFQDNRules(name string) error {
if len(name) > LabelHostnameLength {
return fmt.Errorf(
"DNS segment must not be over 63 chars. %v doesn't comply with this rule: %w",
name,
ErrInvalidUserName,
)
}
if strings.ToLower(name) != name {
return fmt.Errorf(
"DNS segment should be lowercase. %v doesn't comply with this rule: %w",
name,
ErrInvalidUserName,
)
}
if invalidCharsInUserRegex.MatchString(name) {
return fmt.Errorf(
"DNS segment should only be composed of lowercase ASCII letters numbers, hyphen and dots. %v doesn't comply with theses rules: %w",
name,
ErrInvalidUserName,
)
}
return nil
}

143
hscontrol/util/dns_test.go Normal file
View file

@ -0,0 +1,143 @@
package util
import "testing"
func TestNormalizeToFQDNRules(t *testing.T) {
type args struct {
name string
stripEmailDomain bool
}
tests := []struct {
name string
args args
want string
wantErr bool
}{
{
name: "normalize simple name",
args: args{
name: "normalize-simple.name",
stripEmailDomain: false,
},
want: "normalize-simple.name",
wantErr: false,
},
{
name: "normalize an email",
args: args{
name: "foo.bar@example.com",
stripEmailDomain: false,
},
want: "foo.bar.example.com",
wantErr: false,
},
{
name: "normalize an email domain should be removed",
args: args{
name: "foo.bar@example.com",
stripEmailDomain: true,
},
want: "foo.bar",
wantErr: false,
},
{
name: "strip enabled no email passed as argument",
args: args{
name: "not-email-and-strip-enabled",
stripEmailDomain: true,
},
want: "not-email-and-strip-enabled",
wantErr: false,
},
{
name: "normalize complex email",
args: args{
name: "foo.bar+complex-email@example.com",
stripEmailDomain: false,
},
want: "foo.bar-complex-email.example.com",
wantErr: false,
},
{
name: "user name with space",
args: args{
name: "name space",
stripEmailDomain: false,
},
want: "name-space",
wantErr: false,
},
{
name: "user with quote",
args: args{
name: "Jamie's iPhone 5",
stripEmailDomain: false,
},
want: "jamies-iphone-5",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NormalizeToFQDNRules(tt.args.name, tt.args.stripEmailDomain)
if (err != nil) != tt.wantErr {
t.Errorf(
"NormalizeToFQDNRules() error = %v, wantErr %v",
err,
tt.wantErr,
)
return
}
if got != tt.want {
t.Errorf("NormalizeToFQDNRules() = %v, want %v", got, tt.want)
}
})
}
}
func TestCheckForFQDNRules(t *testing.T) {
type args struct {
name string
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "valid: user",
args: args{name: "valid-user"},
wantErr: false,
},
{
name: "invalid: capitalized user",
args: args{name: "Invalid-CapItaLIzed-user"},
wantErr: true,
},
{
name: "invalid: email as user",
args: args{name: "foo.bar@example.com"},
wantErr: true,
},
{
name: "invalid: chars in user name",
args: args{name: "super-user+name"},
wantErr: true,
},
{
name: "invalid: too long name for user",
args: args{
name: "super-long-useruseruser-name-that-should-be-a-little-more-than-63-chars",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := CheckForFQDNRules(tt.args.name); (err != nil) != tt.wantErr {
t.Errorf("CheckForFQDNRules() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View file

@ -6,7 +6,7 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/juanfont/headscale/hscontrol" "github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/tsic" "github.com/juanfont/headscale/integration/tsic"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -45,7 +45,7 @@ var veryLargeDestination = []string{
"208.0.0.0/4:*", "208.0.0.0/4:*",
} }
func aclScenario(t *testing.T, policy *hscontrol.ACLPolicy, clientsPerUser int) *Scenario { func aclScenario(t *testing.T, policy *policy.ACLPolicy, clientsPerUser int) *Scenario {
t.Helper() t.Helper()
scenario, err := NewScenario() scenario, err := NewScenario()
assert.NoError(t, err) assert.NoError(t, err)
@ -92,7 +92,7 @@ func TestACLHostsInNetMapTable(t *testing.T) {
// they can access minus one (them self). // they can access minus one (them self).
tests := map[string]struct { tests := map[string]struct {
users map[string]int users map[string]int
policy hscontrol.ACLPolicy policy policy.ACLPolicy
want map[string]int want map[string]int
}{ }{
// Test that when we have no ACL, each client netmap has // Test that when we have no ACL, each client netmap has
@ -102,8 +102,8 @@ func TestACLHostsInNetMapTable(t *testing.T) {
"user1": 2, "user1": 2,
"user2": 2, "user2": 2,
}, },
policy: hscontrol.ACLPolicy{ policy: policy.ACLPolicy{
ACLs: []hscontrol.ACL{ ACLs: []policy.ACL{
{ {
Action: "accept", Action: "accept",
Sources: []string{"*"}, Sources: []string{"*"},
@ -123,8 +123,8 @@ func TestACLHostsInNetMapTable(t *testing.T) {
"user1": 2, "user1": 2,
"user2": 2, "user2": 2,
}, },
policy: hscontrol.ACLPolicy{ policy: policy.ACLPolicy{
ACLs: []hscontrol.ACL{ ACLs: []policy.ACL{
{ {
Action: "accept", Action: "accept",
Sources: []string{"user1"}, Sources: []string{"user1"},
@ -149,8 +149,8 @@ func TestACLHostsInNetMapTable(t *testing.T) {
"user1": 2, "user1": 2,
"user2": 2, "user2": 2,
}, },
policy: hscontrol.ACLPolicy{ policy: policy.ACLPolicy{
ACLs: []hscontrol.ACL{ ACLs: []policy.ACL{
{ {
Action: "accept", Action: "accept",
Sources: []string{"user1"}, Sources: []string{"user1"},
@ -186,8 +186,8 @@ func TestACLHostsInNetMapTable(t *testing.T) {
"user1": 2, "user1": 2,
"user2": 2, "user2": 2,
}, },
policy: hscontrol.ACLPolicy{ policy: policy.ACLPolicy{
ACLs: []hscontrol.ACL{ ACLs: []policy.ACL{
{ {
Action: "accept", Action: "accept",
Sources: []string{"user1"}, Sources: []string{"user1"},
@ -214,8 +214,8 @@ func TestACLHostsInNetMapTable(t *testing.T) {
"user1": 2, "user1": 2,
"user2": 2, "user2": 2,
}, },
policy: hscontrol.ACLPolicy{ policy: policy.ACLPolicy{
ACLs: []hscontrol.ACL{ ACLs: []policy.ACL{
{ {
Action: "accept", Action: "accept",
Sources: []string{"user1"}, Sources: []string{"user1"},
@ -282,8 +282,8 @@ func TestACLAllowUser80Dst(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
scenario := aclScenario(t, scenario := aclScenario(t,
&hscontrol.ACLPolicy{ &policy.ACLPolicy{
ACLs: []hscontrol.ACL{ ACLs: []policy.ACL{
{ {
Action: "accept", Action: "accept",
Sources: []string{"user1"}, Sources: []string{"user1"},
@ -338,11 +338,11 @@ func TestACLDenyAllPort80(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
scenario := aclScenario(t, scenario := aclScenario(t,
&hscontrol.ACLPolicy{ &policy.ACLPolicy{
Groups: map[string][]string{ Groups: map[string][]string{
"group:integration-acl-test": {"user1", "user2"}, "group:integration-acl-test": {"user1", "user2"},
}, },
ACLs: []hscontrol.ACL{ ACLs: []policy.ACL{
{ {
Action: "accept", Action: "accept",
Sources: []string{"group:integration-acl-test"}, Sources: []string{"group:integration-acl-test"},
@ -387,8 +387,8 @@ func TestACLAllowUserDst(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
scenario := aclScenario(t, scenario := aclScenario(t,
&hscontrol.ACLPolicy{ &policy.ACLPolicy{
ACLs: []hscontrol.ACL{ ACLs: []policy.ACL{
{ {
Action: "accept", Action: "accept",
Sources: []string{"user1"}, Sources: []string{"user1"},
@ -445,8 +445,8 @@ func TestACLAllowStarDst(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
scenario := aclScenario(t, scenario := aclScenario(t,
&hscontrol.ACLPolicy{ &policy.ACLPolicy{
ACLs: []hscontrol.ACL{ ACLs: []policy.ACL{
{ {
Action: "accept", Action: "accept",
Sources: []string{"user1"}, Sources: []string{"user1"},
@ -504,11 +504,11 @@ func TestACLNamedHostsCanReachBySubnet(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
scenario := aclScenario(t, scenario := aclScenario(t,
&hscontrol.ACLPolicy{ &policy.ACLPolicy{
Hosts: hscontrol.Hosts{ Hosts: policy.Hosts{
"all": netip.MustParsePrefix("100.64.0.0/24"), "all": netip.MustParsePrefix("100.64.0.0/24"),
}, },
ACLs: []hscontrol.ACL{ ACLs: []policy.ACL{
// Everyone can curl test3 // Everyone can curl test3
{ {
Action: "accept", Action: "accept",
@ -603,16 +603,16 @@ func TestACLNamedHostsCanReach(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
tests := map[string]struct { tests := map[string]struct {
policy hscontrol.ACLPolicy policy policy.ACLPolicy
}{ }{
"ipv4": { "ipv4": {
policy: hscontrol.ACLPolicy{ policy: policy.ACLPolicy{
Hosts: hscontrol.Hosts{ Hosts: policy.Hosts{
"test1": netip.MustParsePrefix("100.64.0.1/32"), "test1": netip.MustParsePrefix("100.64.0.1/32"),
"test2": netip.MustParsePrefix("100.64.0.2/32"), "test2": netip.MustParsePrefix("100.64.0.2/32"),
"test3": netip.MustParsePrefix("100.64.0.3/32"), "test3": netip.MustParsePrefix("100.64.0.3/32"),
}, },
ACLs: []hscontrol.ACL{ ACLs: []policy.ACL{
// Everyone can curl test3 // Everyone can curl test3
{ {
Action: "accept", Action: "accept",
@ -629,13 +629,13 @@ func TestACLNamedHostsCanReach(t *testing.T) {
}, },
}, },
"ipv6": { "ipv6": {
policy: hscontrol.ACLPolicy{ policy: policy.ACLPolicy{
Hosts: hscontrol.Hosts{ Hosts: policy.Hosts{
"test1": netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), "test1": netip.MustParsePrefix("fd7a:115c:a1e0::1/128"),
"test2": netip.MustParsePrefix("fd7a:115c:a1e0::2/128"), "test2": netip.MustParsePrefix("fd7a:115c:a1e0::2/128"),
"test3": netip.MustParsePrefix("fd7a:115c:a1e0::3/128"), "test3": netip.MustParsePrefix("fd7a:115c:a1e0::3/128"),
}, },
ACLs: []hscontrol.ACL{ ACLs: []policy.ACL{
// Everyone can curl test3 // Everyone can curl test3
{ {
Action: "accept", Action: "accept",
@ -854,11 +854,11 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
tests := map[string]struct { tests := map[string]struct {
policy hscontrol.ACLPolicy policy policy.ACLPolicy
}{ }{
"ipv4": { "ipv4": {
policy: hscontrol.ACLPolicy{ policy: policy.ACLPolicy{
ACLs: []hscontrol.ACL{ ACLs: []policy.ACL{
{ {
Action: "accept", Action: "accept",
Sources: []string{"100.64.0.1"}, Sources: []string{"100.64.0.1"},
@ -868,8 +868,8 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) {
}, },
}, },
"ipv6": { "ipv6": {
policy: hscontrol.ACLPolicy{ policy: policy.ACLPolicy{
ACLs: []hscontrol.ACL{ ACLs: []policy.ACL{
{ {
Action: "accept", Action: "accept",
Sources: []string{"fd7a:115c:a1e0::1"}, Sources: []string{"fd7a:115c:a1e0::1"},
@ -879,12 +879,12 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) {
}, },
}, },
"hostv4cidr": { "hostv4cidr": {
policy: hscontrol.ACLPolicy{ policy: policy.ACLPolicy{
Hosts: hscontrol.Hosts{ Hosts: policy.Hosts{
"test1": netip.MustParsePrefix("100.64.0.1/32"), "test1": netip.MustParsePrefix("100.64.0.1/32"),
"test2": netip.MustParsePrefix("100.64.0.2/32"), "test2": netip.MustParsePrefix("100.64.0.2/32"),
}, },
ACLs: []hscontrol.ACL{ ACLs: []policy.ACL{
{ {
Action: "accept", Action: "accept",
Sources: []string{"test1"}, Sources: []string{"test1"},
@ -894,12 +894,12 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) {
}, },
}, },
"hostv6cidr": { "hostv6cidr": {
policy: hscontrol.ACLPolicy{ policy: policy.ACLPolicy{
Hosts: hscontrol.Hosts{ Hosts: policy.Hosts{
"test1": netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), "test1": netip.MustParsePrefix("fd7a:115c:a1e0::1/128"),
"test2": netip.MustParsePrefix("fd7a:115c:a1e0::2/128"), "test2": netip.MustParsePrefix("fd7a:115c:a1e0::2/128"),
}, },
ACLs: []hscontrol.ACL{ ACLs: []policy.ACL{
{ {
Action: "accept", Action: "accept",
Sources: []string{"test1"}, Sources: []string{"test1"},
@ -909,12 +909,12 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) {
}, },
}, },
"group": { "group": {
policy: hscontrol.ACLPolicy{ policy: policy.ACLPolicy{
Groups: map[string][]string{ Groups: map[string][]string{
"group:one": {"user1"}, "group:one": {"user1"},
"group:two": {"user2"}, "group:two": {"user2"},
}, },
ACLs: []hscontrol.ACL{ ACLs: []policy.ACL{
{ {
Action: "accept", Action: "accept",
Sources: []string{"group:one"}, Sources: []string{"group:one"},

View file

@ -23,7 +23,7 @@ import (
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol" "github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/dockertestutil"
"github.com/juanfont/headscale/integration/integrationutil" "github.com/juanfont/headscale/integration/integrationutil"
@ -60,7 +60,7 @@ type HeadscaleInContainer struct {
port int port int
extraPorts []string extraPorts []string
hostPortBindings map[string][]string hostPortBindings map[string][]string
aclPolicy *hscontrol.ACLPolicy aclPolicy *policy.ACLPolicy
env map[string]string env map[string]string
tlsCert []byte tlsCert []byte
tlsKey []byte tlsKey []byte
@ -73,7 +73,7 @@ type Option = func(c *HeadscaleInContainer)
// WithACLPolicy adds a hscontrol.ACLPolicy policy to the // WithACLPolicy adds a hscontrol.ACLPolicy policy to the
// HeadscaleInContainer instance. // HeadscaleInContainer instance.
func WithACLPolicy(acl *hscontrol.ACLPolicy) Option { func WithACLPolicy(acl *policy.ACLPolicy) Option {
return func(hsic *HeadscaleInContainer) { return func(hsic *HeadscaleInContainer) {
// TODO(kradalby): Move somewhere appropriate // TODO(kradalby): Move somewhere appropriate
hsic.env["HEADSCALE_ACL_POLICY_PATH"] = aclPolicyPath hsic.env["HEADSCALE_ACL_POLICY_PATH"] = aclPolicyPath

View file

@ -6,7 +6,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/juanfont/headscale/hscontrol" "github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/tsic" "github.com/juanfont/headscale/integration/tsic"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -57,18 +57,18 @@ func TestSSHOneUserAllToAll(t *testing.T) {
err = scenario.CreateHeadscaleEnv(spec, err = scenario.CreateHeadscaleEnv(spec,
[]tsic.Option{tsic.WithSSH()}, []tsic.Option{tsic.WithSSH()},
hsic.WithACLPolicy( hsic.WithACLPolicy(
&hscontrol.ACLPolicy{ &policy.ACLPolicy{
Groups: map[string][]string{ Groups: map[string][]string{
"group:integration-test": {"user1"}, "group:integration-test": {"user1"},
}, },
ACLs: []hscontrol.ACL{ ACLs: []policy.ACL{
{ {
Action: "accept", Action: "accept",
Sources: []string{"*"}, Sources: []string{"*"},
Destinations: []string{"*:*"}, Destinations: []string{"*:*"},
}, },
}, },
SSHs: []hscontrol.SSH{ SSHs: []policy.SSH{
{ {
Action: "accept", Action: "accept",
Sources: []string{"group:integration-test"}, Sources: []string{"group:integration-test"},
@ -134,18 +134,18 @@ func TestSSHMultipleUsersAllToAll(t *testing.T) {
err = scenario.CreateHeadscaleEnv(spec, err = scenario.CreateHeadscaleEnv(spec,
[]tsic.Option{tsic.WithSSH()}, []tsic.Option{tsic.WithSSH()},
hsic.WithACLPolicy( hsic.WithACLPolicy(
&hscontrol.ACLPolicy{ &policy.ACLPolicy{
Groups: map[string][]string{ Groups: map[string][]string{
"group:integration-test": {"user1", "user2"}, "group:integration-test": {"user1", "user2"},
}, },
ACLs: []hscontrol.ACL{ ACLs: []policy.ACL{
{ {
Action: "accept", Action: "accept",
Sources: []string{"*"}, Sources: []string{"*"},
Destinations: []string{"*:*"}, Destinations: []string{"*:*"},
}, },
}, },
SSHs: []hscontrol.SSH{ SSHs: []policy.SSH{
{ {
Action: "accept", Action: "accept",
Sources: []string{"group:integration-test"}, Sources: []string{"group:integration-test"},
@ -216,18 +216,18 @@ func TestSSHNoSSHConfigured(t *testing.T) {
err = scenario.CreateHeadscaleEnv(spec, err = scenario.CreateHeadscaleEnv(spec,
[]tsic.Option{tsic.WithSSH()}, []tsic.Option{tsic.WithSSH()},
hsic.WithACLPolicy( hsic.WithACLPolicy(
&hscontrol.ACLPolicy{ &policy.ACLPolicy{
Groups: map[string][]string{ Groups: map[string][]string{
"group:integration-test": {"user1"}, "group:integration-test": {"user1"},
}, },
ACLs: []hscontrol.ACL{ ACLs: []policy.ACL{
{ {
Action: "accept", Action: "accept",
Sources: []string{"*"}, Sources: []string{"*"},
Destinations: []string{"*:*"}, Destinations: []string{"*:*"},
}, },
}, },
SSHs: []hscontrol.SSH{}, SSHs: []policy.SSH{},
}, },
), ),
hsic.WithTestName("sshnoneconfigured"), hsic.WithTestName("sshnoneconfigured"),
@ -286,18 +286,18 @@ func TestSSHIsBlockedInACL(t *testing.T) {
err = scenario.CreateHeadscaleEnv(spec, err = scenario.CreateHeadscaleEnv(spec,
[]tsic.Option{tsic.WithSSH()}, []tsic.Option{tsic.WithSSH()},
hsic.WithACLPolicy( hsic.WithACLPolicy(
&hscontrol.ACLPolicy{ &policy.ACLPolicy{
Groups: map[string][]string{ Groups: map[string][]string{
"group:integration-test": {"user1"}, "group:integration-test": {"user1"},
}, },
ACLs: []hscontrol.ACL{ ACLs: []policy.ACL{
{ {
Action: "accept", Action: "accept",
Sources: []string{"*"}, Sources: []string{"*"},
Destinations: []string{"*:80"}, Destinations: []string{"*:80"},
}, },
}, },
SSHs: []hscontrol.SSH{ SSHs: []policy.SSH{
{ {
Action: "accept", Action: "accept",
Sources: []string{"group:integration-test"}, Sources: []string{"group:integration-test"},
@ -364,19 +364,19 @@ func TestSSUserOnlyIsolation(t *testing.T) {
err = scenario.CreateHeadscaleEnv(spec, err = scenario.CreateHeadscaleEnv(spec,
[]tsic.Option{tsic.WithSSH()}, []tsic.Option{tsic.WithSSH()},
hsic.WithACLPolicy( hsic.WithACLPolicy(
&hscontrol.ACLPolicy{ &policy.ACLPolicy{
Groups: map[string][]string{ Groups: map[string][]string{
"group:ssh1": {"useracl1"}, "group:ssh1": {"useracl1"},
"group:ssh2": {"useracl2"}, "group:ssh2": {"useracl2"},
}, },
ACLs: []hscontrol.ACL{ ACLs: []policy.ACL{
{ {
Action: "accept", Action: "accept",
Sources: []string{"*"}, Sources: []string{"*"},
Destinations: []string{"*:*"}, Destinations: []string{"*:*"},
}, },
}, },
SSHs: []hscontrol.SSH{ SSHs: []policy.SSH{
{ {
Action: "accept", Action: "accept",
Sources: []string{"group:ssh1"}, Sources: []string{"group:ssh1"},