mirror of
https://github.com/juanfont/headscale.git
synced 2024-11-26 08:53:05 +00:00
Split code into modules
This is a massive commit that restructures the code into modules: db/ All functions related to modifying the Database types/ All type definitions and methods that can be exclusivly used on these types without dependencies policy/ All Policy related code, now without dependencies on the Database. policy/matcher/ Dedicated code to match machines in a list of FilterRules Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
14e29a7bee
commit
feb15365b5
51 changed files with 4677 additions and 4290 deletions
|
@ -7,7 +7,7 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
"github.com/juanfont/headscale/hscontrol"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/pterm/pterm"
|
"github.com/pterm/pterm"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
|
@ -277,7 +277,7 @@ func routesToPtables(routes []*v1.Route) pterm.TableData {
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if prefix == hscontrol.ExitRouteV4 || prefix == hscontrol.ExitRouteV6 {
|
if prefix == types.ExitRouteV4 || prefix == types.ExitRouteV6 {
|
||||||
isPrimaryStr = "-"
|
isPrimaryStr = "-"
|
||||||
} else {
|
} else {
|
||||||
isPrimaryStr = strconv.FormatBool(route.IsPrimary)
|
isPrimaryStr = strconv.FormatBool(route.IsPrimary)
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
|
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
"github.com/juanfont/headscale/hscontrol"
|
"github.com/juanfont/headscale/hscontrol"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/policy"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
@ -41,13 +42,15 @@ func getHeadscaleApp() (*hscontrol.Headscale, error) {
|
||||||
|
|
||||||
if cfg.ACL.PolicyPath != "" {
|
if cfg.ACL.PolicyPath != "" {
|
||||||
aclPath := util.AbsolutePathFromConfigPath(cfg.ACL.PolicyPath)
|
aclPath := util.AbsolutePathFromConfigPath(cfg.ACL.PolicyPath)
|
||||||
err = app.LoadACLPolicyFromPath(aclPath)
|
pol, err := policy.LoadACLPolicyFromPath(aclPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal().
|
log.Fatal().
|
||||||
Str("path", aclPath).
|
Str("path", aclPath).
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Could not load the ACL policy")
|
Msg("Could not load the ACL policy")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
app.ACLPolicy = pol
|
||||||
}
|
}
|
||||||
|
|
||||||
return app, nil
|
return app, nil
|
||||||
|
|
|
@ -18,9 +18,6 @@ const (
|
||||||
// TODO(juan): remove this once https://github.com/juanfont/headscale/issues/727 is fixed.
|
// TODO(juan): remove this once https://github.com/juanfont/headscale/issues/727 is fixed.
|
||||||
registrationHoldoff = time.Second * 5
|
registrationHoldoff = time.Second * 5
|
||||||
reservedResponseHeaderSize = 4
|
reservedResponseHeaderSize = 4
|
||||||
RegisterMethodAuthKey = "authkey"
|
|
||||||
RegisterMethodOIDC = "oidc"
|
|
||||||
RegisterMethodCLI = "cli"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrRegisterMethodCLIDoesNotSupportExpire = errors.New(
|
var ErrRegisterMethodCLIDoesNotSupportExpire = errors.New(
|
||||||
|
@ -56,7 +53,7 @@ func (h *Headscale) HealthHandler(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.db.pingDB(req.Context()); err != nil {
|
if err := h.db.PingDB(req.Context()); err != nil {
|
||||||
respond(err)
|
respond(err)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
|
@ -3,6 +3,7 @@ package hscontrol
|
||||||
import (
|
import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
|
@ -10,13 +11,13 @@ import (
|
||||||
|
|
||||||
func (h *Headscale) generateMapResponse(
|
func (h *Headscale) generateMapResponse(
|
||||||
mapRequest tailcfg.MapRequest,
|
mapRequest tailcfg.MapRequest,
|
||||||
machine *Machine,
|
machine *types.Machine,
|
||||||
) (*tailcfg.MapResponse, error) {
|
) (*tailcfg.MapResponse, error) {
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("func", "generateMapResponse").
|
Str("func", "generateMapResponse").
|
||||||
Str("machine", mapRequest.Hostinfo.Hostname).
|
Str("machine", mapRequest.Hostinfo.Hostname).
|
||||||
Msg("Creating Map response")
|
Msg("Creating Map response")
|
||||||
node, err := h.db.toNode(*machine, h.aclPolicy, h.cfg.BaseDomain, h.cfg.DNSConfig)
|
node, err := h.db.TailNode(*machine, h.ACLPolicy, h.cfg.DNSConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -27,7 +28,7 @@ func (h *Headscale) generateMapResponse(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
peers, err := h.db.getValidPeers(h.aclPolicy, h.aclRules, machine)
|
peers, err := h.db.GetValidPeers(h.aclRules, machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -38,9 +39,9 @@ func (h *Headscale) generateMapResponse(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
profiles := h.db.getMapResponseUserProfiles(*machine, peers)
|
profiles := h.db.GetMapResponseUserProfiles(*machine, peers)
|
||||||
|
|
||||||
nodePeers, err := h.db.toNodes(peers, h.aclPolicy, h.cfg.BaseDomain, h.cfg.DNSConfig)
|
nodePeers, err := h.db.TailNodes(peers, h.ACLPolicy, h.cfg.DNSConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
|
166
hscontrol/app.go
166
hscontrol/app.go
|
@ -23,6 +23,9 @@ import (
|
||||||
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
|
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
|
||||||
"github.com/juanfont/headscale"
|
"github.com/juanfont/headscale"
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/db"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/policy"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/patrickmn/go-cache"
|
"github.com/patrickmn/go-cache"
|
||||||
zerolog "github.com/philip-bui/grpc-zerolog"
|
zerolog "github.com/philip-bui/grpc-zerolog"
|
||||||
|
@ -73,7 +76,7 @@ const (
|
||||||
// Headscale represents the base app of the service.
|
// Headscale represents the base app of the service.
|
||||||
type Headscale struct {
|
type Headscale struct {
|
||||||
cfg *Config
|
cfg *Config
|
||||||
db *HSDatabase
|
db *db.HSDatabase
|
||||||
dbString string
|
dbString string
|
||||||
dbType string
|
dbType string
|
||||||
dbDebug bool
|
dbDebug bool
|
||||||
|
@ -83,7 +86,7 @@ type Headscale struct {
|
||||||
DERPMap *tailcfg.DERPMap
|
DERPMap *tailcfg.DERPMap
|
||||||
DERPServer *DERPServer
|
DERPServer *DERPServer
|
||||||
|
|
||||||
aclPolicy *ACLPolicy
|
ACLPolicy *policy.ACLPolicy
|
||||||
aclRules []tailcfg.FilterRule
|
aclRules []tailcfg.FilterRule
|
||||||
sshPolicy *tailcfg.SSHPolicy
|
sshPolicy *tailcfg.SSHPolicy
|
||||||
|
|
||||||
|
@ -99,6 +102,12 @@ type Headscale struct {
|
||||||
|
|
||||||
stateUpdateChan chan struct{}
|
stateUpdateChan chan struct{}
|
||||||
cancelStateUpdateChan chan struct{}
|
cancelStateUpdateChan chan struct{}
|
||||||
|
|
||||||
|
// TODO(kradalby): Temporary measure to make sure we can update policy
|
||||||
|
// across modules, will be removed when aclRules are no longer stored
|
||||||
|
// globally but generated per node basis.
|
||||||
|
policyUpdateChan chan struct{}
|
||||||
|
cancelPolicyUpdateChan chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHeadscale(cfg *Config) (*Headscale, error) {
|
func NewHeadscale(cfg *Config) (*Headscale, error) {
|
||||||
|
@ -119,7 +128,7 @@ func NewHeadscale(cfg *Config) (*Headscale, error) {
|
||||||
|
|
||||||
var dbString string
|
var dbString string
|
||||||
switch cfg.DBtype {
|
switch cfg.DBtype {
|
||||||
case Postgres:
|
case db.Postgres:
|
||||||
dbString = fmt.Sprintf(
|
dbString = fmt.Sprintf(
|
||||||
"host=%s dbname=%s user=%s",
|
"host=%s dbname=%s user=%s",
|
||||||
cfg.DBhost,
|
cfg.DBhost,
|
||||||
|
@ -142,7 +151,7 @@ func NewHeadscale(cfg *Config) (*Headscale, error) {
|
||||||
if cfg.DBpass != "" {
|
if cfg.DBpass != "" {
|
||||||
dbString += fmt.Sprintf(" password=%s", cfg.DBpass)
|
dbString += fmt.Sprintf(" password=%s", cfg.DBpass)
|
||||||
}
|
}
|
||||||
case Sqlite:
|
case db.Sqlite:
|
||||||
dbString = cfg.DBpath
|
dbString = cfg.DBpath
|
||||||
default:
|
default:
|
||||||
return nil, errUnsupportedDatabase
|
return nil, errUnsupportedDatabase
|
||||||
|
@ -166,23 +175,28 @@ func NewHeadscale(cfg *Config) (*Headscale, error) {
|
||||||
|
|
||||||
stateUpdateChan: make(chan struct{}),
|
stateUpdateChan: make(chan struct{}),
|
||||||
cancelStateUpdateChan: make(chan struct{}),
|
cancelStateUpdateChan: make(chan struct{}),
|
||||||
|
|
||||||
|
policyUpdateChan: make(chan struct{}),
|
||||||
|
cancelPolicyUpdateChan: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
go app.watchStateChannel()
|
go app.watchStateChannel()
|
||||||
|
go app.watchPolicyChannel()
|
||||||
|
|
||||||
db, err := NewHeadscaleDatabase(
|
database, err := db.NewHeadscaleDatabase(
|
||||||
cfg.DBtype,
|
cfg.DBtype,
|
||||||
dbString,
|
dbString,
|
||||||
cfg.OIDC.StripEmaildomain,
|
cfg.OIDC.StripEmaildomain,
|
||||||
app.dbDebug,
|
app.dbDebug,
|
||||||
app.stateUpdateChan,
|
app.stateUpdateChan,
|
||||||
|
app.policyUpdateChan,
|
||||||
cfg.IPPrefixes,
|
cfg.IPPrefixes,
|
||||||
cfg.BaseDomain)
|
cfg.BaseDomain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
app.db = db
|
app.db = database
|
||||||
|
|
||||||
if cfg.OIDC.Issuer != "" {
|
if cfg.OIDC.Issuer != "" {
|
||||||
err = app.initOIDC()
|
err = app.initOIDC()
|
||||||
|
@ -228,7 +242,7 @@ func (h *Headscale) redirect(w http.ResponseWriter, req *http.Request) {
|
||||||
func (h *Headscale) expireEphemeralNodes(milliSeconds int64) {
|
func (h *Headscale) expireEphemeralNodes(milliSeconds int64) {
|
||||||
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
|
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
|
||||||
for range ticker.C {
|
for range ticker.C {
|
||||||
h.expireEphemeralNodesWorker()
|
h.db.ExpireEphemeralMachines(h.cfg.EphemeralNodeInactivityTimeout)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -237,112 +251,20 @@ func (h *Headscale) expireEphemeralNodes(milliSeconds int64) {
|
||||||
func (h *Headscale) expireExpiredMachines(milliSeconds int64) {
|
func (h *Headscale) expireExpiredMachines(milliSeconds int64) {
|
||||||
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
|
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
|
||||||
for range ticker.C {
|
for range ticker.C {
|
||||||
h.expireExpiredMachinesWorker()
|
h.db.ExpireExpiredMachines(h.getLastStateChange())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) failoverSubnetRoutes(milliSeconds int64) {
|
func (h *Headscale) failoverSubnetRoutes(milliSeconds int64) {
|
||||||
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
|
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
|
||||||
for range ticker.C {
|
for range ticker.C {
|
||||||
err := h.db.handlePrimarySubnetFailover()
|
err := h.db.HandlePrimarySubnetFailover()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("failed to handle primary subnet failover")
|
log.Error().Err(err).Msg("failed to handle primary subnet failover")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) expireEphemeralNodesWorker() {
|
|
||||||
users, err := h.db.ListUsers()
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("Error listing users")
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, user := range users {
|
|
||||||
machines, err := h.db.ListMachinesByUser(user.Name)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Err(err).
|
|
||||||
Str("user", user.Name).
|
|
||||||
Msg("Error listing machines in user")
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
expiredFound := false
|
|
||||||
for _, machine := range machines {
|
|
||||||
if machine.isEphemeral() && machine.LastSeen != nil &&
|
|
||||||
time.Now().
|
|
||||||
After(machine.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) {
|
|
||||||
expiredFound = true
|
|
||||||
log.Info().
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Msg("Ephemeral client removed from database")
|
|
||||||
|
|
||||||
err = h.db.db.Unscoped().Delete(machine).Error
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Err(err).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Msg("🤮 Cannot delete ephemeral machine from the database")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if expiredFound {
|
|
||||||
h.setLastStateChangeToNow()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Headscale) expireExpiredMachinesWorker() {
|
|
||||||
users, err := h.db.ListUsers()
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("Error listing users")
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, user := range users {
|
|
||||||
machines, err := h.db.ListMachinesByUser(user.Name)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Err(err).
|
|
||||||
Str("user", user.Name).
|
|
||||||
Msg("Error listing machines in user")
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
expiredFound := false
|
|
||||||
for index, machine := range machines {
|
|
||||||
if machine.isExpired() &&
|
|
||||||
machine.Expiry.After(h.getLastStateChange(user)) {
|
|
||||||
expiredFound = true
|
|
||||||
|
|
||||||
err := h.db.ExpireMachine(&machines[index])
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Err(err).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Str("name", machine.GivenName).
|
|
||||||
Msg("🤮 Cannot expire machine")
|
|
||||||
} else {
|
|
||||||
log.Info().
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Str("name", machine.GivenName).
|
|
||||||
Msg("Machine successfully expired")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if expiredFound {
|
|
||||||
h.setLastStateChangeToNow()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
|
func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
|
||||||
req interface{},
|
req interface{},
|
||||||
info *grpc.UnaryServerInfo,
|
info *grpc.UnaryServerInfo,
|
||||||
|
@ -565,6 +487,8 @@ func (h *Headscale) Serve() error {
|
||||||
go h.scheduledDERPMapUpdateWorker(derpMapCancelChannel)
|
go h.scheduledDERPMapUpdateWorker(derpMapCancelChannel)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(kradalby): These should have cancel channels and be cleaned
|
||||||
|
// up on shutdown.
|
||||||
go h.expireEphemeralNodes(updateInterval)
|
go h.expireEphemeralNodes(updateInterval)
|
||||||
go h.expireExpiredMachines(updateInterval)
|
go h.expireExpiredMachines(updateInterval)
|
||||||
|
|
||||||
|
@ -774,10 +698,12 @@ func (h *Headscale) Serve() error {
|
||||||
|
|
||||||
if h.cfg.ACL.PolicyPath != "" {
|
if h.cfg.ACL.PolicyPath != "" {
|
||||||
aclPath := util.AbsolutePathFromConfigPath(h.cfg.ACL.PolicyPath)
|
aclPath := util.AbsolutePathFromConfigPath(h.cfg.ACL.PolicyPath)
|
||||||
err := h.LoadACLPolicyFromPath(aclPath)
|
pol, err := policy.LoadACLPolicyFromPath(aclPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("Failed to reload ACL policy")
|
log.Error().Err(err).Msg("Failed to reload ACL policy")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
h.ACLPolicy = pol
|
||||||
log.Info().
|
log.Info().
|
||||||
Str("path", aclPath).
|
Str("path", aclPath).
|
||||||
Msg("ACL policy successfully reloaded, notifying nodes of change")
|
Msg("ACL policy successfully reloaded, notifying nodes of change")
|
||||||
|
@ -824,12 +750,12 @@ func (h *Headscale) Serve() error {
|
||||||
close(h.stateUpdateChan)
|
close(h.stateUpdateChan)
|
||||||
close(h.cancelStateUpdateChan)
|
close(h.cancelStateUpdateChan)
|
||||||
|
|
||||||
|
<-h.cancelPolicyUpdateChan
|
||||||
|
close(h.policyUpdateChan)
|
||||||
|
close(h.cancelPolicyUpdateChan)
|
||||||
|
|
||||||
// Close db connections
|
// Close db connections
|
||||||
db, err := h.db.db.DB()
|
err = h.db.Close()
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("Failed to get db handle")
|
|
||||||
}
|
|
||||||
err = db.Close()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("Failed to close db")
|
log.Error().Err(err).Msg("Failed to close db")
|
||||||
}
|
}
|
||||||
|
@ -936,6 +862,30 @@ func (h *Headscale) watchStateChannel() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(kradalby): baby steps, make this more robust.
|
||||||
|
func (h *Headscale) watchPolicyChannel() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-h.policyUpdateChan:
|
||||||
|
machines, err := h.db.ListMachines()
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("failed to fetch machines during policy update")
|
||||||
|
}
|
||||||
|
|
||||||
|
rules, sshPolicy, err := policy.GenerateFilterRules(h.ACLPolicy, machines, h.cfg.OIDC.StripEmaildomain)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("failed to update ACL rules")
|
||||||
|
}
|
||||||
|
|
||||||
|
h.aclRules = rules
|
||||||
|
h.sshPolicy = sshPolicy
|
||||||
|
|
||||||
|
case <-h.cancelPolicyUpdateChan:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (h *Headscale) setLastStateChangeToNow() {
|
func (h *Headscale) setLastStateChangeToNow() {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
|
@ -958,7 +908,7 @@ func (h *Headscale) setLastStateChangeToNow() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) getLastStateChange(users ...User) time.Time {
|
func (h *Headscale) getLastStateChange(users ...types.User) time.Time {
|
||||||
times := []time.Time{}
|
times := []time.Time{}
|
||||||
|
|
||||||
// getLastStateChange takes a list of users as a "filter", if no users
|
// getLastStateChange takes a list of users as a "filter", if no users
|
||||||
|
|
480
hscontrol/db/acls_test.go
Normal file
480
hscontrol/db/acls_test.go
Normal file
|
@ -0,0 +1,480 @@
|
||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/policy"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
|
"gopkg.in/check.v1"
|
||||||
|
"tailscale.com/envknob"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TODO(kradalby):
|
||||||
|
// Convert these tests to being non-database dependent and table driven. They are
|
||||||
|
// very verbose, and dont really need the database.
|
||||||
|
|
||||||
|
func (s *Suite) TestSshRules(c *check.C) {
|
||||||
|
envknob.Setenv("HEADSCALE_EXPERIMENTAL_FEATURE_SSH", "1")
|
||||||
|
|
||||||
|
user, err := db.CreateUser("user1")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
_, err = db.GetMachine("user1", "testmachine")
|
||||||
|
c.Assert(err, check.NotNil)
|
||||||
|
hostInfo := tailcfg.Hostinfo{
|
||||||
|
OS: "centos",
|
||||||
|
Hostname: "testmachine",
|
||||||
|
RequestTags: []string{"tag:test"},
|
||||||
|
}
|
||||||
|
|
||||||
|
machine := types.Machine{
|
||||||
|
ID: 0,
|
||||||
|
MachineKey: "foo",
|
||||||
|
NodeKey: "bar",
|
||||||
|
DiscoKey: "faa",
|
||||||
|
Hostname: "testmachine",
|
||||||
|
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
|
||||||
|
UserID: user.ID,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
AuthKeyID: uint(pak.ID),
|
||||||
|
HostInfo: types.HostInfo(hostInfo),
|
||||||
|
}
|
||||||
|
err = db.MachineSave(&machine)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
aclPolicy := &policy.ACLPolicy{
|
||||||
|
Groups: policy.Groups{
|
||||||
|
"group:test": []string{"user1"},
|
||||||
|
},
|
||||||
|
Hosts: policy.Hosts{
|
||||||
|
"client": netip.PrefixFrom(netip.MustParseAddr("100.64.99.42"), 32),
|
||||||
|
},
|
||||||
|
ACLs: []policy.ACL{
|
||||||
|
{
|
||||||
|
Action: "accept",
|
||||||
|
Sources: []string{"*"},
|
||||||
|
Destinations: []string{"*:*"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SSHs: []policy.SSH{
|
||||||
|
{
|
||||||
|
Action: "accept",
|
||||||
|
Sources: []string{"group:test"},
|
||||||
|
Destinations: []string{"client"},
|
||||||
|
Users: []string{"autogroup:nonroot"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Action: "accept",
|
||||||
|
Sources: []string{"*"},
|
||||||
|
Destinations: []string{"client"},
|
||||||
|
Users: []string{"autogroup:nonroot"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, sshPolicy, err := policy.GenerateFilterRules(aclPolicy, types.Machines{}, false)
|
||||||
|
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(sshPolicy, check.NotNil)
|
||||||
|
c.Assert(sshPolicy.Rules, check.HasLen, 2)
|
||||||
|
c.Assert(sshPolicy.Rules[0].SSHUsers, check.HasLen, 1)
|
||||||
|
c.Assert(sshPolicy.Rules[0].Principals, check.HasLen, 1)
|
||||||
|
c.Assert(sshPolicy.Rules[0].Principals[0].UserLogin, check.Matches, "user1")
|
||||||
|
|
||||||
|
c.Assert(sshPolicy.Rules[1].SSHUsers, check.HasLen, 1)
|
||||||
|
c.Assert(sshPolicy.Rules[1].Principals, check.HasLen, 1)
|
||||||
|
c.Assert(sshPolicy.Rules[1].Principals[0].NodeIP, check.Matches, "*")
|
||||||
|
}
|
||||||
|
|
||||||
|
// this test should validate that we can expand a group in a TagOWner section and
|
||||||
|
// match properly the IP's of the related hosts. The owner is valid and the tag is also valid.
|
||||||
|
// the tag is matched in the Sources section.
|
||||||
|
func (s *Suite) TestValidExpandTagOwnersInSources(c *check.C) {
|
||||||
|
user, err := db.CreateUser("user1")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
_, err = db.GetMachine("user1", "testmachine")
|
||||||
|
c.Assert(err, check.NotNil)
|
||||||
|
hostInfo := tailcfg.Hostinfo{
|
||||||
|
OS: "centos",
|
||||||
|
Hostname: "testmachine",
|
||||||
|
RequestTags: []string{"tag:test"},
|
||||||
|
}
|
||||||
|
|
||||||
|
machine := types.Machine{
|
||||||
|
ID: 0,
|
||||||
|
MachineKey: "foo",
|
||||||
|
NodeKey: "bar",
|
||||||
|
DiscoKey: "faa",
|
||||||
|
Hostname: "testmachine",
|
||||||
|
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
|
||||||
|
UserID: user.ID,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
AuthKeyID: uint(pak.ID),
|
||||||
|
HostInfo: types.HostInfo(hostInfo),
|
||||||
|
}
|
||||||
|
err = db.MachineSave(&machine)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
pol := &policy.ACLPolicy{
|
||||||
|
Groups: policy.Groups{"group:test": []string{"user1", "user2"}},
|
||||||
|
TagOwners: policy.TagOwners{"tag:test": []string{"user3", "group:test"}},
|
||||||
|
ACLs: []policy.ACL{
|
||||||
|
{
|
||||||
|
Action: "accept",
|
||||||
|
Sources: []string{"tag:test"},
|
||||||
|
Destinations: []string{"*:*"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
machines, err := db.ListMachines()
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
rules, _, err := policy.GenerateFilterRules(pol, machines, false)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
c.Assert(rules, check.HasLen, 1)
|
||||||
|
c.Assert(rules[0].SrcIPs, check.HasLen, 1)
|
||||||
|
c.Assert(rules[0].SrcIPs[0], check.Equals, "100.64.0.1/32")
|
||||||
|
}
|
||||||
|
|
||||||
|
// this test should validate that we can expand a group in a TagOWner section and
|
||||||
|
// match properly the IP's of the related hosts. The owner is valid and the tag is also valid.
|
||||||
|
// the tag is matched in the Destinations section.
|
||||||
|
func (s *Suite) TestValidExpandTagOwnersInDestinations(c *check.C) {
|
||||||
|
user, err := db.CreateUser("user1")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
_, err = db.GetMachine("user1", "testmachine")
|
||||||
|
c.Assert(err, check.NotNil)
|
||||||
|
hostInfo := tailcfg.Hostinfo{
|
||||||
|
OS: "centos",
|
||||||
|
Hostname: "testmachine",
|
||||||
|
RequestTags: []string{"tag:test"},
|
||||||
|
}
|
||||||
|
|
||||||
|
machine := types.Machine{
|
||||||
|
ID: 1,
|
||||||
|
MachineKey: "12345",
|
||||||
|
NodeKey: "bar",
|
||||||
|
DiscoKey: "faa",
|
||||||
|
Hostname: "testmachine",
|
||||||
|
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
|
||||||
|
UserID: user.ID,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
AuthKeyID: uint(pak.ID),
|
||||||
|
HostInfo: types.HostInfo(hostInfo),
|
||||||
|
}
|
||||||
|
err = db.MachineSave(&machine)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
pol := &policy.ACLPolicy{
|
||||||
|
Groups: policy.Groups{"group:test": []string{"user1", "user2"}},
|
||||||
|
TagOwners: policy.TagOwners{"tag:test": []string{"user3", "group:test"}},
|
||||||
|
ACLs: []policy.ACL{
|
||||||
|
{
|
||||||
|
Action: "accept",
|
||||||
|
Sources: []string{"*"},
|
||||||
|
Destinations: []string{"tag:test:*"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
machines, err := db.ListMachines()
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
rules, _, err := policy.GenerateFilterRules(pol, machines, false)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
c.Assert(rules, check.HasLen, 1)
|
||||||
|
c.Assert(rules[0].DstPorts, check.HasLen, 1)
|
||||||
|
c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")
|
||||||
|
}
|
||||||
|
|
||||||
|
// need a test with:
|
||||||
|
// tag on a host that isn't owned by a tag owners. So the user
|
||||||
|
// of the host should be valid.
|
||||||
|
func (s *Suite) TestInvalidTagValidUser(c *check.C) {
|
||||||
|
user, err := db.CreateUser("user1")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
_, err = db.GetMachine("user1", "testmachine")
|
||||||
|
c.Assert(err, check.NotNil)
|
||||||
|
hostInfo := tailcfg.Hostinfo{
|
||||||
|
OS: "centos",
|
||||||
|
Hostname: "testmachine",
|
||||||
|
RequestTags: []string{"tag:foo"},
|
||||||
|
}
|
||||||
|
|
||||||
|
machine := types.Machine{
|
||||||
|
ID: 1,
|
||||||
|
MachineKey: "12345",
|
||||||
|
NodeKey: "bar",
|
||||||
|
DiscoKey: "faa",
|
||||||
|
Hostname: "testmachine",
|
||||||
|
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
|
||||||
|
UserID: user.ID,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
AuthKeyID: uint(pak.ID),
|
||||||
|
HostInfo: types.HostInfo(hostInfo),
|
||||||
|
}
|
||||||
|
err = db.MachineSave(&machine)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
pol := &policy.ACLPolicy{
|
||||||
|
TagOwners: policy.TagOwners{"tag:test": []string{"user1"}},
|
||||||
|
ACLs: []policy.ACL{
|
||||||
|
{
|
||||||
|
Action: "accept",
|
||||||
|
Sources: []string{"user1"},
|
||||||
|
Destinations: []string{"*:*"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
machines, err := db.ListMachines()
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
rules, _, err := policy.GenerateFilterRules(pol, machines, false)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
c.Assert(rules, check.HasLen, 1)
|
||||||
|
c.Assert(rules[0].SrcIPs, check.HasLen, 1)
|
||||||
|
c.Assert(rules[0].SrcIPs[0], check.Equals, "100.64.0.1/32")
|
||||||
|
}
|
||||||
|
|
||||||
|
// tag on a host is owned by a tag owner, the tag is valid.
|
||||||
|
// an ACL rule is matching the tag to a user. It should not be valid since the
|
||||||
|
// host should be tied to the tag now.
|
||||||
|
func (s *Suite) TestValidTagInvalidUser(c *check.C) {
|
||||||
|
user, err := db.CreateUser("user1")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
_, err = db.GetMachine("user1", "webserver")
|
||||||
|
c.Assert(err, check.NotNil)
|
||||||
|
hostInfo := tailcfg.Hostinfo{
|
||||||
|
OS: "centos",
|
||||||
|
Hostname: "webserver",
|
||||||
|
RequestTags: []string{"tag:webapp"},
|
||||||
|
}
|
||||||
|
|
||||||
|
machine := types.Machine{
|
||||||
|
ID: 1,
|
||||||
|
MachineKey: "12345",
|
||||||
|
NodeKey: "bar",
|
||||||
|
DiscoKey: "faa",
|
||||||
|
Hostname: "webserver",
|
||||||
|
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
|
||||||
|
UserID: user.ID,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
AuthKeyID: uint(pak.ID),
|
||||||
|
HostInfo: types.HostInfo(hostInfo),
|
||||||
|
}
|
||||||
|
err = db.MachineSave(&machine)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
_, err = db.GetMachine("user1", "user")
|
||||||
|
hostInfo2 := tailcfg.Hostinfo{
|
||||||
|
OS: "debian",
|
||||||
|
Hostname: "Hostname",
|
||||||
|
}
|
||||||
|
c.Assert(err, check.NotNil)
|
||||||
|
machine = types.Machine{
|
||||||
|
ID: 2,
|
||||||
|
MachineKey: "56789",
|
||||||
|
NodeKey: "bar2",
|
||||||
|
DiscoKey: "faab",
|
||||||
|
Hostname: "user",
|
||||||
|
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.2")},
|
||||||
|
UserID: user.ID,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
AuthKeyID: uint(pak.ID),
|
||||||
|
HostInfo: types.HostInfo(hostInfo2),
|
||||||
|
}
|
||||||
|
err = db.MachineSave(&machine)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
pol := &policy.ACLPolicy{
|
||||||
|
TagOwners: policy.TagOwners{"tag:webapp": []string{"user1"}},
|
||||||
|
ACLs: []policy.ACL{
|
||||||
|
{
|
||||||
|
Action: "accept",
|
||||||
|
Sources: []string{"user1"},
|
||||||
|
Destinations: []string{"tag:webapp:80,443"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
machines, err := db.ListMachines()
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
rules, _, err := policy.GenerateFilterRules(pol, machines, false)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
c.Assert(rules, check.HasLen, 1)
|
||||||
|
c.Assert(rules[0].SrcIPs, check.HasLen, 1)
|
||||||
|
c.Assert(rules[0].SrcIPs[0], check.Equals, "100.64.0.2/32")
|
||||||
|
c.Assert(rules[0].DstPorts, check.HasLen, 2)
|
||||||
|
c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(80))
|
||||||
|
c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(80))
|
||||||
|
c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")
|
||||||
|
c.Assert(rules[0].DstPorts[1].Ports.First, check.Equals, uint16(443))
|
||||||
|
c.Assert(rules[0].DstPorts[1].Ports.Last, check.Equals, uint16(443))
|
||||||
|
c.Assert(rules[0].DstPorts[1].IP, check.Equals, "100.64.0.1/32")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Suite) TestPortUser(c *check.C) {
|
||||||
|
user, err := db.CreateUser("testuser")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
_, err = db.GetMachine("testuser", "testmachine")
|
||||||
|
c.Assert(err, check.NotNil)
|
||||||
|
ips, _ := db.getAvailableIPs()
|
||||||
|
machine := types.Machine{
|
||||||
|
ID: 0,
|
||||||
|
MachineKey: "12345",
|
||||||
|
NodeKey: "bar",
|
||||||
|
DiscoKey: "faa",
|
||||||
|
Hostname: "testmachine",
|
||||||
|
UserID: user.ID,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
IPAddresses: ips,
|
||||||
|
AuthKeyID: uint(pak.ID),
|
||||||
|
}
|
||||||
|
err = db.MachineSave(&machine)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
acl := []byte(`
|
||||||
|
{
|
||||||
|
"hosts": {
|
||||||
|
"host-1": "100.100.100.100",
|
||||||
|
"subnet-1": "100.100.101.100/24",
|
||||||
|
},
|
||||||
|
|
||||||
|
"acls": [
|
||||||
|
{
|
||||||
|
"action": "accept",
|
||||||
|
"src": [
|
||||||
|
"testuser",
|
||||||
|
],
|
||||||
|
"dst": [
|
||||||
|
"host-1:*",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
`)
|
||||||
|
pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(pol, check.NotNil)
|
||||||
|
|
||||||
|
machines, err := db.ListMachines()
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
rules, _, err := policy.GenerateFilterRules(pol, machines, false)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(rules, check.NotNil)
|
||||||
|
|
||||||
|
c.Assert(rules, check.HasLen, 1)
|
||||||
|
c.Assert(rules[0].DstPorts, check.HasLen, 1)
|
||||||
|
c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0))
|
||||||
|
c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535))
|
||||||
|
c.Assert(rules[0].SrcIPs, check.HasLen, 1)
|
||||||
|
c.Assert(rules[0].SrcIPs[0], check.Not(check.Equals), "not an ip")
|
||||||
|
c.Assert(len(ips), check.Equals, 1)
|
||||||
|
c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String()+"/32")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Suite) TestPortGroup(c *check.C) {
|
||||||
|
user, err := db.CreateUser("testuser")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
_, err = db.GetMachine("testuser", "testmachine")
|
||||||
|
c.Assert(err, check.NotNil)
|
||||||
|
ips, _ := db.getAvailableIPs()
|
||||||
|
machine := types.Machine{
|
||||||
|
ID: 0,
|
||||||
|
MachineKey: "foo",
|
||||||
|
NodeKey: "bar",
|
||||||
|
DiscoKey: "faa",
|
||||||
|
Hostname: "testmachine",
|
||||||
|
UserID: user.ID,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
IPAddresses: ips,
|
||||||
|
AuthKeyID: uint(pak.ID),
|
||||||
|
}
|
||||||
|
err = db.MachineSave(&machine)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
acl := []byte(`
|
||||||
|
{
|
||||||
|
"groups": {
|
||||||
|
"group:example": [
|
||||||
|
"testuser",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
|
||||||
|
"hosts": {
|
||||||
|
"host-1": "100.100.100.100",
|
||||||
|
"subnet-1": "100.100.101.100/24",
|
||||||
|
},
|
||||||
|
|
||||||
|
"acls": [
|
||||||
|
{
|
||||||
|
"action": "accept",
|
||||||
|
"src": [
|
||||||
|
"group:example",
|
||||||
|
],
|
||||||
|
"dst": [
|
||||||
|
"host-1:*",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
`)
|
||||||
|
pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
machines, err := db.ListMachines()
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
rules, _, err := policy.GenerateFilterRules(pol, machines, false)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
c.Assert(rules, check.NotNil)
|
||||||
|
|
||||||
|
c.Assert(rules, check.HasLen, 1)
|
||||||
|
c.Assert(rules[0].DstPorts, check.HasLen, 1)
|
||||||
|
c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0))
|
||||||
|
c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535))
|
||||||
|
c.Assert(rules[0].SrcIPs, check.HasLen, 1)
|
||||||
|
c.Assert(rules[0].SrcIPs[0], check.Not(check.Equals), "not an ip")
|
||||||
|
c.Assert(len(ips), check.Equals, 1)
|
||||||
|
c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String()+"/32")
|
||||||
|
}
|
|
@ -3,21 +3,22 @@
|
||||||
// Use of this source code is governed by a BSD-style
|
// Use of this source code is governed by a BSD-style
|
||||||
// license that can be found in the LICENSE file.
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
package hscontrol
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"go4.org/netipx"
|
"go4.org/netipx"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrCouldNotAllocateIP = errors.New("could not find any suitable IP")
|
var ErrCouldNotAllocateIP = errors.New("could not find any suitable IP")
|
||||||
|
|
||||||
func (hsdb *HSDatabase) getAvailableIPs() (MachineAddresses, error) {
|
func (hsdb *HSDatabase) getAvailableIPs() (types.MachineAddresses, error) {
|
||||||
var ips MachineAddresses
|
var ips types.MachineAddresses
|
||||||
var err error
|
var err error
|
||||||
for _, ipPrefix := range hsdb.ipPrefixes {
|
for _, ipPrefix := range hsdb.ipPrefixes {
|
||||||
var ip *netip.Addr
|
var ip *netip.Addr
|
||||||
|
@ -68,11 +69,11 @@ func (hsdb *HSDatabase) getUsedIPs() (*netipx.IPSet, error) {
|
||||||
// but this was quick to get running and it should be enough
|
// but this was quick to get running and it should be enough
|
||||||
// to begin experimenting with a dual stack tailnet.
|
// to begin experimenting with a dual stack tailnet.
|
||||||
var addressesSlices []string
|
var addressesSlices []string
|
||||||
hsdb.db.Model(&Machine{}).Pluck("ip_addresses", &addressesSlices)
|
hsdb.db.Model(&types.Machine{}).Pluck("ip_addresses", &addressesSlices)
|
||||||
|
|
||||||
var ips netipx.IPSetBuilder
|
var ips netipx.IPSetBuilder
|
||||||
for _, slice := range addressesSlices {
|
for _, slice := range addressesSlices {
|
||||||
var machineAddresses MachineAddresses
|
var machineAddresses types.MachineAddresses
|
||||||
err := machineAddresses.Scan(slice)
|
err := machineAddresses.Scan(slice)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &netipx.IPSet{}, fmt.Errorf(
|
return &netipx.IPSet{}, fmt.Errorf(
|
|
@ -1,14 +1,16 @@
|
||||||
package hscontrol
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"go4.org/netipx"
|
"go4.org/netipx"
|
||||||
"gopkg.in/check.v1"
|
"gopkg.in/check.v1"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Suite) TestGetAvailableIp(c *check.C) {
|
func (s *Suite) TestGetAvailableIp(c *check.C) {
|
||||||
ips, err := app.db.getAvailableIPs()
|
ips, err := db.getAvailableIPs()
|
||||||
|
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
@ -19,32 +21,32 @@ func (s *Suite) TestGetAvailableIp(c *check.C) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestGetUsedIps(c *check.C) {
|
func (s *Suite) TestGetUsedIps(c *check.C) {
|
||||||
ips, err := app.db.getAvailableIPs()
|
ips, err := db.getAvailableIPs()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
user, err := app.db.CreateUser("test-ip")
|
user, err := db.CreateUser("test-ip")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.db.GetMachine("test", "testmachine")
|
_, err = db.GetMachine("test", "testmachine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
machine := Machine{
|
machine := types.Machine{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: "foo",
|
MachineKey: "foo",
|
||||||
NodeKey: "bar",
|
NodeKey: "bar",
|
||||||
DiscoKey: "faa",
|
DiscoKey: "faa",
|
||||||
Hostname: "testmachine",
|
Hostname: "testmachine",
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
IPAddresses: ips,
|
IPAddresses: ips,
|
||||||
}
|
}
|
||||||
app.db.db.Save(&machine)
|
db.db.Save(&machine)
|
||||||
|
|
||||||
usedIps, err := app.db.getUsedIPs()
|
usedIps, err := db.getUsedIPs()
|
||||||
|
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
@ -56,46 +58,48 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
|
||||||
c.Assert(usedIps.Equal(expectedIPSet), check.Equals, true)
|
c.Assert(usedIps.Equal(expectedIPSet), check.Equals, true)
|
||||||
c.Assert(usedIps.Contains(expected), check.Equals, true)
|
c.Assert(usedIps.Contains(expected), check.Equals, true)
|
||||||
|
|
||||||
machine1, err := app.db.GetMachineByID(0)
|
machine1, err := db.GetMachineByID(0)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
c.Assert(len(machine1.IPAddresses), check.Equals, 1)
|
c.Assert(len(machine1.IPAddresses), check.Equals, 1)
|
||||||
c.Assert(machine1.IPAddresses[0], check.Equals, expected)
|
c.Assert(machine1.IPAddresses[0], check.Equals, expected)
|
||||||
|
|
||||||
|
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestGetMultiIp(c *check.C) {
|
func (s *Suite) TestGetMultiIp(c *check.C) {
|
||||||
user, err := app.db.CreateUser("test-ip-multi")
|
user, err := db.CreateUser("test-ip-multi")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
for index := 1; index <= 350; index++ {
|
for index := 1; index <= 350; index++ {
|
||||||
app.db.ipAllocationMutex.Lock()
|
db.ipAllocationMutex.Lock()
|
||||||
|
|
||||||
ips, err := app.db.getAvailableIPs()
|
ips, err := db.getAvailableIPs()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.db.GetMachine("test", "testmachine")
|
_, err = db.GetMachine("test", "testmachine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
machine := Machine{
|
machine := types.Machine{
|
||||||
ID: uint64(index),
|
ID: uint64(index),
|
||||||
MachineKey: "foo",
|
MachineKey: "foo",
|
||||||
NodeKey: "bar",
|
NodeKey: "bar",
|
||||||
DiscoKey: "faa",
|
DiscoKey: "faa",
|
||||||
Hostname: "testmachine",
|
Hostname: "testmachine",
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
IPAddresses: ips,
|
IPAddresses: ips,
|
||||||
}
|
}
|
||||||
app.db.db.Save(&machine)
|
db.db.Save(&machine)
|
||||||
|
|
||||||
app.db.ipAllocationMutex.Unlock()
|
db.ipAllocationMutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
usedIps, err := app.db.getUsedIPs()
|
usedIps, err := db.getUsedIPs()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
expected0 := netip.MustParseAddr("10.27.0.1")
|
expected0 := netip.MustParseAddr("10.27.0.1")
|
||||||
|
@ -117,7 +121,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
|
||||||
c.Assert(usedIps.Contains(expected300), check.Equals, true)
|
c.Assert(usedIps.Contains(expected300), check.Equals, true)
|
||||||
|
|
||||||
// Check that we can read back the IPs
|
// Check that we can read back the IPs
|
||||||
machine1, err := app.db.GetMachineByID(1)
|
machine1, err := db.GetMachineByID(1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(machine1.IPAddresses), check.Equals, 1)
|
c.Assert(len(machine1.IPAddresses), check.Equals, 1)
|
||||||
c.Assert(
|
c.Assert(
|
||||||
|
@ -126,7 +130,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
|
||||||
netip.MustParseAddr("10.27.0.1"),
|
netip.MustParseAddr("10.27.0.1"),
|
||||||
)
|
)
|
||||||
|
|
||||||
machine50, err := app.db.GetMachineByID(50)
|
machine50, err := db.GetMachineByID(50)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(machine50.IPAddresses), check.Equals, 1)
|
c.Assert(len(machine50.IPAddresses), check.Equals, 1)
|
||||||
c.Assert(
|
c.Assert(
|
||||||
|
@ -136,7 +140,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
|
||||||
)
|
)
|
||||||
|
|
||||||
expectedNextIP := netip.MustParseAddr("10.27.1.95")
|
expectedNextIP := netip.MustParseAddr("10.27.1.95")
|
||||||
nextIP, err := app.db.getAvailableIPs()
|
nextIP, err := db.getAvailableIPs()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
c.Assert(len(nextIP), check.Equals, 1)
|
c.Assert(len(nextIP), check.Equals, 1)
|
||||||
|
@ -144,15 +148,17 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
|
||||||
|
|
||||||
// If we call get Available again, we should receive
|
// If we call get Available again, we should receive
|
||||||
// the same IP, as it has not been reserved.
|
// the same IP, as it has not been reserved.
|
||||||
nextIP2, err := app.db.getAvailableIPs()
|
nextIP2, err := db.getAvailableIPs()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
c.Assert(len(nextIP2), check.Equals, 1)
|
c.Assert(len(nextIP2), check.Equals, 1)
|
||||||
c.Assert(nextIP2[0].String(), check.Equals, expectedNextIP.String())
|
c.Assert(nextIP2[0].String(), check.Equals, expectedNextIP.String())
|
||||||
|
|
||||||
|
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
|
func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
|
||||||
ips, err := app.db.getAvailableIPs()
|
ips, err := db.getAvailableIPs()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
expected := netip.MustParseAddr("10.27.0.1")
|
expected := netip.MustParseAddr("10.27.0.1")
|
||||||
|
@ -160,30 +166,32 @@ func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
|
||||||
c.Assert(len(ips), check.Equals, 1)
|
c.Assert(len(ips), check.Equals, 1)
|
||||||
c.Assert(ips[0].String(), check.Equals, expected.String())
|
c.Assert(ips[0].String(), check.Equals, expected.String())
|
||||||
|
|
||||||
user, err := app.db.CreateUser("test-ip")
|
user, err := db.CreateUser("test-ip")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.db.GetMachine("test", "testmachine")
|
_, err = db.GetMachine("test", "testmachine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
machine := Machine{
|
machine := types.Machine{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: "foo",
|
MachineKey: "foo",
|
||||||
NodeKey: "bar",
|
NodeKey: "bar",
|
||||||
DiscoKey: "faa",
|
DiscoKey: "faa",
|
||||||
Hostname: "testmachine",
|
Hostname: "testmachine",
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
app.db.db.Save(&machine)
|
db.db.Save(&machine)
|
||||||
|
|
||||||
ips2, err := app.db.getAvailableIPs()
|
ips2, err := db.getAvailableIPs()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
c.Assert(len(ips2), check.Equals, 1)
|
c.Assert(len(ips2), check.Equals, 1)
|
||||||
c.Assert(ips2[0].String(), check.Equals, expected.String())
|
c.Assert(ips2[0].String(), check.Equals, expected.String())
|
||||||
|
|
||||||
|
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||||
}
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package hscontrol
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
@ -6,10 +6,9 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -19,22 +18,10 @@ const (
|
||||||
|
|
||||||
var ErrAPIKeyFailedToParse = errors.New("failed to parse ApiKey")
|
var ErrAPIKeyFailedToParse = errors.New("failed to parse ApiKey")
|
||||||
|
|
||||||
// APIKey describes the datamodel for API keys used to remotely authenticate with
|
|
||||||
// headscale.
|
|
||||||
type APIKey struct {
|
|
||||||
ID uint64 `gorm:"primary_key"`
|
|
||||||
Prefix string `gorm:"uniqueIndex"`
|
|
||||||
Hash []byte
|
|
||||||
|
|
||||||
CreatedAt *time.Time
|
|
||||||
Expiration *time.Time
|
|
||||||
LastSeen *time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateAPIKey creates a new ApiKey in a user, and returns it.
|
// CreateAPIKey creates a new ApiKey in a user, and returns it.
|
||||||
func (hsdb *HSDatabase) CreateAPIKey(
|
func (hsdb *HSDatabase) CreateAPIKey(
|
||||||
expiration *time.Time,
|
expiration *time.Time,
|
||||||
) (string, *APIKey, error) {
|
) (string, *types.APIKey, error) {
|
||||||
prefix, err := util.GenerateRandomStringURLSafe(apiPrefixLength)
|
prefix, err := util.GenerateRandomStringURLSafe(apiPrefixLength)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
|
@ -53,7 +40,7 @@ func (hsdb *HSDatabase) CreateAPIKey(
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
key := APIKey{
|
key := types.APIKey{
|
||||||
Prefix: prefix,
|
Prefix: prefix,
|
||||||
Hash: hash,
|
Hash: hash,
|
||||||
Expiration: expiration,
|
Expiration: expiration,
|
||||||
|
@ -67,8 +54,8 @@ func (hsdb *HSDatabase) CreateAPIKey(
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListAPIKeys returns the list of ApiKeys for a user.
|
// ListAPIKeys returns the list of ApiKeys for a user.
|
||||||
func (hsdb *HSDatabase) ListAPIKeys() ([]APIKey, error) {
|
func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) {
|
||||||
keys := []APIKey{}
|
keys := []types.APIKey{}
|
||||||
if err := hsdb.db.Find(&keys).Error; err != nil {
|
if err := hsdb.db.Find(&keys).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -77,8 +64,8 @@ func (hsdb *HSDatabase) ListAPIKeys() ([]APIKey, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAPIKey returns a ApiKey for a given key.
|
// GetAPIKey returns a ApiKey for a given key.
|
||||||
func (hsdb *HSDatabase) GetAPIKey(prefix string) (*APIKey, error) {
|
func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) {
|
||||||
key := APIKey{}
|
key := types.APIKey{}
|
||||||
if result := hsdb.db.First(&key, "prefix = ?", prefix); result.Error != nil {
|
if result := hsdb.db.First(&key, "prefix = ?", prefix); result.Error != nil {
|
||||||
return nil, result.Error
|
return nil, result.Error
|
||||||
}
|
}
|
||||||
|
@ -87,9 +74,9 @@ func (hsdb *HSDatabase) GetAPIKey(prefix string) (*APIKey, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAPIKeyByID returns a ApiKey for a given id.
|
// GetAPIKeyByID returns a ApiKey for a given id.
|
||||||
func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*APIKey, error) {
|
func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) {
|
||||||
key := APIKey{}
|
key := types.APIKey{}
|
||||||
if result := hsdb.db.Find(&APIKey{ID: id}).First(&key); result.Error != nil {
|
if result := hsdb.db.Find(&types.APIKey{ID: id}).First(&key); result.Error != nil {
|
||||||
return nil, result.Error
|
return nil, result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -98,7 +85,7 @@ func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*APIKey, error) {
|
||||||
|
|
||||||
// DestroyAPIKey destroys a ApiKey. Returns error if the ApiKey
|
// DestroyAPIKey destroys a ApiKey. Returns error if the ApiKey
|
||||||
// does not exist.
|
// does not exist.
|
||||||
func (hsdb *HSDatabase) DestroyAPIKey(key APIKey) error {
|
func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error {
|
||||||
if result := hsdb.db.Unscoped().Delete(key); result.Error != nil {
|
if result := hsdb.db.Unscoped().Delete(key); result.Error != nil {
|
||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
|
@ -107,7 +94,7 @@ func (hsdb *HSDatabase) DestroyAPIKey(key APIKey) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExpireAPIKey marks a ApiKey as expired.
|
// ExpireAPIKey marks a ApiKey as expired.
|
||||||
func (hsdb *HSDatabase) ExpireAPIKey(key *APIKey) error {
|
func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error {
|
||||||
if err := hsdb.db.Model(&key).Update("Expiration", time.Now()).Error; err != nil {
|
if err := hsdb.db.Model(&key).Update("Expiration", time.Now()).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -136,24 +123,3 @@ func (hsdb *HSDatabase) ValidateAPIKey(keyStr string) (bool, error) {
|
||||||
|
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (key *APIKey) toProto() *v1.ApiKey {
|
|
||||||
protoKey := v1.ApiKey{
|
|
||||||
Id: key.ID,
|
|
||||||
Prefix: key.Prefix,
|
|
||||||
}
|
|
||||||
|
|
||||||
if key.Expiration != nil {
|
|
||||||
protoKey.Expiration = timestamppb.New(*key.Expiration)
|
|
||||||
}
|
|
||||||
|
|
||||||
if key.CreatedAt != nil {
|
|
||||||
protoKey.CreatedAt = timestamppb.New(*key.CreatedAt)
|
|
||||||
}
|
|
||||||
|
|
||||||
if key.LastSeen != nil {
|
|
||||||
protoKey.LastSeen = timestamppb.New(*key.LastSeen)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &protoKey
|
|
||||||
}
|
|
|
@ -1,4 +1,4 @@
|
||||||
package hscontrol
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"time"
|
"time"
|
||||||
|
@ -7,7 +7,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (*Suite) TestCreateAPIKey(c *check.C) {
|
func (*Suite) TestCreateAPIKey(c *check.C) {
|
||||||
apiKeyStr, apiKey, err := app.db.CreateAPIKey(nil)
|
apiKeyStr, apiKey, err := db.CreateAPIKey(nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(apiKey, check.NotNil)
|
c.Assert(apiKey, check.NotNil)
|
||||||
|
|
||||||
|
@ -16,74 +16,82 @@ func (*Suite) TestCreateAPIKey(c *check.C) {
|
||||||
c.Assert(apiKey.Hash, check.NotNil)
|
c.Assert(apiKey.Hash, check.NotNil)
|
||||||
c.Assert(apiKeyStr, check.Not(check.Equals), "")
|
c.Assert(apiKeyStr, check.Not(check.Equals), "")
|
||||||
|
|
||||||
_, err = app.db.ListAPIKeys()
|
_, err = db.ListAPIKeys()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
keys, err := app.db.ListAPIKeys()
|
keys, err := db.ListAPIKeys()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(keys), check.Equals, 1)
|
c.Assert(len(keys), check.Equals, 1)
|
||||||
|
|
||||||
|
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestAPIKeyDoesNotExist(c *check.C) {
|
func (*Suite) TestAPIKeyDoesNotExist(c *check.C) {
|
||||||
key, err := app.db.GetAPIKey("does-not-exist")
|
key, err := db.GetAPIKey("does-not-exist")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
c.Assert(key, check.IsNil)
|
c.Assert(key, check.IsNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestValidateAPIKeyOk(c *check.C) {
|
func (*Suite) TestValidateAPIKeyOk(c *check.C) {
|
||||||
nowPlus2 := time.Now().Add(2 * time.Hour)
|
nowPlus2 := time.Now().Add(2 * time.Hour)
|
||||||
apiKeyStr, apiKey, err := app.db.CreateAPIKey(&nowPlus2)
|
apiKeyStr, apiKey, err := db.CreateAPIKey(&nowPlus2)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(apiKey, check.NotNil)
|
c.Assert(apiKey, check.NotNil)
|
||||||
|
|
||||||
valid, err := app.db.ValidateAPIKey(apiKeyStr)
|
valid, err := db.ValidateAPIKey(apiKeyStr)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(valid, check.Equals, true)
|
c.Assert(valid, check.Equals, true)
|
||||||
|
|
||||||
|
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestValidateAPIKeyNotOk(c *check.C) {
|
func (*Suite) TestValidateAPIKeyNotOk(c *check.C) {
|
||||||
nowMinus2 := time.Now().Add(time.Duration(-2) * time.Hour)
|
nowMinus2 := time.Now().Add(time.Duration(-2) * time.Hour)
|
||||||
apiKeyStr, apiKey, err := app.db.CreateAPIKey(&nowMinus2)
|
apiKeyStr, apiKey, err := db.CreateAPIKey(&nowMinus2)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(apiKey, check.NotNil)
|
c.Assert(apiKey, check.NotNil)
|
||||||
|
|
||||||
valid, err := app.db.ValidateAPIKey(apiKeyStr)
|
valid, err := db.ValidateAPIKey(apiKeyStr)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(valid, check.Equals, false)
|
c.Assert(valid, check.Equals, false)
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
apiKeyStrNow, apiKey, err := app.db.CreateAPIKey(&now)
|
apiKeyStrNow, apiKey, err := db.CreateAPIKey(&now)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(apiKey, check.NotNil)
|
c.Assert(apiKey, check.NotNil)
|
||||||
|
|
||||||
validNow, err := app.db.ValidateAPIKey(apiKeyStrNow)
|
validNow, err := db.ValidateAPIKey(apiKeyStrNow)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(validNow, check.Equals, false)
|
c.Assert(validNow, check.Equals, false)
|
||||||
|
|
||||||
validSilly, err := app.db.ValidateAPIKey("nota.validkey")
|
validSilly, err := db.ValidateAPIKey("nota.validkey")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
c.Assert(validSilly, check.Equals, false)
|
c.Assert(validSilly, check.Equals, false)
|
||||||
|
|
||||||
validWithErr, err := app.db.ValidateAPIKey("produceerrorkey")
|
validWithErr, err := db.ValidateAPIKey("produceerrorkey")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
c.Assert(validWithErr, check.Equals, false)
|
c.Assert(validWithErr, check.Equals, false)
|
||||||
|
|
||||||
|
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestExpireAPIKey(c *check.C) {
|
func (*Suite) TestExpireAPIKey(c *check.C) {
|
||||||
nowPlus2 := time.Now().Add(2 * time.Hour)
|
nowPlus2 := time.Now().Add(2 * time.Hour)
|
||||||
apiKeyStr, apiKey, err := app.db.CreateAPIKey(&nowPlus2)
|
apiKeyStr, apiKey, err := db.CreateAPIKey(&nowPlus2)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(apiKey, check.NotNil)
|
c.Assert(apiKey, check.NotNil)
|
||||||
|
|
||||||
valid, err := app.db.ValidateAPIKey(apiKeyStr)
|
valid, err := db.ValidateAPIKey(apiKeyStr)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(valid, check.Equals, true)
|
c.Assert(valid, check.Equals, true)
|
||||||
|
|
||||||
err = app.db.ExpireAPIKey(apiKey)
|
err = db.ExpireAPIKey(apiKey)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(apiKey.Expiration, check.NotNil)
|
c.Assert(apiKey.Expiration, check.NotNil)
|
||||||
|
|
||||||
notValid, err := app.db.ValidateAPIKey(apiKeyStr)
|
notValid, err := db.ValidateAPIKey(apiKeyStr)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(notValid, check.Equals, false)
|
c.Assert(notValid, check.Equals, false)
|
||||||
|
|
||||||
|
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||||
}
|
}
|
|
@ -1,9 +1,7 @@
|
||||||
package hscontrol
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql/driver"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
@ -11,11 +9,12 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/glebarez/sqlite"
|
"github.com/glebarez/sqlite"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"gorm.io/driver/postgres"
|
"gorm.io/driver/postgres"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"gorm.io/gorm/logger"
|
"gorm.io/gorm/logger"
|
||||||
"tailscale.com/tailcfg"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -26,7 +25,6 @@ const (
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errValueNotFound = errors.New("not found")
|
errValueNotFound = errors.New("not found")
|
||||||
ErrCannotParsePrefix = errors.New("cannot parse prefix")
|
|
||||||
errDatabaseNotSupported = errors.New("database type not supported")
|
errDatabaseNotSupported = errors.New("database type not supported")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -40,6 +38,7 @@ type KV struct {
|
||||||
type HSDatabase struct {
|
type HSDatabase struct {
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
notifyStateChan chan<- struct{}
|
notifyStateChan chan<- struct{}
|
||||||
|
notifyPolicyChan chan<- struct{}
|
||||||
|
|
||||||
ipAllocationMutex sync.Mutex
|
ipAllocationMutex sync.Mutex
|
||||||
|
|
||||||
|
@ -54,6 +53,7 @@ func NewHeadscaleDatabase(
|
||||||
dbType, connectionAddr string,
|
dbType, connectionAddr string,
|
||||||
stripEmailDomain, debug bool,
|
stripEmailDomain, debug bool,
|
||||||
notifyStateChan chan<- struct{},
|
notifyStateChan chan<- struct{},
|
||||||
|
notifyPolicyChan chan<- struct{},
|
||||||
ipPrefixes []netip.Prefix,
|
ipPrefixes []netip.Prefix,
|
||||||
baseDomain string,
|
baseDomain string,
|
||||||
) (*HSDatabase, error) {
|
) (*HSDatabase, error) {
|
||||||
|
@ -65,6 +65,7 @@ func NewHeadscaleDatabase(
|
||||||
db := HSDatabase{
|
db := HSDatabase{
|
||||||
db: dbConn,
|
db: dbConn,
|
||||||
notifyStateChan: notifyStateChan,
|
notifyStateChan: notifyStateChan,
|
||||||
|
notifyPolicyChan: notifyPolicyChan,
|
||||||
|
|
||||||
ipPrefixes: ipPrefixes,
|
ipPrefixes: ipPrefixes,
|
||||||
baseDomain: baseDomain,
|
baseDomain: baseDomain,
|
||||||
|
@ -79,30 +80,30 @@ func NewHeadscaleDatabase(
|
||||||
|
|
||||||
_ = dbConn.Migrator().RenameTable("namespaces", "users")
|
_ = dbConn.Migrator().RenameTable("namespaces", "users")
|
||||||
|
|
||||||
err = dbConn.AutoMigrate(User{})
|
err = dbConn.AutoMigrate(types.User{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = dbConn.Migrator().RenameColumn(&Machine{}, "namespace_id", "user_id")
|
_ = dbConn.Migrator().RenameColumn(&types.Machine{}, "namespace_id", "user_id")
|
||||||
_ = dbConn.Migrator().RenameColumn(&PreAuthKey{}, "namespace_id", "user_id")
|
_ = dbConn.Migrator().RenameColumn(&types.PreAuthKey{}, "namespace_id", "user_id")
|
||||||
|
|
||||||
_ = dbConn.Migrator().RenameColumn(&Machine{}, "ip_address", "ip_addresses")
|
_ = dbConn.Migrator().RenameColumn(&types.Machine{}, "ip_address", "ip_addresses")
|
||||||
_ = dbConn.Migrator().RenameColumn(&Machine{}, "name", "hostname")
|
_ = dbConn.Migrator().RenameColumn(&types.Machine{}, "name", "hostname")
|
||||||
|
|
||||||
// GivenName is used as the primary source of DNS names, make sure
|
// GivenName is used as the primary source of DNS names, make sure
|
||||||
// the field is populated and normalized if it was not when the
|
// the field is populated and normalized if it was not when the
|
||||||
// machine was registered.
|
// machine was registered.
|
||||||
_ = dbConn.Migrator().RenameColumn(&Machine{}, "nickname", "given_name")
|
_ = dbConn.Migrator().RenameColumn(&types.Machine{}, "nickname", "given_name")
|
||||||
|
|
||||||
// If the Machine table has a column for registered,
|
// If the Machine table has a column for registered,
|
||||||
// find all occourences of "false" and drop them. Then
|
// find all occourences of "false" and drop them. Then
|
||||||
// remove the column.
|
// remove the column.
|
||||||
if dbConn.Migrator().HasColumn(&Machine{}, "registered") {
|
if dbConn.Migrator().HasColumn(&types.Machine{}, "registered") {
|
||||||
log.Info().
|
log.Info().
|
||||||
Msg(`Database has legacy "registered" column in machine, removing...`)
|
Msg(`Database has legacy "registered" column in machine, removing...`)
|
||||||
|
|
||||||
machines := Machines{}
|
machines := types.Machines{}
|
||||||
if err := dbConn.Not("registered").Find(&machines).Error; err != nil {
|
if err := dbConn.Not("registered").Find(&machines).Error; err != nil {
|
||||||
log.Error().Err(err).Msg("Error accessing db")
|
log.Error().Err(err).Msg("Error accessing db")
|
||||||
}
|
}
|
||||||
|
@ -112,7 +113,7 @@ func NewHeadscaleDatabase(
|
||||||
Str("machine", machine.Hostname).
|
Str("machine", machine.Hostname).
|
||||||
Str("machine_key", machine.MachineKey).
|
Str("machine_key", machine.MachineKey).
|
||||||
Msg("Deleting unregistered machine")
|
Msg("Deleting unregistered machine")
|
||||||
if err := dbConn.Delete(&Machine{}, machine.ID).Error; err != nil {
|
if err := dbConn.Delete(&types.Machine{}, machine.ID).Error; err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Err(err).
|
Err(err).
|
||||||
Str("machine", machine.Hostname).
|
Str("machine", machine.Hostname).
|
||||||
|
@ -121,23 +122,23 @@ func NewHeadscaleDatabase(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err := dbConn.Migrator().DropColumn(&Machine{}, "registered")
|
err := dbConn.Migrator().DropColumn(&types.Machine{}, "registered")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("Error dropping registered column")
|
log.Error().Err(err).Msg("Error dropping registered column")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = dbConn.AutoMigrate(&Route{})
|
err = dbConn.AutoMigrate(&types.Route{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if dbConn.Migrator().HasColumn(&Machine{}, "enabled_routes") {
|
if dbConn.Migrator().HasColumn(&types.Machine{}, "enabled_routes") {
|
||||||
log.Info().Msgf("Database has legacy enabled_routes column in machine, migrating...")
|
log.Info().Msgf("Database has legacy enabled_routes column in machine, migrating...")
|
||||||
|
|
||||||
type MachineAux struct {
|
type MachineAux struct {
|
||||||
ID uint64
|
ID uint64
|
||||||
EnabledRoutes IPPrefixes
|
EnabledRoutes types.IPPrefixes
|
||||||
}
|
}
|
||||||
|
|
||||||
machinesAux := []MachineAux{}
|
machinesAux := []MachineAux{}
|
||||||
|
@ -157,8 +158,8 @@ func NewHeadscaleDatabase(
|
||||||
}
|
}
|
||||||
|
|
||||||
err = dbConn.Preload("Machine").
|
err = dbConn.Preload("Machine").
|
||||||
Where("machine_id = ? AND prefix = ?", machine.ID, IPPrefix(prefix)).
|
Where("machine_id = ? AND prefix = ?", machine.ID, types.IPPrefix(prefix)).
|
||||||
First(&Route{}).
|
First(&types.Route{}).
|
||||||
Error
|
Error
|
||||||
if err == nil {
|
if err == nil {
|
||||||
log.Info().
|
log.Info().
|
||||||
|
@ -168,11 +169,11 @@ func NewHeadscaleDatabase(
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
route := Route{
|
route := types.Route{
|
||||||
MachineID: machine.ID,
|
MachineID: machine.ID,
|
||||||
Advertised: true,
|
Advertised: true,
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Prefix: IPPrefix(prefix),
|
Prefix: types.IPPrefix(prefix),
|
||||||
}
|
}
|
||||||
if err := dbConn.Create(&route).Error; err != nil {
|
if err := dbConn.Create(&route).Error; err != nil {
|
||||||
log.Error().Err(err).Msg("Error creating route")
|
log.Error().Err(err).Msg("Error creating route")
|
||||||
|
@ -185,26 +186,26 @@ func NewHeadscaleDatabase(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = dbConn.Migrator().DropColumn(&Machine{}, "enabled_routes")
|
err = dbConn.Migrator().DropColumn(&types.Machine{}, "enabled_routes")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("Error dropping enabled_routes column")
|
log.Error().Err(err).Msg("Error dropping enabled_routes column")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = dbConn.AutoMigrate(&Machine{})
|
err = dbConn.AutoMigrate(&types.Machine{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if dbConn.Migrator().HasColumn(&Machine{}, "given_name") {
|
if dbConn.Migrator().HasColumn(&types.Machine{}, "given_name") {
|
||||||
machines := Machines{}
|
machines := types.Machines{}
|
||||||
if err := dbConn.Find(&machines).Error; err != nil {
|
if err := dbConn.Find(&machines).Error; err != nil {
|
||||||
log.Error().Err(err).Msg("Error accessing db")
|
log.Error().Err(err).Msg("Error accessing db")
|
||||||
}
|
}
|
||||||
|
|
||||||
for item, machine := range machines {
|
for item, machine := range machines {
|
||||||
if machine.GivenName == "" {
|
if machine.GivenName == "" {
|
||||||
normalizedHostname, err := NormalizeToFQDNRules(
|
normalizedHostname, err := util.NormalizeToFQDNRules(
|
||||||
machine.Hostname,
|
machine.Hostname,
|
||||||
stripEmailDomain,
|
stripEmailDomain,
|
||||||
)
|
)
|
||||||
|
@ -233,19 +234,19 @@ func NewHeadscaleDatabase(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = dbConn.AutoMigrate(&PreAuthKey{})
|
err = dbConn.AutoMigrate(&types.PreAuthKey{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = dbConn.AutoMigrate(&PreAuthKeyACLTag{})
|
err = dbConn.AutoMigrate(&types.PreAuthKeyACLTag{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = dbConn.Migrator().DropTable("shared_machines")
|
_ = dbConn.Migrator().DropTable("shared_machines")
|
||||||
|
|
||||||
err = dbConn.AutoMigrate(&APIKey{})
|
err = dbConn.AutoMigrate(&types.APIKey{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -339,7 +340,7 @@ func (hsdb *HSDatabase) setValue(key string, value string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) pingDB(ctx context.Context) error {
|
func (hsdb *HSDatabase) PingDB(ctx context.Context) error {
|
||||||
ctx, cancel := context.WithTimeout(ctx, time.Second)
|
ctx, cancel := context.WithTimeout(ctx, time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
sqlDB, err := hsdb.db.DB()
|
sqlDB, err := hsdb.db.DB()
|
||||||
|
@ -350,97 +351,11 @@ func (hsdb *HSDatabase) pingDB(ctx context.Context) error {
|
||||||
return sqlDB.PingContext(ctx)
|
return sqlDB.PingContext(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// This is a "wrapper" type around tailscales
|
func (hsdb *HSDatabase) Close() error {
|
||||||
// Hostinfo to allow us to add database "serialization"
|
db, err := hsdb.db.DB()
|
||||||
// methods. This allows us to use a typed values throughout
|
|
||||||
// the code and not have to marshal/unmarshal and error
|
|
||||||
// check all over the code.
|
|
||||||
type HostInfo tailcfg.Hostinfo
|
|
||||||
|
|
||||||
func (hi *HostInfo) Scan(destination interface{}) error {
|
|
||||||
switch value := destination.(type) {
|
|
||||||
case []byte:
|
|
||||||
return json.Unmarshal(value, hi)
|
|
||||||
|
|
||||||
case string:
|
|
||||||
return json.Unmarshal([]byte(value), hi)
|
|
||||||
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Value return json value, implement driver.Valuer interface.
|
|
||||||
func (hi HostInfo) Value() (driver.Value, error) {
|
|
||||||
bytes, err := json.Marshal(hi)
|
|
||||||
|
|
||||||
return string(bytes), err
|
|
||||||
}
|
|
||||||
|
|
||||||
type IPPrefix netip.Prefix
|
|
||||||
|
|
||||||
func (i *IPPrefix) Scan(destination interface{}) error {
|
|
||||||
switch value := destination.(type) {
|
|
||||||
case string:
|
|
||||||
prefix, err := netip.ParsePrefix(value)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
*i = IPPrefix(prefix)
|
|
||||||
|
|
||||||
return nil
|
return db.Close()
|
||||||
default:
|
|
||||||
return fmt.Errorf("%w: unexpected data type %T", ErrCannotParsePrefix, destination)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Value return json value, implement driver.Valuer interface.
|
|
||||||
func (i IPPrefix) Value() (driver.Value, error) {
|
|
||||||
prefixStr := netip.Prefix(i).String()
|
|
||||||
|
|
||||||
return prefixStr, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type IPPrefixes []netip.Prefix
|
|
||||||
|
|
||||||
func (i *IPPrefixes) Scan(destination interface{}) error {
|
|
||||||
switch value := destination.(type) {
|
|
||||||
case []byte:
|
|
||||||
return json.Unmarshal(value, i)
|
|
||||||
|
|
||||||
case string:
|
|
||||||
return json.Unmarshal([]byte(value), i)
|
|
||||||
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Value return json value, implement driver.Valuer interface.
|
|
||||||
func (i IPPrefixes) Value() (driver.Value, error) {
|
|
||||||
bytes, err := json.Marshal(i)
|
|
||||||
|
|
||||||
return string(bytes), err
|
|
||||||
}
|
|
||||||
|
|
||||||
type StringList []string
|
|
||||||
|
|
||||||
func (i *StringList) Scan(destination interface{}) error {
|
|
||||||
switch value := destination.(type) {
|
|
||||||
case []byte:
|
|
||||||
return json.Unmarshal(value, i)
|
|
||||||
|
|
||||||
case string:
|
|
||||||
return json.Unmarshal([]byte(value), i)
|
|
||||||
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Value return json value, implement driver.Valuer interface.
|
|
||||||
func (i StringList) Value() (driver.Value, error) {
|
|
||||||
bytes, err := json.Marshal(i)
|
|
||||||
|
|
||||||
return string(bytes), err
|
|
||||||
}
|
}
|
File diff suppressed because it is too large
Load diff
797
hscontrol/db/machine_test.go
Normal file
797
hscontrol/db/machine_test.go
Normal file
|
@ -0,0 +1,797 @@
|
||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/policy"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
|
"gopkg.in/check.v1"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
"tailscale.com/types/key"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Suite) TestGetMachine(c *check.C) {
|
||||||
|
user, err := db.CreateUser("test")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
_, err = db.GetMachine("test", "testmachine")
|
||||||
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
|
machine := &types.Machine{
|
||||||
|
ID: 0,
|
||||||
|
MachineKey: "foo",
|
||||||
|
NodeKey: "bar",
|
||||||
|
DiscoKey: "faa",
|
||||||
|
Hostname: "testmachine",
|
||||||
|
UserID: user.ID,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
AuthKeyID: uint(pak.ID),
|
||||||
|
}
|
||||||
|
db.db.Save(machine)
|
||||||
|
|
||||||
|
_, err = db.GetMachine("test", "testmachine")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Suite) TestGetMachineByID(c *check.C) {
|
||||||
|
user, err := db.CreateUser("test")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
_, err = db.GetMachineByID(0)
|
||||||
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
|
machine := types.Machine{
|
||||||
|
ID: 0,
|
||||||
|
MachineKey: "foo",
|
||||||
|
NodeKey: "bar",
|
||||||
|
DiscoKey: "faa",
|
||||||
|
Hostname: "testmachine",
|
||||||
|
UserID: user.ID,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
AuthKeyID: uint(pak.ID),
|
||||||
|
}
|
||||||
|
db.db.Save(&machine)
|
||||||
|
|
||||||
|
_, err = db.GetMachineByID(0)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Suite) TestGetMachineByNodeKey(c *check.C) {
|
||||||
|
user, err := db.CreateUser("test")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
_, err = db.GetMachineByID(0)
|
||||||
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
|
nodeKey := key.NewNode()
|
||||||
|
machineKey := key.NewMachine()
|
||||||
|
|
||||||
|
machine := types.Machine{
|
||||||
|
ID: 0,
|
||||||
|
MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()),
|
||||||
|
NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()),
|
||||||
|
DiscoKey: "faa",
|
||||||
|
Hostname: "testmachine",
|
||||||
|
UserID: user.ID,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
AuthKeyID: uint(pak.ID),
|
||||||
|
}
|
||||||
|
db.db.Save(&machine)
|
||||||
|
|
||||||
|
_, err = db.GetMachineByNodeKey(nodeKey.Public())
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Suite) TestGetMachineByAnyNodeKey(c *check.C) {
|
||||||
|
user, err := db.CreateUser("test")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
_, err = db.GetMachineByID(0)
|
||||||
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
|
nodeKey := key.NewNode()
|
||||||
|
oldNodeKey := key.NewNode()
|
||||||
|
|
||||||
|
machineKey := key.NewMachine()
|
||||||
|
|
||||||
|
machine := types.Machine{
|
||||||
|
ID: 0,
|
||||||
|
MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()),
|
||||||
|
NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()),
|
||||||
|
DiscoKey: "faa",
|
||||||
|
Hostname: "testmachine",
|
||||||
|
UserID: user.ID,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
AuthKeyID: uint(pak.ID),
|
||||||
|
}
|
||||||
|
db.db.Save(&machine)
|
||||||
|
|
||||||
|
_, err = db.GetMachineByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public())
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Suite) TestDeleteMachine(c *check.C) {
|
||||||
|
user, err := db.CreateUser("test")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
machine := types.Machine{
|
||||||
|
ID: 0,
|
||||||
|
MachineKey: "foo",
|
||||||
|
NodeKey: "bar",
|
||||||
|
DiscoKey: "faa",
|
||||||
|
Hostname: "testmachine",
|
||||||
|
UserID: user.ID,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
AuthKeyID: uint(1),
|
||||||
|
}
|
||||||
|
db.db.Save(&machine)
|
||||||
|
|
||||||
|
err = db.DeleteMachine(&machine)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
_, err = db.GetMachine(user.Name, "testmachine")
|
||||||
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
|
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Suite) TestHardDeleteMachine(c *check.C) {
|
||||||
|
user, err := db.CreateUser("test")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
machine := types.Machine{
|
||||||
|
ID: 0,
|
||||||
|
MachineKey: "foo",
|
||||||
|
NodeKey: "bar",
|
||||||
|
DiscoKey: "faa",
|
||||||
|
Hostname: "testmachine3",
|
||||||
|
UserID: user.ID,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
AuthKeyID: uint(1),
|
||||||
|
}
|
||||||
|
db.db.Save(&machine)
|
||||||
|
|
||||||
|
err = db.HardDeleteMachine(&machine)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
_, err = db.GetMachine(user.Name, "testmachine3")
|
||||||
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
|
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Suite) TestListPeers(c *check.C) {
|
||||||
|
user, err := db.CreateUser("test")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
_, err = db.GetMachineByID(0)
|
||||||
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
|
for index := 0; index <= 10; index++ {
|
||||||
|
machine := types.Machine{
|
||||||
|
ID: uint64(index),
|
||||||
|
MachineKey: "foo" + strconv.Itoa(index),
|
||||||
|
NodeKey: "bar" + strconv.Itoa(index),
|
||||||
|
DiscoKey: "faa" + strconv.Itoa(index),
|
||||||
|
Hostname: "testmachine" + strconv.Itoa(index),
|
||||||
|
UserID: user.ID,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
AuthKeyID: uint(pak.ID),
|
||||||
|
}
|
||||||
|
db.db.Save(&machine)
|
||||||
|
}
|
||||||
|
|
||||||
|
machine0ByID, err := db.GetMachineByID(0)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
peersOfMachine0, err := db.ListPeers(machine0ByID)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
c.Assert(len(peersOfMachine0), check.Equals, 9)
|
||||||
|
c.Assert(peersOfMachine0[0].Hostname, check.Equals, "testmachine2")
|
||||||
|
c.Assert(peersOfMachine0[5].Hostname, check.Equals, "testmachine7")
|
||||||
|
c.Assert(peersOfMachine0[8].Hostname, check.Equals, "testmachine10")
|
||||||
|
|
||||||
|
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
||||||
|
type base struct {
|
||||||
|
user *types.User
|
||||||
|
key *types.PreAuthKey
|
||||||
|
}
|
||||||
|
|
||||||
|
stor := make([]base, 0)
|
||||||
|
|
||||||
|
for _, name := range []string{"test", "admin"} {
|
||||||
|
user, err := db.CreateUser(name)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
stor = append(stor, base{user, pak})
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := db.GetMachineByID(0)
|
||||||
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
|
for index := 0; index <= 10; index++ {
|
||||||
|
machine := types.Machine{
|
||||||
|
ID: uint64(index),
|
||||||
|
MachineKey: "foo" + strconv.Itoa(index),
|
||||||
|
NodeKey: "bar" + strconv.Itoa(index),
|
||||||
|
DiscoKey: "faa" + strconv.Itoa(index),
|
||||||
|
IPAddresses: types.MachineAddresses{
|
||||||
|
netip.MustParseAddr(fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1))),
|
||||||
|
},
|
||||||
|
Hostname: "testmachine" + strconv.Itoa(index),
|
||||||
|
UserID: stor[index%2].user.ID,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
AuthKeyID: uint(stor[index%2].key.ID),
|
||||||
|
}
|
||||||
|
db.db.Save(&machine)
|
||||||
|
}
|
||||||
|
|
||||||
|
aclPolicy := &policy.ACLPolicy{
|
||||||
|
Groups: map[string][]string{
|
||||||
|
"group:test": {"admin"},
|
||||||
|
},
|
||||||
|
Hosts: map[string]netip.Prefix{},
|
||||||
|
TagOwners: map[string][]string{},
|
||||||
|
ACLs: []policy.ACL{
|
||||||
|
{
|
||||||
|
Action: "accept",
|
||||||
|
Sources: []string{"admin"},
|
||||||
|
Destinations: []string{"*:*"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Action: "accept",
|
||||||
|
Sources: []string{"test"},
|
||||||
|
Destinations: []string{"test:*"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Tests: []policy.ACLTest{},
|
||||||
|
}
|
||||||
|
|
||||||
|
adminMachine, err := db.GetMachineByID(1)
|
||||||
|
c.Logf("Machine(%v), user: %v", adminMachine.Hostname, adminMachine.User)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
testMachine, err := db.GetMachineByID(2)
|
||||||
|
c.Logf("Machine(%v), user: %v", testMachine.Hostname, testMachine.User)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
machines, err := db.ListMachines()
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
aclRules, _, err := policy.GenerateFilterRules(aclPolicy, machines, false)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
peersOfTestMachine := db.filterMachinesByACL(aclRules, testMachine, machines)
|
||||||
|
peersOfAdminMachine := db.filterMachinesByACL(aclRules, adminMachine, machines)
|
||||||
|
|
||||||
|
c.Log(peersOfTestMachine)
|
||||||
|
c.Assert(len(peersOfTestMachine), check.Equals, 9)
|
||||||
|
c.Assert(peersOfTestMachine[0].Hostname, check.Equals, "testmachine1")
|
||||||
|
c.Assert(peersOfTestMachine[1].Hostname, check.Equals, "testmachine3")
|
||||||
|
c.Assert(peersOfTestMachine[3].Hostname, check.Equals, "testmachine5")
|
||||||
|
|
||||||
|
c.Log(peersOfAdminMachine)
|
||||||
|
c.Assert(len(peersOfAdminMachine), check.Equals, 9)
|
||||||
|
c.Assert(peersOfAdminMachine[0].Hostname, check.Equals, "testmachine2")
|
||||||
|
c.Assert(peersOfAdminMachine[2].Hostname, check.Equals, "testmachine4")
|
||||||
|
c.Assert(peersOfAdminMachine[5].Hostname, check.Equals, "testmachine7")
|
||||||
|
|
||||||
|
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Suite) TestExpireMachine(c *check.C) {
|
||||||
|
user, err := db.CreateUser("test")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
_, err = db.GetMachine("test", "testmachine")
|
||||||
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
|
machine := &types.Machine{
|
||||||
|
ID: 0,
|
||||||
|
MachineKey: "foo",
|
||||||
|
NodeKey: "bar",
|
||||||
|
DiscoKey: "faa",
|
||||||
|
Hostname: "testmachine",
|
||||||
|
UserID: user.ID,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
AuthKeyID: uint(pak.ID),
|
||||||
|
Expiry: &time.Time{},
|
||||||
|
}
|
||||||
|
db.db.Save(machine)
|
||||||
|
|
||||||
|
machineFromDB, err := db.GetMachine("test", "testmachine")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(machineFromDB, check.NotNil)
|
||||||
|
|
||||||
|
c.Assert(machineFromDB.IsExpired(), check.Equals, false)
|
||||||
|
|
||||||
|
err = db.ExpireMachine(machineFromDB)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
c.Assert(machineFromDB.IsExpired(), check.Equals, true)
|
||||||
|
|
||||||
|
c.Assert(channelUpdates, check.Equals, int32(1))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) {
|
||||||
|
input := types.MachineAddresses([]netip.Addr{
|
||||||
|
netip.MustParseAddr("192.0.2.1"),
|
||||||
|
netip.MustParseAddr("2001:db8::1"),
|
||||||
|
})
|
||||||
|
serialized, err := input.Value()
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
if serial, ok := serialized.(string); ok {
|
||||||
|
c.Assert(serial, check.Equals, "192.0.2.1,2001:db8::1")
|
||||||
|
}
|
||||||
|
|
||||||
|
var deserialized types.MachineAddresses
|
||||||
|
err = deserialized.Scan(serialized)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
c.Assert(len(deserialized), check.Equals, len(input))
|
||||||
|
for i := range deserialized {
|
||||||
|
c.Assert(deserialized[i], check.Equals, input[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Suite) TestGenerateGivenName(c *check.C) {
|
||||||
|
user1, err := db.CreateUser("user-1")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
pak, err := db.CreatePreAuthKey(user1.Name, false, false, nil, nil)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
_, err = db.GetMachine("user-1", "testmachine")
|
||||||
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
|
machine := &types.Machine{
|
||||||
|
ID: 0,
|
||||||
|
MachineKey: "machine-key-1",
|
||||||
|
NodeKey: "node-key-1",
|
||||||
|
DiscoKey: "disco-key-1",
|
||||||
|
Hostname: "hostname-1",
|
||||||
|
GivenName: "hostname-1",
|
||||||
|
UserID: user1.ID,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
AuthKeyID: uint(pak.ID),
|
||||||
|
}
|
||||||
|
db.db.Save(machine)
|
||||||
|
|
||||||
|
givenName, err := db.GenerateGivenName("machine-key-2", "hostname-2")
|
||||||
|
comment := check.Commentf("Same user, unique machines, unique hostnames, no conflict")
|
||||||
|
c.Assert(err, check.IsNil, comment)
|
||||||
|
c.Assert(givenName, check.Equals, "hostname-2", comment)
|
||||||
|
|
||||||
|
givenName, err = db.GenerateGivenName("machine-key-1", "hostname-1")
|
||||||
|
comment = check.Commentf("Same user, same machine, same hostname, no conflict")
|
||||||
|
c.Assert(err, check.IsNil, comment)
|
||||||
|
c.Assert(givenName, check.Equals, "hostname-1", comment)
|
||||||
|
|
||||||
|
givenName, err = db.GenerateGivenName("machine-key-2", "hostname-1")
|
||||||
|
comment = check.Commentf("Same user, unique machines, same hostname, conflict")
|
||||||
|
c.Assert(err, check.IsNil, comment)
|
||||||
|
c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", MachineGivenNameHashLength), comment)
|
||||||
|
|
||||||
|
givenName, err = db.GenerateGivenName("machine-key-2", "hostname-1")
|
||||||
|
comment = check.Commentf("Unique users, unique machines, same hostname, conflict")
|
||||||
|
c.Assert(err, check.IsNil, comment)
|
||||||
|
c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", MachineGivenNameHashLength), comment)
|
||||||
|
|
||||||
|
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Suite) TestSetTags(c *check.C) {
|
||||||
|
user, err := db.CreateUser("test")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
_, err = db.GetMachine("test", "testmachine")
|
||||||
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
|
machine := &types.Machine{
|
||||||
|
ID: 0,
|
||||||
|
MachineKey: "foo",
|
||||||
|
NodeKey: "bar",
|
||||||
|
DiscoKey: "faa",
|
||||||
|
Hostname: "testmachine",
|
||||||
|
UserID: user.ID,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
AuthKeyID: uint(pak.ID),
|
||||||
|
}
|
||||||
|
db.db.Save(machine)
|
||||||
|
|
||||||
|
// assign simple tags
|
||||||
|
sTags := []string{"tag:test", "tag:foo"}
|
||||||
|
err = db.SetTags(machine, sTags)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
machine, err = db.GetMachine("test", "testmachine")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(machine.ForcedTags, check.DeepEquals, types.StringList(sTags))
|
||||||
|
|
||||||
|
// assign duplicat tags, expect no errors but no doubles in DB
|
||||||
|
eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"}
|
||||||
|
err = db.SetTags(machine, eTags)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
machine, err = db.GetMachine("test", "testmachine")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(
|
||||||
|
machine.ForcedTags,
|
||||||
|
check.DeepEquals,
|
||||||
|
types.StringList([]string{"tag:bar", "tag:test", "tag:unknown"}),
|
||||||
|
)
|
||||||
|
|
||||||
|
c.Assert(channelUpdates, check.Equals, int32(4))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHeadscale_generateGivenName(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
suppliedName string
|
||||||
|
randomSuffix bool
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
db *HSDatabase
|
||||||
|
args args
|
||||||
|
want *regexp.Regexp
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple machine name generation",
|
||||||
|
db: &HSDatabase{
|
||||||
|
stripEmailDomain: true,
|
||||||
|
},
|
||||||
|
args: args{
|
||||||
|
suppliedName: "testmachine",
|
||||||
|
randomSuffix: false,
|
||||||
|
},
|
||||||
|
want: regexp.MustCompile("^testmachine$"),
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "machine name with 53 chars",
|
||||||
|
db: &HSDatabase{
|
||||||
|
stripEmailDomain: true,
|
||||||
|
},
|
||||||
|
args: args{
|
||||||
|
suppliedName: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine",
|
||||||
|
randomSuffix: false,
|
||||||
|
},
|
||||||
|
want: regexp.MustCompile("^testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine$"),
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "machine name with 63 chars",
|
||||||
|
db: &HSDatabase{
|
||||||
|
stripEmailDomain: true,
|
||||||
|
},
|
||||||
|
args: args{
|
||||||
|
suppliedName: "machineeee12345678901234567890123456789012345678901234567890123",
|
||||||
|
randomSuffix: false,
|
||||||
|
},
|
||||||
|
want: regexp.MustCompile("^machineeee12345678901234567890123456789012345678901234567890123$"),
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "machine name with 64 chars",
|
||||||
|
db: &HSDatabase{
|
||||||
|
stripEmailDomain: true,
|
||||||
|
},
|
||||||
|
args: args{
|
||||||
|
suppliedName: "machineeee123456789012345678901234567890123456789012345678901234",
|
||||||
|
randomSuffix: false,
|
||||||
|
},
|
||||||
|
want: nil,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "machine name with 73 chars",
|
||||||
|
db: &HSDatabase{
|
||||||
|
stripEmailDomain: true,
|
||||||
|
},
|
||||||
|
args: args{
|
||||||
|
suppliedName: "machineeee123456789012345678901234567890123456789012345678901234567890123",
|
||||||
|
randomSuffix: false,
|
||||||
|
},
|
||||||
|
want: nil,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "machine name with random suffix",
|
||||||
|
db: &HSDatabase{
|
||||||
|
stripEmailDomain: true,
|
||||||
|
},
|
||||||
|
args: args{
|
||||||
|
suppliedName: "test",
|
||||||
|
randomSuffix: true,
|
||||||
|
},
|
||||||
|
want: regexp.MustCompile(fmt.Sprintf("^test-[a-z0-9]{%d}$", MachineGivenNameHashLength)),
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "machine name with 63 chars with random suffix",
|
||||||
|
db: &HSDatabase{
|
||||||
|
stripEmailDomain: true,
|
||||||
|
},
|
||||||
|
args: args{
|
||||||
|
suppliedName: "machineeee12345678901234567890123456789012345678901234567890123",
|
||||||
|
randomSuffix: true,
|
||||||
|
},
|
||||||
|
want: regexp.MustCompile(fmt.Sprintf("^machineeee1234567890123456789012345678901234567890123-[a-z0-9]{%d}$", MachineGivenNameHashLength)),
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := tt.db.generateGivenName(tt.args.suppliedName, tt.args.randomSuffix)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf(
|
||||||
|
"Headscale.GenerateGivenName() error = %v, wantErr %v",
|
||||||
|
err,
|
||||||
|
tt.wantErr,
|
||||||
|
)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.want != nil && !tt.want.MatchString(got) {
|
||||||
|
t.Errorf(
|
||||||
|
"Headscale.GenerateGivenName() = %v, does not match %v",
|
||||||
|
tt.want,
|
||||||
|
got,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(got) > util.LabelHostnameLength {
|
||||||
|
t.Errorf(
|
||||||
|
"Headscale.GenerateGivenName() = %v is larger than allowed DNS segment %d",
|
||||||
|
got,
|
||||||
|
util.LabelHostnameLength,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Suite) TestAutoApproveRoutes(c *check.C) {
|
||||||
|
acl := []byte(`
|
||||||
|
{
|
||||||
|
"tagOwners": {
|
||||||
|
"tag:exit": ["test"],
|
||||||
|
},
|
||||||
|
|
||||||
|
"groups": {
|
||||||
|
"group:test": ["test"]
|
||||||
|
},
|
||||||
|
|
||||||
|
"acls": [
|
||||||
|
{"action": "accept", "users": ["*"], "ports": ["*:*"]},
|
||||||
|
],
|
||||||
|
|
||||||
|
"autoApprovers": {
|
||||||
|
"exitNode": ["tag:exit"],
|
||||||
|
"routes": {
|
||||||
|
"10.10.0.0/16": ["group:test"],
|
||||||
|
"10.11.0.0/16": ["test"],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
`)
|
||||||
|
|
||||||
|
pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(pol, check.NotNil)
|
||||||
|
|
||||||
|
user, err := db.CreateUser("test")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
nodeKey := key.NewNode()
|
||||||
|
|
||||||
|
defaultRouteV4 := netip.MustParsePrefix("0.0.0.0/0")
|
||||||
|
defaultRouteV6 := netip.MustParsePrefix("::/0")
|
||||||
|
route1 := netip.MustParsePrefix("10.10.0.0/16")
|
||||||
|
// Check if a subprefix of an autoapproved route is approved
|
||||||
|
route2 := netip.MustParsePrefix("10.11.0.0/24")
|
||||||
|
|
||||||
|
machine := types.Machine{
|
||||||
|
ID: 0,
|
||||||
|
MachineKey: "foo",
|
||||||
|
NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()),
|
||||||
|
DiscoKey: "faa",
|
||||||
|
Hostname: "test",
|
||||||
|
UserID: user.ID,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
AuthKeyID: uint(pak.ID),
|
||||||
|
HostInfo: types.HostInfo{
|
||||||
|
RequestTags: []string{"tag:exit"},
|
||||||
|
RoutableIPs: []netip.Prefix{defaultRouteV4, defaultRouteV6, route1, route2},
|
||||||
|
},
|
||||||
|
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
|
||||||
|
}
|
||||||
|
|
||||||
|
db.db.Save(&machine)
|
||||||
|
|
||||||
|
err = db.ProcessMachineRoutes(&machine)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
machine0ByID, err := db.GetMachineByID(0)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
err = db.EnableAutoApprovedRoutes(pol, machine0ByID)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
enabledRoutes, err := db.GetEnabledRoutes(machine0ByID)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(enabledRoutes, check.HasLen, 4)
|
||||||
|
|
||||||
|
c.Assert(channelUpdates, check.Equals, int32(4))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMachine_canAccess(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
filter []tailcfg.FilterRule
|
||||||
|
machine2 *types.Machine
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
machine types.Machine
|
||||||
|
args args
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no-rules",
|
||||||
|
machine: types.Machine{
|
||||||
|
IPAddresses: types.MachineAddresses{
|
||||||
|
netip.MustParseAddr("10.0.0.1"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
args: args{
|
||||||
|
filter: []tailcfg.FilterRule{},
|
||||||
|
machine2: &types.Machine{
|
||||||
|
IPAddresses: types.MachineAddresses{
|
||||||
|
netip.MustParseAddr("10.0.0.2"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard",
|
||||||
|
machine: types.Machine{
|
||||||
|
IPAddresses: types.MachineAddresses{
|
||||||
|
netip.MustParseAddr("10.0.0.1"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
args: args{
|
||||||
|
filter: []tailcfg.FilterRule{
|
||||||
|
{
|
||||||
|
SrcIPs: []string{"*"},
|
||||||
|
DstPorts: []tailcfg.NetPortRange{
|
||||||
|
{
|
||||||
|
IP: "*",
|
||||||
|
Ports: tailcfg.PortRange{
|
||||||
|
First: 0,
|
||||||
|
Last: 65535,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
machine2: &types.Machine{
|
||||||
|
IPAddresses: types.MachineAddresses{
|
||||||
|
netip.MustParseAddr("10.0.0.2"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "explicit-m1-to-m2",
|
||||||
|
machine: types.Machine{
|
||||||
|
IPAddresses: types.MachineAddresses{
|
||||||
|
netip.MustParseAddr("10.0.0.1"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
args: args{
|
||||||
|
filter: []tailcfg.FilterRule{
|
||||||
|
{
|
||||||
|
SrcIPs: []string{"10.0.0.1"},
|
||||||
|
DstPorts: []tailcfg.NetPortRange{
|
||||||
|
{
|
||||||
|
IP: "10.0.0.2",
|
||||||
|
Ports: tailcfg.PortRange{
|
||||||
|
First: 0,
|
||||||
|
Last: 65535,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
machine2: &types.Machine{
|
||||||
|
IPAddresses: types.MachineAddresses{
|
||||||
|
netip.MustParseAddr("10.0.0.2"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "explicit-m2-to-m1",
|
||||||
|
machine: types.Machine{
|
||||||
|
IPAddresses: types.MachineAddresses{
|
||||||
|
netip.MustParseAddr("10.0.0.1"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
args: args{
|
||||||
|
filter: []tailcfg.FilterRule{
|
||||||
|
{
|
||||||
|
SrcIPs: []string{"10.0.0.2"},
|
||||||
|
DstPorts: []tailcfg.NetPortRange{
|
||||||
|
{
|
||||||
|
IP: "10.0.0.1",
|
||||||
|
Ports: tailcfg.PortRange{
|
||||||
|
First: 0,
|
||||||
|
Last: 65535,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
machine2: &types.Machine{
|
||||||
|
IPAddresses: types.MachineAddresses{
|
||||||
|
netip.MustParseAddr("10.0.0.2"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := tt.machine.CanAccess(tt.args.filter, tt.args.machine2); got != tt.want {
|
||||||
|
t.Errorf("Machine.CanAccess() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,17 +1,14 @@
|
||||||
package hscontrol
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -23,28 +20,6 @@ var (
|
||||||
ErrPreAuthKeyACLTagInvalid = errors.New("AuthKey tag is invalid")
|
ErrPreAuthKeyACLTagInvalid = errors.New("AuthKey tag is invalid")
|
||||||
)
|
)
|
||||||
|
|
||||||
// PreAuthKey describes a pre-authorization key usable in a particular user.
|
|
||||||
type PreAuthKey struct {
|
|
||||||
ID uint64 `gorm:"primary_key"`
|
|
||||||
Key string
|
|
||||||
UserID uint
|
|
||||||
User User
|
|
||||||
Reusable bool
|
|
||||||
Ephemeral bool `gorm:"default:false"`
|
|
||||||
Used bool `gorm:"default:false"`
|
|
||||||
ACLTags []PreAuthKeyACLTag
|
|
||||||
|
|
||||||
CreatedAt *time.Time
|
|
||||||
Expiration *time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
// PreAuthKeyACLTag describes an autmatic tag applied to a node when registered with the associated PreAuthKey.
|
|
||||||
type PreAuthKeyACLTag struct {
|
|
||||||
ID uint64 `gorm:"primary_key"`
|
|
||||||
PreAuthKeyID uint64
|
|
||||||
Tag string
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
|
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
|
||||||
func (hsdb *HSDatabase) CreatePreAuthKey(
|
func (hsdb *HSDatabase) CreatePreAuthKey(
|
||||||
userName string,
|
userName string,
|
||||||
|
@ -52,7 +27,7 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
|
||||||
ephemeral bool,
|
ephemeral bool,
|
||||||
expiration *time.Time,
|
expiration *time.Time,
|
||||||
aclTags []string,
|
aclTags []string,
|
||||||
) (*PreAuthKey, error) {
|
) (*types.PreAuthKey, error) {
|
||||||
user, err := hsdb.GetUser(userName)
|
user, err := hsdb.GetUser(userName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -74,7 +49,7 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
key := PreAuthKey{
|
key := types.PreAuthKey{
|
||||||
Key: kstr,
|
Key: kstr,
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
User: *user,
|
User: *user,
|
||||||
|
@ -94,7 +69,7 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
|
||||||
|
|
||||||
for _, tag := range aclTags {
|
for _, tag := range aclTags {
|
||||||
if !seenTags[tag] {
|
if !seenTags[tag] {
|
||||||
if err := db.Save(&PreAuthKeyACLTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil {
|
if err := db.Save(&types.PreAuthKeyACLTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil {
|
||||||
return fmt.Errorf(
|
return fmt.Errorf(
|
||||||
"failed to ceate key tag in the database: %w",
|
"failed to ceate key tag in the database: %w",
|
||||||
err,
|
err,
|
||||||
|
@ -116,14 +91,14 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListPreAuthKeys returns the list of PreAuthKeys for a user.
|
// ListPreAuthKeys returns the list of PreAuthKeys for a user.
|
||||||
func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]PreAuthKey, error) {
|
func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, error) {
|
||||||
user, err := hsdb.GetUser(userName)
|
user, err := hsdb.GetUser(userName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
keys := []PreAuthKey{}
|
keys := []types.PreAuthKey{}
|
||||||
if err := hsdb.db.Preload("User").Preload("ACLTags").Where(&PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil {
|
if err := hsdb.db.Preload("User").Preload("ACLTags").Where(&types.PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -131,8 +106,8 @@ func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]PreAuthKey, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPreAuthKey returns a PreAuthKey for a given key.
|
// GetPreAuthKey returns a PreAuthKey for a given key.
|
||||||
func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*PreAuthKey, error) {
|
func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKey, error) {
|
||||||
pak, err := hsdb.checkKeyValidity(key)
|
pak, err := hsdb.ValidatePreAuthKey(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -146,9 +121,9 @@ func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*PreAuthKey, err
|
||||||
|
|
||||||
// DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey
|
// DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey
|
||||||
// does not exist.
|
// does not exist.
|
||||||
func (hsdb *HSDatabase) DestroyPreAuthKey(pak PreAuthKey) error {
|
func (hsdb *HSDatabase) DestroyPreAuthKey(pak types.PreAuthKey) error {
|
||||||
return hsdb.db.Transaction(func(db *gorm.DB) error {
|
return hsdb.db.Transaction(func(db *gorm.DB) error {
|
||||||
if result := db.Unscoped().Where(PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&PreAuthKeyACLTag{}); result.Error != nil {
|
if result := db.Unscoped().Where(types.PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&types.PreAuthKeyACLTag{}); result.Error != nil {
|
||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -161,7 +136,7 @@ func (hsdb *HSDatabase) DestroyPreAuthKey(pak PreAuthKey) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
|
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
|
||||||
func (hsdb *HSDatabase) ExpirePreAuthKey(k *PreAuthKey) error {
|
func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error {
|
||||||
if err := hsdb.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
|
if err := hsdb.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -170,7 +145,7 @@ func (hsdb *HSDatabase) ExpirePreAuthKey(k *PreAuthKey) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// UsePreAuthKey marks a PreAuthKey as used.
|
// UsePreAuthKey marks a PreAuthKey as used.
|
||||||
func (hsdb *HSDatabase) UsePreAuthKey(k *PreAuthKey) error {
|
func (hsdb *HSDatabase) UsePreAuthKey(k *types.PreAuthKey) error {
|
||||||
k.Used = true
|
k.Used = true
|
||||||
if err := hsdb.db.Save(k).Error; err != nil {
|
if err := hsdb.db.Save(k).Error; err != nil {
|
||||||
return fmt.Errorf("failed to update key used status in the database: %w", err)
|
return fmt.Errorf("failed to update key used status in the database: %w", err)
|
||||||
|
@ -179,10 +154,10 @@ func (hsdb *HSDatabase) UsePreAuthKey(k *PreAuthKey) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkKeyValidity does the heavy lifting for validation of the PreAuthKey coming from a node
|
// ValidatePreAuthKey does the heavy lifting for validation of the PreAuthKey coming from a node
|
||||||
// If returns no error and a PreAuthKey, it can be used.
|
// If returns no error and a PreAuthKey, it can be used.
|
||||||
func (hsdb *HSDatabase) checkKeyValidity(k string) (*PreAuthKey, error) {
|
func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) {
|
||||||
pak := PreAuthKey{}
|
pak := types.PreAuthKey{}
|
||||||
if result := hsdb.db.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is(
|
if result := hsdb.db.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is(
|
||||||
result.Error,
|
result.Error,
|
||||||
gorm.ErrRecordNotFound,
|
gorm.ErrRecordNotFound,
|
||||||
|
@ -198,8 +173,8 @@ func (hsdb *HSDatabase) checkKeyValidity(k string) (*PreAuthKey, error) {
|
||||||
return &pak, nil
|
return &pak, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
machines := []Machine{}
|
machines := types.Machines{}
|
||||||
if err := hsdb.db.Preload("AuthKey").Where(&Machine{AuthKeyID: uint(pak.ID)}).Find(&machines).Error; err != nil {
|
if err := hsdb.db.Preload("AuthKey").Where(&types.Machine{AuthKeyID: uint(pak.ID)}).Find(&machines).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -219,29 +194,3 @@ func (hsdb *HSDatabase) generateKey() (string, error) {
|
||||||
|
|
||||||
return hex.EncodeToString(bytes), nil
|
return hex.EncodeToString(bytes), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (key *PreAuthKey) toProto() *v1.PreAuthKey {
|
|
||||||
protoKey := v1.PreAuthKey{
|
|
||||||
User: key.User.Name,
|
|
||||||
Id: strconv.FormatUint(key.ID, util.Base10),
|
|
||||||
Key: key.Key,
|
|
||||||
Ephemeral: key.Ephemeral,
|
|
||||||
Reusable: key.Reusable,
|
|
||||||
Used: key.Used,
|
|
||||||
AclTags: make([]string, len(key.ACLTags)),
|
|
||||||
}
|
|
||||||
|
|
||||||
if key.Expiration != nil {
|
|
||||||
protoKey.Expiration = timestamppb.New(*key.Expiration)
|
|
||||||
}
|
|
||||||
|
|
||||||
if key.CreatedAt != nil {
|
|
||||||
protoKey.CreatedAt = timestamppb.New(*key.CreatedAt)
|
|
||||||
}
|
|
||||||
|
|
||||||
for idx := range key.ACLTags {
|
|
||||||
protoKey.AclTags[idx] = key.ACLTags[idx].Tag
|
|
||||||
}
|
|
||||||
|
|
||||||
return &protoKey
|
|
||||||
}
|
|
|
@ -1,20 +1,22 @@
|
||||||
package hscontrol
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"gopkg.in/check.v1"
|
"gopkg.in/check.v1"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (*Suite) TestCreatePreAuthKey(c *check.C) {
|
func (*Suite) TestCreatePreAuthKey(c *check.C) {
|
||||||
_, err := app.db.CreatePreAuthKey("bogus", true, false, nil, nil)
|
_, err := db.CreatePreAuthKey("bogus", true, false, nil, nil)
|
||||||
|
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
user, err := app.db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
key, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
key, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
// Did we get a valid key?
|
// Did we get a valid key?
|
||||||
|
@ -24,10 +26,10 @@ func (*Suite) TestCreatePreAuthKey(c *check.C) {
|
||||||
// Make sure the User association is populated
|
// Make sure the User association is populated
|
||||||
c.Assert(key.User.Name, check.Equals, user.Name)
|
c.Assert(key.User.Name, check.Equals, user.Name)
|
||||||
|
|
||||||
_, err = app.db.ListPreAuthKeys("bogus")
|
_, err = db.ListPreAuthKeys("bogus")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
keys, err := app.db.ListPreAuthKeys(user.Name)
|
keys, err := db.ListPreAuthKeys(user.Name)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(keys), check.Equals, 1)
|
c.Assert(len(keys), check.Equals, 1)
|
||||||
|
|
||||||
|
@ -36,174 +38,176 @@ func (*Suite) TestCreatePreAuthKey(c *check.C) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestExpiredPreAuthKey(c *check.C) {
|
func (*Suite) TestExpiredPreAuthKey(c *check.C) {
|
||||||
user, err := app.db.CreateUser("test2")
|
user, err := db.CreateUser("test2")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
pak, err := app.db.CreatePreAuthKey(user.Name, true, false, &now, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, true, false, &now, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
key, err := app.db.checkKeyValidity(pak.Key)
|
key, err := db.ValidatePreAuthKey(pak.Key)
|
||||||
c.Assert(err, check.Equals, ErrPreAuthKeyExpired)
|
c.Assert(err, check.Equals, ErrPreAuthKeyExpired)
|
||||||
c.Assert(key, check.IsNil)
|
c.Assert(key, check.IsNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestPreAuthKeyDoesNotExist(c *check.C) {
|
func (*Suite) TestPreAuthKeyDoesNotExist(c *check.C) {
|
||||||
key, err := app.db.checkKeyValidity("potatoKey")
|
key, err := db.ValidatePreAuthKey("potatoKey")
|
||||||
c.Assert(err, check.Equals, ErrPreAuthKeyNotFound)
|
c.Assert(err, check.Equals, ErrPreAuthKeyNotFound)
|
||||||
c.Assert(key, check.IsNil)
|
c.Assert(key, check.IsNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestValidateKeyOk(c *check.C) {
|
func (*Suite) TestValidateKeyOk(c *check.C) {
|
||||||
user, err := app.db.CreateUser("test3")
|
user, err := db.CreateUser("test3")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
key, err := app.db.checkKeyValidity(pak.Key)
|
key, err := db.ValidatePreAuthKey(pak.Key)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(key.ID, check.Equals, pak.ID)
|
c.Assert(key.ID, check.Equals, pak.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestAlreadyUsedKey(c *check.C) {
|
func (*Suite) TestAlreadyUsedKey(c *check.C) {
|
||||||
user, err := app.db.CreateUser("test4")
|
user, err := db.CreateUser("test4")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
machine := Machine{
|
machine := types.Machine{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: "foo",
|
MachineKey: "foo",
|
||||||
NodeKey: "bar",
|
NodeKey: "bar",
|
||||||
DiscoKey: "faa",
|
DiscoKey: "faa",
|
||||||
Hostname: "testest",
|
Hostname: "testest",
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
app.db.db.Save(&machine)
|
db.db.Save(&machine)
|
||||||
|
|
||||||
key, err := app.db.checkKeyValidity(pak.Key)
|
key, err := db.ValidatePreAuthKey(pak.Key)
|
||||||
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)
|
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)
|
||||||
c.Assert(key, check.IsNil)
|
c.Assert(key, check.IsNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestReusableBeingUsedKey(c *check.C) {
|
func (*Suite) TestReusableBeingUsedKey(c *check.C) {
|
||||||
user, err := app.db.CreateUser("test5")
|
user, err := db.CreateUser("test5")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
machine := Machine{
|
machine := types.Machine{
|
||||||
ID: 1,
|
ID: 1,
|
||||||
MachineKey: "foo",
|
MachineKey: "foo",
|
||||||
NodeKey: "bar",
|
NodeKey: "bar",
|
||||||
DiscoKey: "faa",
|
DiscoKey: "faa",
|
||||||
Hostname: "testest",
|
Hostname: "testest",
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
app.db.db.Save(&machine)
|
db.db.Save(&machine)
|
||||||
|
|
||||||
key, err := app.db.checkKeyValidity(pak.Key)
|
key, err := db.ValidatePreAuthKey(pak.Key)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(key.ID, check.Equals, pak.ID)
|
c.Assert(key.ID, check.Equals, pak.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) {
|
func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) {
|
||||||
user, err := app.db.CreateUser("test6")
|
user, err := db.CreateUser("test6")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
key, err := app.db.checkKeyValidity(pak.Key)
|
key, err := db.ValidatePreAuthKey(pak.Key)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(key.ID, check.Equals, pak.ID)
|
c.Assert(key.ID, check.Equals, pak.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestEphemeralKey(c *check.C) {
|
func (*Suite) TestEphemeralKey(c *check.C) {
|
||||||
user, err := app.db.CreateUser("test7")
|
user, err := db.CreateUser("test7")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.db.CreatePreAuthKey(user.Name, false, true, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, false, true, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now().Add(-time.Second * 30)
|
||||||
machine := Machine{
|
machine := types.Machine{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: "foo",
|
MachineKey: "foo",
|
||||||
NodeKey: "bar",
|
NodeKey: "bar",
|
||||||
DiscoKey: "faa",
|
DiscoKey: "faa",
|
||||||
Hostname: "testest",
|
Hostname: "testest",
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
LastSeen: &now,
|
LastSeen: &now,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
app.db.db.Save(&machine)
|
db.db.Save(&machine)
|
||||||
|
|
||||||
_, err = app.db.checkKeyValidity(pak.Key)
|
_, err = db.ValidatePreAuthKey(pak.Key)
|
||||||
// Ephemeral keys are by definition reusable
|
// Ephemeral keys are by definition reusable
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.db.GetMachine("test7", "testest")
|
_, err = db.GetMachine("test7", "testest")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
app.expireEphemeralNodesWorker()
|
db.ExpireEphemeralMachines(time.Second * 20)
|
||||||
|
|
||||||
// The machine record should have been deleted
|
// The machine record should have been deleted
|
||||||
_, err = app.db.GetMachine("test7", "testest")
|
_, err = db.GetMachine("test7", "testest")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
|
c.Assert(channelUpdates, check.Equals, int32(1))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestExpirePreauthKey(c *check.C) {
|
func (*Suite) TestExpirePreauthKey(c *check.C) {
|
||||||
user, err := app.db.CreateUser("test3")
|
user, err := db.CreateUser("test3")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(pak.Expiration, check.IsNil)
|
c.Assert(pak.Expiration, check.IsNil)
|
||||||
|
|
||||||
err = app.db.ExpirePreAuthKey(pak)
|
err = db.ExpirePreAuthKey(pak)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(pak.Expiration, check.NotNil)
|
c.Assert(pak.Expiration, check.NotNil)
|
||||||
|
|
||||||
key, err := app.db.checkKeyValidity(pak.Key)
|
key, err := db.ValidatePreAuthKey(pak.Key)
|
||||||
c.Assert(err, check.Equals, ErrPreAuthKeyExpired)
|
c.Assert(err, check.Equals, ErrPreAuthKeyExpired)
|
||||||
c.Assert(key, check.IsNil)
|
c.Assert(key, check.IsNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) {
|
func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) {
|
||||||
user, err := app.db.CreateUser("test6")
|
user, err := db.CreateUser("test6")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
pak.Used = true
|
pak.Used = true
|
||||||
app.db.db.Save(&pak)
|
db.db.Save(&pak)
|
||||||
|
|
||||||
_, err = app.db.checkKeyValidity(pak.Key)
|
_, err = db.ValidatePreAuthKey(pak.Key)
|
||||||
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)
|
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestPreAuthKeyACLTags(c *check.C) {
|
func (*Suite) TestPreAuthKeyACLTags(c *check.C) {
|
||||||
user, err := app.db.CreateUser("test8")
|
user, err := db.CreateUser("test8")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.db.CreatePreAuthKey(user.Name, false, false, nil, []string{"badtag"})
|
_, err = db.CreatePreAuthKey(user.Name, false, false, nil, []string{"badtag"})
|
||||||
c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected
|
c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected
|
||||||
|
|
||||||
tags := []string{"tag:test1", "tag:test2"}
|
tags := []string{"tag:test1", "tag:test2"}
|
||||||
tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"}
|
tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"}
|
||||||
_, err = app.db.CreatePreAuthKey(user.Name, false, false, nil, tagsWithDuplicate)
|
_, err = db.CreatePreAuthKey(user.Name, false, false, nil, tagsWithDuplicate)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
listedPaks, err := app.db.ListPreAuthKeys("test8")
|
listedPaks, err := db.ListPreAuthKeys("test8")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(listedPaks[0].toProto().AclTags, check.DeepEquals, tags)
|
c.Assert(listedPaks[0].Proto().AclTags, check.DeepEquals, tags)
|
||||||
}
|
}
|
|
@ -1,55 +1,19 @@
|
||||||
package hscontrol
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
"github.com/juanfont/headscale/hscontrol/policy"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var ErrRouteIsNotAvailable = errors.New("route is not available")
|
||||||
ErrRouteIsNotAvailable = errors.New("route is not available")
|
|
||||||
ExitRouteV4 = netip.MustParsePrefix("0.0.0.0/0")
|
|
||||||
ExitRouteV6 = netip.MustParsePrefix("::/0")
|
|
||||||
)
|
|
||||||
|
|
||||||
type Route struct {
|
func (hsdb *HSDatabase) GetRoutes() (types.Routes, error) {
|
||||||
gorm.Model
|
var routes types.Routes
|
||||||
|
|
||||||
MachineID uint64
|
|
||||||
Machine Machine
|
|
||||||
Prefix IPPrefix
|
|
||||||
|
|
||||||
Advertised bool
|
|
||||||
Enabled bool
|
|
||||||
IsPrimary bool
|
|
||||||
}
|
|
||||||
|
|
||||||
type Routes []Route
|
|
||||||
|
|
||||||
func (r *Route) String() string {
|
|
||||||
return fmt.Sprintf("%s:%s", r.Machine, netip.Prefix(r.Prefix).String())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Route) isExitRoute() bool {
|
|
||||||
return netip.Prefix(r.Prefix) == ExitRouteV4 || netip.Prefix(r.Prefix) == ExitRouteV6
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rs Routes) toPrefixes() []netip.Prefix {
|
|
||||||
prefixes := make([]netip.Prefix, len(rs))
|
|
||||||
for i, r := range rs {
|
|
||||||
prefixes[i] = netip.Prefix(r.Prefix)
|
|
||||||
}
|
|
||||||
|
|
||||||
return prefixes
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hsdb *HSDatabase) GetRoutes() ([]Route, error) {
|
|
||||||
var routes []Route
|
|
||||||
err := hsdb.db.Preload("Machine").Find(&routes).Error
|
err := hsdb.db.Preload("Machine").Find(&routes).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -58,8 +22,21 @@ func (hsdb *HSDatabase) GetRoutes() ([]Route, error) {
|
||||||
return routes, nil
|
return routes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) GetMachineRoutes(m *Machine) ([]Route, error) {
|
func (hsdb *HSDatabase) GetMachineAdvertisedRoutes(machine *types.Machine) (types.Routes, error) {
|
||||||
var routes []Route
|
var routes types.Routes
|
||||||
|
err := hsdb.db.
|
||||||
|
Preload("Machine").
|
||||||
|
Where("machine_id = ? AND advertised = true", machine.ID).
|
||||||
|
Find(&routes).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return routes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hsdb *HSDatabase) GetMachineRoutes(m *types.Machine) (types.Routes, error) {
|
||||||
|
var routes types.Routes
|
||||||
err := hsdb.db.
|
err := hsdb.db.
|
||||||
Preload("Machine").
|
Preload("Machine").
|
||||||
Where("machine_id = ?", m.ID).
|
Where("machine_id = ?", m.ID).
|
||||||
|
@ -71,8 +48,8 @@ func (hsdb *HSDatabase) GetMachineRoutes(m *Machine) ([]Route, error) {
|
||||||
return routes, nil
|
return routes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) GetRoute(id uint64) (*Route, error) {
|
func (hsdb *HSDatabase) GetRoute(id uint64) (*types.Route, error) {
|
||||||
var route Route
|
var route types.Route
|
||||||
err := hsdb.db.Preload("Machine").First(&route, id).Error
|
err := hsdb.db.Preload("Machine").First(&route, id).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -90,8 +67,12 @@ func (hsdb *HSDatabase) EnableRoute(id uint64) error {
|
||||||
// Tailscale requires both IPv4 and IPv6 exit routes to
|
// Tailscale requires both IPv4 and IPv6 exit routes to
|
||||||
// be enabled at the same time, as per
|
// be enabled at the same time, as per
|
||||||
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
|
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
|
||||||
if route.isExitRoute() {
|
if route.IsExitRoute() {
|
||||||
return hsdb.enableRoutes(&route.Machine, ExitRouteV4.String(), ExitRouteV6.String())
|
return hsdb.enableRoutes(
|
||||||
|
&route.Machine,
|
||||||
|
types.ExitRouteV4.String(),
|
||||||
|
types.ExitRouteV6.String(),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
return hsdb.enableRoutes(&route.Machine, netip.Prefix(route.Prefix).String())
|
return hsdb.enableRoutes(&route.Machine, netip.Prefix(route.Prefix).String())
|
||||||
|
@ -106,7 +87,7 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error {
|
||||||
// Tailscale requires both IPv4 and IPv6 exit routes to
|
// Tailscale requires both IPv4 and IPv6 exit routes to
|
||||||
// be enabled at the same time, as per
|
// be enabled at the same time, as per
|
||||||
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
|
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
|
||||||
if !route.isExitRoute() {
|
if !route.IsExitRoute() {
|
||||||
route.Enabled = false
|
route.Enabled = false
|
||||||
route.IsPrimary = false
|
route.IsPrimary = false
|
||||||
err = hsdb.db.Save(route).Error
|
err = hsdb.db.Save(route).Error
|
||||||
|
@ -114,7 +95,7 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return hsdb.handlePrimarySubnetFailover()
|
return hsdb.HandlePrimarySubnetFailover()
|
||||||
}
|
}
|
||||||
|
|
||||||
routes, err := hsdb.GetMachineRoutes(&route.Machine)
|
routes, err := hsdb.GetMachineRoutes(&route.Machine)
|
||||||
|
@ -123,7 +104,7 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range routes {
|
for i := range routes {
|
||||||
if routes[i].isExitRoute() {
|
if routes[i].IsExitRoute() {
|
||||||
routes[i].Enabled = false
|
routes[i].Enabled = false
|
||||||
routes[i].IsPrimary = false
|
routes[i].IsPrimary = false
|
||||||
err = hsdb.db.Save(&routes[i]).Error
|
err = hsdb.db.Save(&routes[i]).Error
|
||||||
|
@ -133,7 +114,7 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return hsdb.handlePrimarySubnetFailover()
|
return hsdb.HandlePrimarySubnetFailover()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
|
func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
|
||||||
|
@ -145,12 +126,12 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
|
||||||
// Tailscale requires both IPv4 and IPv6 exit routes to
|
// Tailscale requires both IPv4 and IPv6 exit routes to
|
||||||
// be enabled at the same time, as per
|
// be enabled at the same time, as per
|
||||||
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
|
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
|
||||||
if !route.isExitRoute() {
|
if !route.IsExitRoute() {
|
||||||
if err := hsdb.db.Unscoped().Delete(&route).Error; err != nil {
|
if err := hsdb.db.Unscoped().Delete(&route).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return hsdb.handlePrimarySubnetFailover()
|
return hsdb.HandlePrimarySubnetFailover()
|
||||||
}
|
}
|
||||||
|
|
||||||
routes, err := hsdb.GetMachineRoutes(&route.Machine)
|
routes, err := hsdb.GetMachineRoutes(&route.Machine)
|
||||||
|
@ -158,9 +139,9 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
routesToDelete := []Route{}
|
routesToDelete := types.Routes{}
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if r.isExitRoute() {
|
if r.IsExitRoute() {
|
||||||
routesToDelete = append(routesToDelete, r)
|
routesToDelete = append(routesToDelete, r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -169,10 +150,10 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return hsdb.handlePrimarySubnetFailover()
|
return hsdb.HandlePrimarySubnetFailover()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) DeleteMachineRoutes(m *Machine) error {
|
func (hsdb *HSDatabase) DeleteMachineRoutes(m *types.Machine) error {
|
||||||
routes, err := hsdb.GetMachineRoutes(m)
|
routes, err := hsdb.GetMachineRoutes(m)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -184,14 +165,14 @@ func (hsdb *HSDatabase) DeleteMachineRoutes(m *Machine) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return hsdb.handlePrimarySubnetFailover()
|
return hsdb.HandlePrimarySubnetFailover()
|
||||||
}
|
}
|
||||||
|
|
||||||
// isUniquePrefix returns if there is another machine providing the same route already.
|
// isUniquePrefix returns if there is another machine providing the same route already.
|
||||||
func (hsdb *HSDatabase) isUniquePrefix(route Route) bool {
|
func (hsdb *HSDatabase) isUniquePrefix(route types.Route) bool {
|
||||||
var count int64
|
var count int64
|
||||||
hsdb.db.
|
hsdb.db.
|
||||||
Model(&Route{}).
|
Model(&types.Route{}).
|
||||||
Where("prefix = ? AND machine_id != ? AND advertised = ? AND enabled = ?",
|
Where("prefix = ? AND machine_id != ? AND advertised = ? AND enabled = ?",
|
||||||
route.Prefix,
|
route.Prefix,
|
||||||
route.MachineID,
|
route.MachineID,
|
||||||
|
@ -200,11 +181,11 @@ func (hsdb *HSDatabase) isUniquePrefix(route Route) bool {
|
||||||
return count == 0
|
return count == 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*Route, error) {
|
func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*types.Route, error) {
|
||||||
var route Route
|
var route types.Route
|
||||||
err := hsdb.db.
|
err := hsdb.db.
|
||||||
Preload("Machine").
|
Preload("Machine").
|
||||||
Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", IPPrefix(prefix), true, true, true).
|
Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", types.IPPrefix(prefix), true, true, true).
|
||||||
First(&route).Error
|
First(&route).Error
|
||||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -219,8 +200,8 @@ func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*Route, error) {
|
||||||
|
|
||||||
// getMachinePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover)
|
// getMachinePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover)
|
||||||
// Exit nodes are not considered for this, as they are never marked as Primary.
|
// Exit nodes are not considered for this, as they are never marked as Primary.
|
||||||
func (hsdb *HSDatabase) getMachinePrimaryRoutes(m *Machine) ([]Route, error) {
|
func (hsdb *HSDatabase) GetMachinePrimaryRoutes(m *types.Machine) (types.Routes, error) {
|
||||||
var routes []Route
|
var routes types.Routes
|
||||||
err := hsdb.db.
|
err := hsdb.db.
|
||||||
Preload("Machine").
|
Preload("Machine").
|
||||||
Where("machine_id = ? AND advertised = ? AND enabled = ? AND is_primary = ?", m.ID, true, true, true).
|
Where("machine_id = ? AND advertised = ? AND enabled = ? AND is_primary = ?", m.ID, true, true, true).
|
||||||
|
@ -232,8 +213,8 @@ func (hsdb *HSDatabase) getMachinePrimaryRoutes(m *Machine) ([]Route, error) {
|
||||||
return routes, nil
|
return routes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) processMachineRoutes(machine *Machine) error {
|
func (hsdb *HSDatabase) ProcessMachineRoutes(machine *types.Machine) error {
|
||||||
currentRoutes := []Route{}
|
currentRoutes := types.Routes{}
|
||||||
err := hsdb.db.Where("machine_id = ?", machine.ID).Find(¤tRoutes).Error
|
err := hsdb.db.Where("machine_id = ?", machine.ID).Find(¤tRoutes).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -266,9 +247,9 @@ func (hsdb *HSDatabase) processMachineRoutes(machine *Machine) error {
|
||||||
|
|
||||||
for prefix, exists := range advertisedRoutes {
|
for prefix, exists := range advertisedRoutes {
|
||||||
if !exists {
|
if !exists {
|
||||||
route := Route{
|
route := types.Route{
|
||||||
MachineID: machine.ID,
|
MachineID: machine.ID,
|
||||||
Prefix: IPPrefix(prefix),
|
Prefix: types.IPPrefix(prefix),
|
||||||
Advertised: true,
|
Advertised: true,
|
||||||
Enabled: false,
|
Enabled: false,
|
||||||
}
|
}
|
||||||
|
@ -282,9 +263,9 @@ func (hsdb *HSDatabase) processMachineRoutes(machine *Machine) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) handlePrimarySubnetFailover() error {
|
func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error {
|
||||||
// first, get all the enabled routes
|
// first, get all the enabled routes
|
||||||
var routes []Route
|
var routes types.Routes
|
||||||
err := hsdb.db.
|
err := hsdb.db.
|
||||||
Preload("Machine").
|
Preload("Machine").
|
||||||
Where("advertised = ? AND enabled = ?", true, true).
|
Where("advertised = ? AND enabled = ?", true, true).
|
||||||
|
@ -295,7 +276,7 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error {
|
||||||
|
|
||||||
routesChanged := false
|
routesChanged := false
|
||||||
for pos, route := range routes {
|
for pos, route := range routes {
|
||||||
if route.isExitRoute() {
|
if route.IsExitRoute() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -321,7 +302,7 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if route.IsPrimary {
|
if route.IsPrimary {
|
||||||
if route.Machine.isOnline() {
|
if route.Machine.IsOnline() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -332,7 +313,7 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error {
|
||||||
Msgf("machine offline, finding a new primary subnet")
|
Msgf("machine offline, finding a new primary subnet")
|
||||||
|
|
||||||
// find a new primary route
|
// find a new primary route
|
||||||
var newPrimaryRoutes []Route
|
var newPrimaryRoutes types.Routes
|
||||||
err := hsdb.db.
|
err := hsdb.db.
|
||||||
Preload("Machine").
|
Preload("Machine").
|
||||||
Where("prefix = ? AND machine_id != ? AND advertised = ? AND enabled = ?",
|
Where("prefix = ? AND machine_id != ? AND advertised = ? AND enabled = ?",
|
||||||
|
@ -346,9 +327,9 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var newPrimaryRoute *Route
|
var newPrimaryRoute *types.Route
|
||||||
for pos, r := range newPrimaryRoutes {
|
for pos, r := range newPrimaryRoutes {
|
||||||
if r.Machine.isOnline() {
|
if r.Machine.IsOnline() {
|
||||||
newPrimaryRoute = &newPrimaryRoutes[pos]
|
newPrimaryRoute = &newPrimaryRoutes[pos]
|
||||||
|
|
||||||
break
|
break
|
||||||
|
@ -399,27 +380,78 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rs Routes) toProto() []*v1.Route {
|
// EnableAutoApprovedRoutes enables any routes advertised by a machine that match the ACL autoApprovers policy.
|
||||||
protoRoutes := []*v1.Route{}
|
func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
|
||||||
|
aclPolicy *policy.ACLPolicy,
|
||||||
for _, route := range rs {
|
machine *types.Machine,
|
||||||
protoRoute := v1.Route{
|
) error {
|
||||||
Id: uint64(route.ID),
|
if len(machine.IPAddresses) == 0 {
|
||||||
Machine: route.Machine.toProto(),
|
return nil // This machine has no IPAddresses, so can't possibly match any autoApprovers ACLs
|
||||||
Prefix: netip.Prefix(route.Prefix).String(),
|
|
||||||
Advertised: route.Advertised,
|
|
||||||
Enabled: route.Enabled,
|
|
||||||
IsPrimary: route.IsPrimary,
|
|
||||||
CreatedAt: timestamppb.New(route.CreatedAt),
|
|
||||||
UpdatedAt: timestamppb.New(route.UpdatedAt),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if route.DeletedAt.Valid {
|
routes, err := hsdb.GetMachineAdvertisedRoutes(machine)
|
||||||
protoRoute.DeletedAt = timestamppb.New(route.DeletedAt.Time)
|
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
log.Error().
|
||||||
|
Caller().
|
||||||
|
Err(err).
|
||||||
|
Str("machine", machine.Hostname).
|
||||||
|
Msg("Could not get advertised routes for machine")
|
||||||
|
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
protoRoutes = append(protoRoutes, &protoRoute)
|
approvedRoutes := types.Routes{}
|
||||||
|
|
||||||
|
for _, advertisedRoute := range routes {
|
||||||
|
if advertisedRoute.Enabled {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
return protoRoutes
|
routeApprovers, err := aclPolicy.AutoApprovers.GetRouteApprovers(
|
||||||
|
netip.Prefix(advertisedRoute.Prefix),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
log.Err(err).
|
||||||
|
Str("advertisedRoute", advertisedRoute.String()).
|
||||||
|
Uint64("machineId", machine.ID).
|
||||||
|
Msg("Failed to resolve autoApprovers for advertised route")
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, approvedAlias := range routeApprovers {
|
||||||
|
if approvedAlias == machine.User.Name {
|
||||||
|
approvedRoutes = append(approvedRoutes, advertisedRoute)
|
||||||
|
} else {
|
||||||
|
// TODO(kradalby): figure out how to get this to depend on less stuff
|
||||||
|
approvedIps, err := aclPolicy.ExpandAlias(types.Machines{*machine}, approvedAlias, hsdb.stripEmailDomain)
|
||||||
|
if err != nil {
|
||||||
|
log.Err(err).
|
||||||
|
Str("alias", approvedAlias).
|
||||||
|
Msg("Failed to expand alias when processing autoApprovers policy")
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// approvedIPs should contain all of machine's IPs if it matches the rule, so check for first
|
||||||
|
if approvedIps.Contains(machine.IPAddresses[0]) {
|
||||||
|
approvedRoutes = append(approvedRoutes, advertisedRoute)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, approvedRoute := range approvedRoutes {
|
||||||
|
err := hsdb.EnableRoute(uint64(approvedRoute.ID))
|
||||||
|
if err != nil {
|
||||||
|
log.Err(err).
|
||||||
|
Str("approvedRoute", approvedRoute.String()).
|
||||||
|
Uint64("machineId", machine.ID).
|
||||||
|
Msg("Failed to enable approved route")
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
|
@ -1,9 +1,11 @@
|
||||||
package hscontrol
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/policy"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"gopkg.in/check.v1"
|
"gopkg.in/check.v1"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
|
@ -11,13 +13,13 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Suite) TestGetRoutes(c *check.C) {
|
func (s *Suite) TestGetRoutes(c *check.C) {
|
||||||
user, err := app.db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.db.GetMachine("test", "test_get_route_machine")
|
_, err = db.GetMachine("test", "test_get_route_machine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
route, err := netip.ParsePrefix("10.0.0.0/24")
|
route, err := netip.ParsePrefix("10.0.0.0/24")
|
||||||
|
@ -27,41 +29,43 @@ func (s *Suite) TestGetRoutes(c *check.C) {
|
||||||
RoutableIPs: []netip.Prefix{route},
|
RoutableIPs: []netip.Prefix{route},
|
||||||
}
|
}
|
||||||
|
|
||||||
machine := Machine{
|
machine := types.Machine{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: "foo",
|
MachineKey: "foo",
|
||||||
NodeKey: "bar",
|
NodeKey: "bar",
|
||||||
DiscoKey: "faa",
|
DiscoKey: "faa",
|
||||||
Hostname: "test_get_route_machine",
|
Hostname: "test_get_route_machine",
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
HostInfo: HostInfo(hostInfo),
|
HostInfo: types.HostInfo(hostInfo),
|
||||||
}
|
}
|
||||||
app.db.db.Save(&machine)
|
db.db.Save(&machine)
|
||||||
|
|
||||||
err = app.db.processMachineRoutes(&machine)
|
err = db.ProcessMachineRoutes(&machine)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
advertisedRoutes, err := app.db.GetAdvertisedRoutes(&machine)
|
advertisedRoutes, err := db.GetAdvertisedRoutes(&machine)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(advertisedRoutes), check.Equals, 1)
|
c.Assert(len(advertisedRoutes), check.Equals, 1)
|
||||||
|
|
||||||
err = app.db.enableRoutes(&machine, "192.168.0.0/24")
|
err = db.enableRoutes(&machine, "192.168.0.0/24")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
err = app.db.enableRoutes(&machine, "10.0.0.0/24")
|
err = db.enableRoutes(&machine, "10.0.0.0/24")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestGetEnableRoutes(c *check.C) {
|
func (s *Suite) TestGetEnableRoutes(c *check.C) {
|
||||||
user, err := app.db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.db.GetMachine("test", "test_enable_route_machine")
|
_, err = db.GetMachine("test", "test_enable_route_machine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
route, err := netip.ParsePrefix(
|
route, err := netip.ParsePrefix(
|
||||||
|
@ -78,65 +82,67 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
|
||||||
RoutableIPs: []netip.Prefix{route, route2},
|
RoutableIPs: []netip.Prefix{route, route2},
|
||||||
}
|
}
|
||||||
|
|
||||||
machine := Machine{
|
machine := types.Machine{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: "foo",
|
MachineKey: "foo",
|
||||||
NodeKey: "bar",
|
NodeKey: "bar",
|
||||||
DiscoKey: "faa",
|
DiscoKey: "faa",
|
||||||
Hostname: "test_enable_route_machine",
|
Hostname: "test_enable_route_machine",
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
HostInfo: HostInfo(hostInfo),
|
HostInfo: types.HostInfo(hostInfo),
|
||||||
}
|
}
|
||||||
app.db.db.Save(&machine)
|
db.db.Save(&machine)
|
||||||
|
|
||||||
err = app.db.processMachineRoutes(&machine)
|
err = db.ProcessMachineRoutes(&machine)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
availableRoutes, err := app.db.GetAdvertisedRoutes(&machine)
|
availableRoutes, err := db.GetAdvertisedRoutes(&machine)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(availableRoutes), check.Equals, 2)
|
c.Assert(len(availableRoutes), check.Equals, 2)
|
||||||
|
|
||||||
noEnabledRoutes, err := app.db.GetEnabledRoutes(&machine)
|
noEnabledRoutes, err := db.GetEnabledRoutes(&machine)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(noEnabledRoutes), check.Equals, 0)
|
c.Assert(len(noEnabledRoutes), check.Equals, 0)
|
||||||
|
|
||||||
err = app.db.enableRoutes(&machine, "192.168.0.0/24")
|
err = db.enableRoutes(&machine, "192.168.0.0/24")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
err = app.db.enableRoutes(&machine, "10.0.0.0/24")
|
err = db.enableRoutes(&machine, "10.0.0.0/24")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
enabledRoutes, err := app.db.GetEnabledRoutes(&machine)
|
enabledRoutes, err := db.GetEnabledRoutes(&machine)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(enabledRoutes), check.Equals, 1)
|
c.Assert(len(enabledRoutes), check.Equals, 1)
|
||||||
|
|
||||||
// Adding it twice will just let it pass through
|
// Adding it twice will just let it pass through
|
||||||
err = app.db.enableRoutes(&machine, "10.0.0.0/24")
|
err = db.enableRoutes(&machine, "10.0.0.0/24")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
enableRoutesAfterDoubleApply, err := app.db.GetEnabledRoutes(&machine)
|
enableRoutesAfterDoubleApply, err := db.GetEnabledRoutes(&machine)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(enableRoutesAfterDoubleApply), check.Equals, 1)
|
c.Assert(len(enableRoutesAfterDoubleApply), check.Equals, 1)
|
||||||
|
|
||||||
err = app.db.enableRoutes(&machine, "150.0.10.0/25")
|
err = db.enableRoutes(&machine, "150.0.10.0/25")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
enabledRoutesWithAdditionalRoute, err := app.db.GetEnabledRoutes(&machine)
|
enabledRoutesWithAdditionalRoute, err := db.GetEnabledRoutes(&machine)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(enabledRoutesWithAdditionalRoute), check.Equals, 2)
|
c.Assert(len(enabledRoutesWithAdditionalRoute), check.Equals, 2)
|
||||||
|
|
||||||
|
c.Assert(channelUpdates, check.Equals, int32(3))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestIsUniquePrefix(c *check.C) {
|
func (s *Suite) TestIsUniquePrefix(c *check.C) {
|
||||||
user, err := app.db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.db.GetMachine("test", "test_enable_route_machine")
|
_, err = db.GetMachine("test", "test_enable_route_machine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
route, err := netip.ParsePrefix(
|
route, err := netip.ParsePrefix(
|
||||||
|
@ -152,75 +158,77 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
|
||||||
hostInfo1 := tailcfg.Hostinfo{
|
hostInfo1 := tailcfg.Hostinfo{
|
||||||
RoutableIPs: []netip.Prefix{route, route2},
|
RoutableIPs: []netip.Prefix{route, route2},
|
||||||
}
|
}
|
||||||
machine1 := Machine{
|
machine1 := types.Machine{
|
||||||
ID: 1,
|
ID: 1,
|
||||||
MachineKey: "foo",
|
MachineKey: "foo",
|
||||||
NodeKey: "bar",
|
NodeKey: "bar",
|
||||||
DiscoKey: "faa",
|
DiscoKey: "faa",
|
||||||
Hostname: "test_enable_route_machine",
|
Hostname: "test_enable_route_machine",
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
HostInfo: HostInfo(hostInfo1),
|
HostInfo: types.HostInfo(hostInfo1),
|
||||||
}
|
}
|
||||||
app.db.db.Save(&machine1)
|
db.db.Save(&machine1)
|
||||||
|
|
||||||
err = app.db.processMachineRoutes(&machine1)
|
err = db.ProcessMachineRoutes(&machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = app.db.enableRoutes(&machine1, route.String())
|
err = db.enableRoutes(&machine1, route.String())
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = app.db.enableRoutes(&machine1, route2.String())
|
err = db.enableRoutes(&machine1, route2.String())
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
hostInfo2 := tailcfg.Hostinfo{
|
hostInfo2 := tailcfg.Hostinfo{
|
||||||
RoutableIPs: []netip.Prefix{route2},
|
RoutableIPs: []netip.Prefix{route2},
|
||||||
}
|
}
|
||||||
machine2 := Machine{
|
machine2 := types.Machine{
|
||||||
ID: 2,
|
ID: 2,
|
||||||
MachineKey: "foo",
|
MachineKey: "foo",
|
||||||
NodeKey: "bar",
|
NodeKey: "bar",
|
||||||
DiscoKey: "faa",
|
DiscoKey: "faa",
|
||||||
Hostname: "test_enable_route_machine",
|
Hostname: "test_enable_route_machine",
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
HostInfo: HostInfo(hostInfo2),
|
HostInfo: types.HostInfo(hostInfo2),
|
||||||
}
|
}
|
||||||
app.db.db.Save(&machine2)
|
db.db.Save(&machine2)
|
||||||
|
|
||||||
err = app.db.processMachineRoutes(&machine2)
|
err = db.ProcessMachineRoutes(&machine2)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = app.db.enableRoutes(&machine2, route2.String())
|
err = db.enableRoutes(&machine2, route2.String())
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1)
|
enabledRoutes1, err := db.GetEnabledRoutes(&machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(enabledRoutes1), check.Equals, 2)
|
c.Assert(len(enabledRoutes1), check.Equals, 2)
|
||||||
|
|
||||||
enabledRoutes2, err := app.db.GetEnabledRoutes(&machine2)
|
enabledRoutes2, err := db.GetEnabledRoutes(&machine2)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(enabledRoutes2), check.Equals, 1)
|
c.Assert(len(enabledRoutes2), check.Equals, 1)
|
||||||
|
|
||||||
routes, err := app.db.getMachinePrimaryRoutes(&machine1)
|
routes, err := db.GetMachinePrimaryRoutes(&machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(routes), check.Equals, 2)
|
c.Assert(len(routes), check.Equals, 2)
|
||||||
|
|
||||||
routes, err = app.db.getMachinePrimaryRoutes(&machine2)
|
routes, err = db.GetMachinePrimaryRoutes(&machine2)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(routes), check.Equals, 0)
|
c.Assert(len(routes), check.Equals, 0)
|
||||||
|
|
||||||
|
c.Assert(channelUpdates, check.Equals, int32(3))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestSubnetFailover(c *check.C) {
|
func (s *Suite) TestSubnetFailover(c *check.C) {
|
||||||
user, err := app.db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.db.GetMachine("test", "test_enable_route_machine")
|
_, err = db.GetMachine("test", "test_enable_route_machine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
prefix, err := netip.ParsePrefix(
|
prefix, err := netip.ParsePrefix(
|
||||||
|
@ -238,134 +246,136 @@ func (s *Suite) TestSubnetFailover(c *check.C) {
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
machine1 := Machine{
|
machine1 := types.Machine{
|
||||||
ID: 1,
|
ID: 1,
|
||||||
MachineKey: "foo",
|
MachineKey: "foo",
|
||||||
NodeKey: "bar",
|
NodeKey: "bar",
|
||||||
DiscoKey: "faa",
|
DiscoKey: "faa",
|
||||||
Hostname: "test_enable_route_machine",
|
Hostname: "test_enable_route_machine",
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
HostInfo: HostInfo(hostInfo1),
|
HostInfo: types.HostInfo(hostInfo1),
|
||||||
LastSeen: &now,
|
LastSeen: &now,
|
||||||
}
|
}
|
||||||
app.db.db.Save(&machine1)
|
db.db.Save(&machine1)
|
||||||
|
|
||||||
err = app.db.processMachineRoutes(&machine1)
|
err = db.ProcessMachineRoutes(&machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = app.db.enableRoutes(&machine1, prefix.String())
|
err = db.enableRoutes(&machine1, prefix.String())
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = app.db.enableRoutes(&machine1, prefix2.String())
|
err = db.enableRoutes(&machine1, prefix2.String())
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = app.db.handlePrimarySubnetFailover()
|
err = db.HandlePrimarySubnetFailover()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1)
|
enabledRoutes1, err := db.GetEnabledRoutes(&machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(enabledRoutes1), check.Equals, 2)
|
c.Assert(len(enabledRoutes1), check.Equals, 2)
|
||||||
|
|
||||||
route, err := app.db.getPrimaryRoute(prefix)
|
route, err := db.getPrimaryRoute(prefix)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(route.MachineID, check.Equals, machine1.ID)
|
c.Assert(route.MachineID, check.Equals, machine1.ID)
|
||||||
|
|
||||||
hostInfo2 := tailcfg.Hostinfo{
|
hostInfo2 := tailcfg.Hostinfo{
|
||||||
RoutableIPs: []netip.Prefix{prefix2},
|
RoutableIPs: []netip.Prefix{prefix2},
|
||||||
}
|
}
|
||||||
machine2 := Machine{
|
machine2 := types.Machine{
|
||||||
ID: 2,
|
ID: 2,
|
||||||
MachineKey: "foo",
|
MachineKey: "foo",
|
||||||
NodeKey: "bar",
|
NodeKey: "bar",
|
||||||
DiscoKey: "faa",
|
DiscoKey: "faa",
|
||||||
Hostname: "test_enable_route_machine",
|
Hostname: "test_enable_route_machine",
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
HostInfo: HostInfo(hostInfo2),
|
HostInfo: types.HostInfo(hostInfo2),
|
||||||
LastSeen: &now,
|
LastSeen: &now,
|
||||||
}
|
}
|
||||||
app.db.db.Save(&machine2)
|
db.db.Save(&machine2)
|
||||||
|
|
||||||
err = app.db.processMachineRoutes(&machine2)
|
err = db.ProcessMachineRoutes(&machine2)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = app.db.enableRoutes(&machine2, prefix2.String())
|
err = db.enableRoutes(&machine2, prefix2.String())
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = app.db.handlePrimarySubnetFailover()
|
err = db.HandlePrimarySubnetFailover()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
enabledRoutes1, err = app.db.GetEnabledRoutes(&machine1)
|
enabledRoutes1, err = db.GetEnabledRoutes(&machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(enabledRoutes1), check.Equals, 2)
|
c.Assert(len(enabledRoutes1), check.Equals, 2)
|
||||||
|
|
||||||
enabledRoutes2, err := app.db.GetEnabledRoutes(&machine2)
|
enabledRoutes2, err := db.GetEnabledRoutes(&machine2)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(enabledRoutes2), check.Equals, 1)
|
c.Assert(len(enabledRoutes2), check.Equals, 1)
|
||||||
|
|
||||||
routes, err := app.db.getMachinePrimaryRoutes(&machine1)
|
routes, err := db.GetMachinePrimaryRoutes(&machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(routes), check.Equals, 2)
|
c.Assert(len(routes), check.Equals, 2)
|
||||||
|
|
||||||
routes, err = app.db.getMachinePrimaryRoutes(&machine2)
|
routes, err = db.GetMachinePrimaryRoutes(&machine2)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(routes), check.Equals, 0)
|
c.Assert(len(routes), check.Equals, 0)
|
||||||
|
|
||||||
// lets make machine1 lastseen 10 mins ago
|
// lets make machine1 lastseen 10 mins ago
|
||||||
before := now.Add(-10 * time.Minute)
|
before := now.Add(-10 * time.Minute)
|
||||||
machine1.LastSeen = &before
|
machine1.LastSeen = &before
|
||||||
err = app.db.db.Save(&machine1).Error
|
err = db.db.Save(&machine1).Error
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = app.db.handlePrimarySubnetFailover()
|
err = db.HandlePrimarySubnetFailover()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
routes, err = app.db.getMachinePrimaryRoutes(&machine1)
|
routes, err = db.GetMachinePrimaryRoutes(&machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(routes), check.Equals, 1)
|
c.Assert(len(routes), check.Equals, 1)
|
||||||
|
|
||||||
routes, err = app.db.getMachinePrimaryRoutes(&machine2)
|
routes, err = db.GetMachinePrimaryRoutes(&machine2)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(routes), check.Equals, 1)
|
c.Assert(len(routes), check.Equals, 1)
|
||||||
|
|
||||||
machine2.HostInfo = HostInfo(tailcfg.Hostinfo{
|
machine2.HostInfo = types.HostInfo(tailcfg.Hostinfo{
|
||||||
RoutableIPs: []netip.Prefix{prefix, prefix2},
|
RoutableIPs: []netip.Prefix{prefix, prefix2},
|
||||||
})
|
})
|
||||||
err = app.db.db.Save(&machine2).Error
|
err = db.db.Save(&machine2).Error
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = app.db.processMachineRoutes(&machine2)
|
err = db.ProcessMachineRoutes(&machine2)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = app.db.enableRoutes(&machine2, prefix.String())
|
err = db.enableRoutes(&machine2, prefix.String())
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = app.db.handlePrimarySubnetFailover()
|
err = db.HandlePrimarySubnetFailover()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
routes, err = app.db.getMachinePrimaryRoutes(&machine1)
|
routes, err = db.GetMachinePrimaryRoutes(&machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(routes), check.Equals, 0)
|
c.Assert(len(routes), check.Equals, 0)
|
||||||
|
|
||||||
routes, err = app.db.getMachinePrimaryRoutes(&machine2)
|
routes, err = db.GetMachinePrimaryRoutes(&machine2)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(routes), check.Equals, 2)
|
c.Assert(len(routes), check.Equals, 2)
|
||||||
|
|
||||||
|
c.Assert(channelUpdates, check.Equals, int32(6))
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestAllowedIPRoutes tests that the AllowedIPs are correctly set for a node,
|
// TestAllowedIPRoutes tests that the AllowedIPs are correctly set for a node,
|
||||||
// including both the primary routes the node is responsible for, and the
|
// including both the primary routes the node is responsible for, and the
|
||||||
// exit node routes if enabled.
|
// exit node routes if enabled.
|
||||||
func (s *Suite) TestAllowedIPRoutes(c *check.C) {
|
func (s *Suite) TestAllowedIPRoutes(c *check.C) {
|
||||||
user, err := app.db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.db.GetMachine("test", "test_enable_route_machine")
|
_, err = db.GetMachine("test", "test_enable_route_machine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
prefix, err := netip.ParsePrefix(
|
prefix, err := netip.ParsePrefix(
|
||||||
|
@ -397,35 +407,35 @@ func (s *Suite) TestAllowedIPRoutes(c *check.C) {
|
||||||
machineKey := key.NewMachine()
|
machineKey := key.NewMachine()
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
machine1 := Machine{
|
machine1 := types.Machine{
|
||||||
ID: 1,
|
ID: 1,
|
||||||
MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()),
|
MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()),
|
||||||
NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()),
|
NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()),
|
||||||
DiscoKey: util.DiscoPublicKeyStripPrefix(discoKey.Public()),
|
DiscoKey: util.DiscoPublicKeyStripPrefix(discoKey.Public()),
|
||||||
Hostname: "test_enable_route_machine",
|
Hostname: "test_enable_route_machine",
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
HostInfo: HostInfo(hostInfo1),
|
HostInfo: types.HostInfo(hostInfo1),
|
||||||
LastSeen: &now,
|
LastSeen: &now,
|
||||||
}
|
}
|
||||||
app.db.db.Save(&machine1)
|
db.db.Save(&machine1)
|
||||||
|
|
||||||
err = app.db.processMachineRoutes(&machine1)
|
err = db.ProcessMachineRoutes(&machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = app.db.enableRoutes(&machine1, prefix.String())
|
err = db.enableRoutes(&machine1, prefix.String())
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
// We do not enable this one on purpose to test that it is not enabled
|
// We do not enable this one on purpose to test that it is not enabled
|
||||||
// err = app.db.enableRoutes(&machine1, prefix2.String())
|
// err = db.enableRoutes(&machine1, prefix2.String())
|
||||||
// c.Assert(err, check.IsNil)
|
// c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
routes, err := app.db.GetMachineRoutes(&machine1)
|
routes, err := db.GetMachineRoutes(&machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
if route.isExitRoute() {
|
if route.IsExitRoute() {
|
||||||
err = app.db.EnableRoute(uint64(route.ID))
|
err = db.EnableRoute(uint64(route.ID))
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
// We only enable one exit route, so we can test that both are enabled
|
// We only enable one exit route, so we can test that both are enabled
|
||||||
|
@ -433,14 +443,14 @@ func (s *Suite) TestAllowedIPRoutes(c *check.C) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = app.db.handlePrimarySubnetFailover()
|
err = db.HandlePrimarySubnetFailover()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1)
|
enabledRoutes1, err := db.GetEnabledRoutes(&machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(enabledRoutes1), check.Equals, 3)
|
c.Assert(len(enabledRoutes1), check.Equals, 3)
|
||||||
|
|
||||||
peer, err := app.db.toNode(machine1, app.aclPolicy, "headscale.net", nil)
|
peer, err := db.TailNode(machine1, &policy.ACLPolicy{}, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
c.Assert(len(peer.AllowedIPs), check.Equals, 3)
|
c.Assert(len(peer.AllowedIPs), check.Equals, 3)
|
||||||
|
@ -461,44 +471,46 @@ func (s *Suite) TestAllowedIPRoutes(c *check.C) {
|
||||||
|
|
||||||
// Now we disable only one of the exit routes
|
// Now we disable only one of the exit routes
|
||||||
// and we see if both are disabled
|
// and we see if both are disabled
|
||||||
var exitRouteV4 Route
|
var exitRouteV4 types.Route
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
if route.isExitRoute() && netip.Prefix(route.Prefix) == prefixExitNodeV4 {
|
if route.IsExitRoute() && netip.Prefix(route.Prefix) == prefixExitNodeV4 {
|
||||||
exitRouteV4 = route
|
exitRouteV4 = route
|
||||||
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = app.db.DisableRoute(uint64(exitRouteV4.ID))
|
err = db.DisableRoute(uint64(exitRouteV4.ID))
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
enabledRoutes1, err = app.db.GetEnabledRoutes(&machine1)
|
enabledRoutes1, err = db.GetEnabledRoutes(&machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(enabledRoutes1), check.Equals, 1)
|
c.Assert(len(enabledRoutes1), check.Equals, 1)
|
||||||
|
|
||||||
// and now we delete only one of the exit routes
|
// and now we delete only one of the exit routes
|
||||||
// and we check if both are deleted
|
// and we check if both are deleted
|
||||||
routes, err = app.db.GetMachineRoutes(&machine1)
|
routes, err = db.GetMachineRoutes(&machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(routes), check.Equals, 4)
|
c.Assert(len(routes), check.Equals, 4)
|
||||||
|
|
||||||
err = app.db.DeleteRoute(uint64(exitRouteV4.ID))
|
err = db.DeleteRoute(uint64(exitRouteV4.ID))
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
routes, err = app.db.GetMachineRoutes(&machine1)
|
routes, err = db.GetMachineRoutes(&machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(routes), check.Equals, 2)
|
c.Assert(len(routes), check.Equals, 2)
|
||||||
|
|
||||||
|
c.Assert(channelUpdates, check.Equals, int32(2))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestDeleteRoutes(c *check.C) {
|
func (s *Suite) TestDeleteRoutes(c *check.C) {
|
||||||
user, err := app.db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.db.GetMachine("test", "test_enable_route_machine")
|
_, err = db.GetMachine("test", "test_enable_route_machine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
prefix, err := netip.ParsePrefix(
|
prefix, err := netip.ParsePrefix(
|
||||||
|
@ -516,36 +528,38 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
machine1 := Machine{
|
machine1 := types.Machine{
|
||||||
ID: 1,
|
ID: 1,
|
||||||
MachineKey: "foo",
|
MachineKey: "foo",
|
||||||
NodeKey: "bar",
|
NodeKey: "bar",
|
||||||
DiscoKey: "faa",
|
DiscoKey: "faa",
|
||||||
Hostname: "test_enable_route_machine",
|
Hostname: "test_enable_route_machine",
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
HostInfo: HostInfo(hostInfo1),
|
HostInfo: types.HostInfo(hostInfo1),
|
||||||
LastSeen: &now,
|
LastSeen: &now,
|
||||||
}
|
}
|
||||||
app.db.db.Save(&machine1)
|
db.db.Save(&machine1)
|
||||||
|
|
||||||
err = app.db.processMachineRoutes(&machine1)
|
err = db.ProcessMachineRoutes(&machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = app.db.enableRoutes(&machine1, prefix.String())
|
err = db.enableRoutes(&machine1, prefix.String())
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = app.db.enableRoutes(&machine1, prefix2.String())
|
err = db.enableRoutes(&machine1, prefix2.String())
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
routes, err := app.db.GetMachineRoutes(&machine1)
|
routes, err := db.GetMachineRoutes(&machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = app.db.DeleteRoute(uint64(routes[0].ID))
|
err = db.DeleteRoute(uint64(routes[0].ID))
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1)
|
enabledRoutes1, err := db.GetEnabledRoutes(&machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(enabledRoutes1), check.Equals, 1)
|
c.Assert(len(enabledRoutes1), check.Equals, 1)
|
||||||
|
|
||||||
|
c.Assert(channelUpdates, check.Equals, int32(2))
|
||||||
}
|
}
|
74
hscontrol/db/suite_test.go
Normal file
74
hscontrol/db/suite_test.go
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gopkg.in/check.v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test(t *testing.T) {
|
||||||
|
check.TestingT(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ = check.Suite(&Suite{})
|
||||||
|
|
||||||
|
type Suite struct{}
|
||||||
|
|
||||||
|
var (
|
||||||
|
tmpDir string
|
||||||
|
db *HSDatabase
|
||||||
|
|
||||||
|
// channelUpdates counts the number of times
|
||||||
|
// either of the channels was notified.
|
||||||
|
channelUpdates int32
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Suite) SetUpTest(c *check.C) {
|
||||||
|
atomic.StoreInt32(&channelUpdates, 0)
|
||||||
|
s.ResetDB(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Suite) TearDownTest(c *check.C) {
|
||||||
|
os.RemoveAll(tmpDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
func notificationSink(c <-chan struct{}) {
|
||||||
|
for {
|
||||||
|
<-c
|
||||||
|
atomic.AddInt32(&channelUpdates, 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Suite) ResetDB(c *check.C) {
|
||||||
|
if len(tmpDir) != 0 {
|
||||||
|
os.RemoveAll(tmpDir)
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
tmpDir, err = os.MkdirTemp("", "autoygg-client-test")
|
||||||
|
if err != nil {
|
||||||
|
c.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sink := make(chan struct{})
|
||||||
|
|
||||||
|
go notificationSink(sink)
|
||||||
|
|
||||||
|
db, err = NewHeadscaleDatabase(
|
||||||
|
"sqlite3",
|
||||||
|
tmpDir+"/headscale_test.db",
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
sink,
|
||||||
|
sink,
|
||||||
|
[]netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.27.0.0/23"),
|
||||||
|
},
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
c.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,17 +1,12 @@
|
||||||
package hscontrol
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"regexp"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
)
|
)
|
||||||
|
@ -20,33 +15,16 @@ var (
|
||||||
ErrUserExists = errors.New("user already exists")
|
ErrUserExists = errors.New("user already exists")
|
||||||
ErrUserNotFound = errors.New("user not found")
|
ErrUserNotFound = errors.New("user not found")
|
||||||
ErrUserStillHasNodes = errors.New("user not empty: node(s) found")
|
ErrUserStillHasNodes = errors.New("user not empty: node(s) found")
|
||||||
ErrInvalidUserName = errors.New("invalid user name")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
// value related to RFC 1123 and 952.
|
|
||||||
labelHostnameLength = 63
|
|
||||||
)
|
|
||||||
|
|
||||||
var invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+")
|
|
||||||
|
|
||||||
// User is the way Headscale implements the concept of users in Tailscale
|
|
||||||
//
|
|
||||||
// At the end of the day, users in Tailscale are some kind of 'bubbles' or users
|
|
||||||
// that contain our machines.
|
|
||||||
type User struct {
|
|
||||||
gorm.Model
|
|
||||||
Name string `gorm:"unique"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateUser creates a new User. Returns error if could not be created
|
// CreateUser creates a new User. Returns error if could not be created
|
||||||
// or another user already exists.
|
// or another user already exists.
|
||||||
func (hsdb *HSDatabase) CreateUser(name string) (*User, error) {
|
func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) {
|
||||||
err := CheckForFQDNRules(name)
|
err := util.CheckForFQDNRules(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
user := User{}
|
user := types.User{}
|
||||||
if err := hsdb.db.Where("name = ?", name).First(&user).Error; err == nil {
|
if err := hsdb.db.Where("name = ?", name).First(&user).Error; err == nil {
|
||||||
return nil, ErrUserExists
|
return nil, ErrUserExists
|
||||||
}
|
}
|
||||||
|
@ -105,7 +83,7 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err = CheckForFQDNRules(newName)
|
err = util.CheckForFQDNRules(newName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -127,8 +105,8 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUser fetches a user by name.
|
// GetUser fetches a user by name.
|
||||||
func (hsdb *HSDatabase) GetUser(name string) (*User, error) {
|
func (hsdb *HSDatabase) GetUser(name string) (*types.User, error) {
|
||||||
user := User{}
|
user := types.User{}
|
||||||
if result := hsdb.db.First(&user, "name = ?", name); errors.Is(
|
if result := hsdb.db.First(&user, "name = ?", name); errors.Is(
|
||||||
result.Error,
|
result.Error,
|
||||||
gorm.ErrRecordNotFound,
|
gorm.ErrRecordNotFound,
|
||||||
|
@ -140,8 +118,8 @@ func (hsdb *HSDatabase) GetUser(name string) (*User, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListUsers gets all the existing users.
|
// ListUsers gets all the existing users.
|
||||||
func (hsdb *HSDatabase) ListUsers() ([]User, error) {
|
func (hsdb *HSDatabase) ListUsers() ([]types.User, error) {
|
||||||
users := []User{}
|
users := []types.User{}
|
||||||
if err := hsdb.db.Find(&users).Error; err != nil {
|
if err := hsdb.db.Find(&users).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -150,8 +128,8 @@ func (hsdb *HSDatabase) ListUsers() ([]User, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListMachinesByUser gets all the nodes in a given user.
|
// ListMachinesByUser gets all the nodes in a given user.
|
||||||
func (hsdb *HSDatabase) ListMachinesByUser(name string) ([]Machine, error) {
|
func (hsdb *HSDatabase) ListMachinesByUser(name string) (types.Machines, error) {
|
||||||
err := CheckForFQDNRules(name)
|
err := util.CheckForFQDNRules(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -160,8 +138,8 @@ func (hsdb *HSDatabase) ListMachinesByUser(name string) ([]Machine, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
machines := []Machine{}
|
machines := types.Machines{}
|
||||||
if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&Machine{UserID: user.ID}).Find(&machines).Error; err != nil {
|
if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Machine{UserID: user.ID}).Find(&machines).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -169,8 +147,8 @@ func (hsdb *HSDatabase) ListMachinesByUser(name string) ([]Machine, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetMachineUser assigns a Machine to a user.
|
// SetMachineUser assigns a Machine to a user.
|
||||||
func (hsdb *HSDatabase) SetMachineUser(machine *Machine, username string) error {
|
func (hsdb *HSDatabase) SetMachineUser(machine *types.Machine, username string) error {
|
||||||
err := CheckForFQDNRules(username)
|
err := util.CheckForFQDNRules(username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -186,37 +164,11 @@ func (hsdb *HSDatabase) SetMachineUser(machine *Machine, username string) error
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *User) toTailscaleUser() *tailcfg.User {
|
func (hsdb *HSDatabase) GetMapResponseUserProfiles(
|
||||||
user := tailcfg.User{
|
machine types.Machine,
|
||||||
ID: tailcfg.UserID(n.ID),
|
peers types.Machines,
|
||||||
LoginName: n.Name,
|
|
||||||
DisplayName: n.Name,
|
|
||||||
ProfilePicURL: "",
|
|
||||||
Domain: "headscale.net",
|
|
||||||
Logins: []tailcfg.LoginID{},
|
|
||||||
Created: time.Time{},
|
|
||||||
}
|
|
||||||
|
|
||||||
return &user
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *User) toTailscaleLogin() *tailcfg.Login {
|
|
||||||
login := tailcfg.Login{
|
|
||||||
ID: tailcfg.LoginID(n.ID),
|
|
||||||
LoginName: n.Name,
|
|
||||||
DisplayName: n.Name,
|
|
||||||
ProfilePicURL: "",
|
|
||||||
Domain: "headscale.net",
|
|
||||||
}
|
|
||||||
|
|
||||||
return &login
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hsdb *HSDatabase) getMapResponseUserProfiles(
|
|
||||||
machine Machine,
|
|
||||||
peers Machines,
|
|
||||||
) []tailcfg.UserProfile {
|
) []tailcfg.UserProfile {
|
||||||
userMap := make(map[string]User)
|
userMap := make(map[string]types.User)
|
||||||
userMap[machine.User.Name] = machine.User
|
userMap[machine.User.Name] = machine.User
|
||||||
for _, peer := range peers {
|
for _, peer := range peers {
|
||||||
userMap[peer.User.Name] = peer.User // not worth checking if already is there
|
userMap[peer.User.Name] = peer.User // not worth checking if already is there
|
||||||
|
@ -240,63 +192,3 @@ func (hsdb *HSDatabase) getMapResponseUserProfiles(
|
||||||
|
|
||||||
return profiles
|
return profiles
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *User) toProto() *v1.User {
|
|
||||||
return &v1.User{
|
|
||||||
Id: strconv.FormatUint(uint64(n.ID), util.Base10),
|
|
||||||
Name: n.Name,
|
|
||||||
CreatedAt: timestamppb.New(n.CreatedAt),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NormalizeToFQDNRules will replace forbidden chars in user
|
|
||||||
// it can also return an error if the user doesn't respect RFC 952 and 1123.
|
|
||||||
func NormalizeToFQDNRules(name string, stripEmailDomain bool) (string, error) {
|
|
||||||
name = strings.ToLower(name)
|
|
||||||
name = strings.ReplaceAll(name, "'", "")
|
|
||||||
atIdx := strings.Index(name, "@")
|
|
||||||
if stripEmailDomain && atIdx > 0 {
|
|
||||||
name = name[:atIdx]
|
|
||||||
} else {
|
|
||||||
name = strings.ReplaceAll(name, "@", ".")
|
|
||||||
}
|
|
||||||
name = invalidCharsInUserRegex.ReplaceAllString(name, "-")
|
|
||||||
|
|
||||||
for _, elt := range strings.Split(name, ".") {
|
|
||||||
if len(elt) > labelHostnameLength {
|
|
||||||
return "", fmt.Errorf(
|
|
||||||
"label %v is more than 63 chars: %w",
|
|
||||||
elt,
|
|
||||||
ErrInvalidUserName,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return name, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func CheckForFQDNRules(name string) error {
|
|
||||||
if len(name) > labelHostnameLength {
|
|
||||||
return fmt.Errorf(
|
|
||||||
"DNS segment must not be over 63 chars. %v doesn't comply with this rule: %w",
|
|
||||||
name,
|
|
||||||
ErrInvalidUserName,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
if strings.ToLower(name) != name {
|
|
||||||
return fmt.Errorf(
|
|
||||||
"DNS segment should be lowercase. %v doesn't comply with this rule: %w",
|
|
||||||
name,
|
|
||||||
ErrInvalidUserName,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
if invalidCharsInUserRegex.MatchString(name) {
|
|
||||||
return fmt.Errorf(
|
|
||||||
"DNS segment should only be composed of lowercase ASCII letters numbers, hyphen and dots. %v doesn't comply with theses rules: %w",
|
|
||||||
name,
|
|
||||||
ErrInvalidUserName,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
277
hscontrol/db/users_test.go
Normal file
277
hscontrol/db/users_test.go
Normal file
|
@ -0,0 +1,277 @@
|
||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
|
"gopkg.in/check.v1"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Suite) TestCreateAndDestroyUser(c *check.C) {
|
||||||
|
user, err := db.CreateUser("test")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(user.Name, check.Equals, "test")
|
||||||
|
|
||||||
|
users, err := db.ListUsers()
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(len(users), check.Equals, 1)
|
||||||
|
|
||||||
|
err = db.DestroyUser("test")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
_, err = db.GetUser("test")
|
||||||
|
c.Assert(err, check.NotNil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Suite) TestDestroyUserErrors(c *check.C) {
|
||||||
|
err := db.DestroyUser("test")
|
||||||
|
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||||
|
|
||||||
|
user, err := db.CreateUser("test")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
err = db.DestroyUser("test")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
result := db.db.Preload("User").First(&pak, "key = ?", pak.Key)
|
||||||
|
// destroying a user also deletes all associated preauthkeys
|
||||||
|
c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound)
|
||||||
|
|
||||||
|
user, err = db.CreateUser("test")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
pak, err = db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
machine := types.Machine{
|
||||||
|
ID: 0,
|
||||||
|
MachineKey: "foo",
|
||||||
|
NodeKey: "bar",
|
||||||
|
DiscoKey: "faa",
|
||||||
|
Hostname: "testmachine",
|
||||||
|
UserID: user.ID,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
AuthKeyID: uint(pak.ID),
|
||||||
|
}
|
||||||
|
db.db.Save(&machine)
|
||||||
|
|
||||||
|
err = db.DestroyUser("test")
|
||||||
|
c.Assert(err, check.Equals, ErrUserStillHasNodes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Suite) TestRenameUser(c *check.C) {
|
||||||
|
userTest, err := db.CreateUser("test")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(userTest.Name, check.Equals, "test")
|
||||||
|
|
||||||
|
users, err := db.ListUsers()
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(len(users), check.Equals, 1)
|
||||||
|
|
||||||
|
err = db.RenameUser("test", "test-renamed")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
_, err = db.GetUser("test")
|
||||||
|
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||||
|
|
||||||
|
_, err = db.GetUser("test-renamed")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
err = db.RenameUser("test-does-not-exit", "test")
|
||||||
|
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||||
|
|
||||||
|
userTest2, err := db.CreateUser("test2")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(userTest2.Name, check.Equals, "test2")
|
||||||
|
|
||||||
|
err = db.RenameUser("test2", "test-renamed")
|
||||||
|
c.Assert(err, check.Equals, ErrUserExists)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
|
||||||
|
userShared1, err := db.CreateUser("shared1")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
userShared2, err := db.CreateUser("shared2")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
userShared3, err := db.CreateUser("shared3")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
preAuthKeyShared1, err := db.CreatePreAuthKey(
|
||||||
|
userShared1.Name,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
preAuthKeyShared2, err := db.CreatePreAuthKey(
|
||||||
|
userShared2.Name,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
preAuthKeyShared3, err := db.CreatePreAuthKey(
|
||||||
|
userShared3.Name,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
preAuthKey2Shared1, err := db.CreatePreAuthKey(
|
||||||
|
userShared1.Name,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
_, err = db.GetMachine(userShared1.Name, "test_get_shared_nodes_1")
|
||||||
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
|
machineInShared1 := &types.Machine{
|
||||||
|
ID: 1,
|
||||||
|
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||||
|
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||||
|
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||||
|
Hostname: "test_get_shared_nodes_1",
|
||||||
|
UserID: userShared1.ID,
|
||||||
|
User: *userShared1,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
|
||||||
|
AuthKeyID: uint(preAuthKeyShared1.ID),
|
||||||
|
}
|
||||||
|
db.db.Save(machineInShared1)
|
||||||
|
|
||||||
|
_, err = db.GetMachine(userShared1.Name, machineInShared1.Hostname)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
machineInShared2 := &types.Machine{
|
||||||
|
ID: 2,
|
||||||
|
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
|
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
|
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
|
Hostname: "test_get_shared_nodes_2",
|
||||||
|
UserID: userShared2.ID,
|
||||||
|
User: *userShared2,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
|
||||||
|
AuthKeyID: uint(preAuthKeyShared2.ID),
|
||||||
|
}
|
||||||
|
db.db.Save(machineInShared2)
|
||||||
|
|
||||||
|
_, err = db.GetMachine(userShared2.Name, machineInShared2.Hostname)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
machineInShared3 := &types.Machine{
|
||||||
|
ID: 3,
|
||||||
|
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
|
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
|
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
|
Hostname: "test_get_shared_nodes_3",
|
||||||
|
UserID: userShared3.ID,
|
||||||
|
User: *userShared3,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
|
||||||
|
AuthKeyID: uint(preAuthKeyShared3.ID),
|
||||||
|
}
|
||||||
|
db.db.Save(machineInShared3)
|
||||||
|
|
||||||
|
_, err = db.GetMachine(userShared3.Name, machineInShared3.Hostname)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
machine2InShared1 := &types.Machine{
|
||||||
|
ID: 4,
|
||||||
|
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
|
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
|
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
|
Hostname: "test_get_shared_nodes_4",
|
||||||
|
UserID: userShared1.ID,
|
||||||
|
User: *userShared1,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")},
|
||||||
|
AuthKeyID: uint(preAuthKey2Shared1.ID),
|
||||||
|
}
|
||||||
|
db.db.Save(machine2InShared1)
|
||||||
|
|
||||||
|
peersOfMachine1InShared1, err := db.getPeers([]tailcfg.FilterRule{}, machineInShared1)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
userProfiles := db.GetMapResponseUserProfiles(
|
||||||
|
*machineInShared1,
|
||||||
|
peersOfMachine1InShared1,
|
||||||
|
)
|
||||||
|
|
||||||
|
c.Assert(len(userProfiles), check.Equals, 3)
|
||||||
|
|
||||||
|
found := false
|
||||||
|
for _, userProfiles := range userProfiles {
|
||||||
|
if userProfiles.DisplayName == userShared1.Name {
|
||||||
|
found = true
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.Assert(found, check.Equals, true)
|
||||||
|
|
||||||
|
found = false
|
||||||
|
for _, userProfile := range userProfiles {
|
||||||
|
if userProfile.DisplayName == userShared2.Name {
|
||||||
|
found = true
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.Assert(found, check.Equals, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Suite) TestSetMachineUser(c *check.C) {
|
||||||
|
oldUser, err := db.CreateUser("old")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
newUser, err := db.CreateUser("new")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
pak, err := db.CreatePreAuthKey(oldUser.Name, false, false, nil, nil)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
machine := types.Machine{
|
||||||
|
ID: 0,
|
||||||
|
MachineKey: "foo",
|
||||||
|
NodeKey: "bar",
|
||||||
|
DiscoKey: "faa",
|
||||||
|
Hostname: "testmachine",
|
||||||
|
UserID: oldUser.ID,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
AuthKeyID: uint(pak.ID),
|
||||||
|
}
|
||||||
|
db.db.Save(&machine)
|
||||||
|
c.Assert(machine.UserID, check.Equals, oldUser.ID)
|
||||||
|
|
||||||
|
err = db.SetMachineUser(&machine, newUser.Name)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(machine.UserID, check.Equals, newUser.ID)
|
||||||
|
c.Assert(machine.User.Name, check.Equals, newUser.Name)
|
||||||
|
|
||||||
|
err = db.SetMachineUser(&machine, "non-existing-user")
|
||||||
|
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||||
|
|
||||||
|
err = db.SetMachineUser(&machine, newUser.Name)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(machine.UserID, check.Equals, newUser.ID)
|
||||||
|
c.Assert(machine.User.Name, check.Equals, newUser.Name)
|
||||||
|
}
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
mapset "github.com/deckarep/golang-set/v2"
|
mapset "github.com/deckarep/golang-set/v2"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"go4.org/netipx"
|
"go4.org/netipx"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/dnstype"
|
"tailscale.com/types/dnstype"
|
||||||
|
@ -165,7 +166,7 @@ func generateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
|
||||||
//
|
//
|
||||||
// This will produce a resolver like:
|
// This will produce a resolver like:
|
||||||
// `https://dns.nextdns.io/<nextdns-id>?device_name=node-name&device_model=linux&device_ip=100.64.0.1`
|
// `https://dns.nextdns.io/<nextdns-id>?device_name=node-name&device_model=linux&device_ip=100.64.0.1`
|
||||||
func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine Machine) {
|
func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine types.Machine) {
|
||||||
for _, resolver := range resolvers {
|
for _, resolver := range resolvers {
|
||||||
if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) {
|
if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) {
|
||||||
attrs := url.Values{
|
attrs := url.Values{
|
||||||
|
@ -185,8 +186,8 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine Machine) {
|
||||||
func getMapResponseDNSConfig(
|
func getMapResponseDNSConfig(
|
||||||
dnsConfigOrig *tailcfg.DNSConfig,
|
dnsConfigOrig *tailcfg.DNSConfig,
|
||||||
baseDomain string,
|
baseDomain string,
|
||||||
machine Machine,
|
machine types.Machine,
|
||||||
peers Machines,
|
peers types.Machines,
|
||||||
) *tailcfg.DNSConfig {
|
) *tailcfg.DNSConfig {
|
||||||
var dnsConfig *tailcfg.DNSConfig = dnsConfigOrig.Clone()
|
var dnsConfig *tailcfg.DNSConfig = dnsConfigOrig.Clone()
|
||||||
if dnsConfigOrig != nil && dnsConfigOrig.Proxied { // if MagicDNS is enabled
|
if dnsConfigOrig != nil && dnsConfigOrig.Proxied { // if MagicDNS is enabled
|
||||||
|
@ -200,7 +201,7 @@ func getMapResponseDNSConfig(
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
userSet := mapset.NewSet[User]()
|
userSet := mapset.NewSet[types.User]()
|
||||||
userSet.Add(machine.User)
|
userSet.Add(machine.User)
|
||||||
for _, p := range peers {
|
for _, p := range peers {
|
||||||
userSet.Add(p.User)
|
userSet.Add(p.User)
|
||||||
|
|
|
@ -4,6 +4,8 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"gopkg.in/check.v1"
|
"gopkg.in/check.v1"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/dnstype"
|
"tailscale.com/types/dnstype"
|
||||||
|
@ -160,7 +162,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
|
||||||
_, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1")
|
_, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
machineInShared1 := &Machine{
|
machineInShared1 := &types.Machine{
|
||||||
ID: 1,
|
ID: 1,
|
||||||
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||||
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||||
|
@ -168,16 +170,17 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
|
||||||
Hostname: "test_get_shared_nodes_1",
|
Hostname: "test_get_shared_nodes_1",
|
||||||
UserID: userShared1.ID,
|
UserID: userShared1.ID,
|
||||||
User: *userShared1,
|
User: *userShared1,
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
|
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
|
||||||
AuthKeyID: uint(preAuthKeyInShared1.ID),
|
AuthKeyID: uint(preAuthKeyInShared1.ID),
|
||||||
}
|
}
|
||||||
app.db.db.Save(machineInShared1)
|
err = app.db.MachineSave(machineInShared1)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname)
|
_, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
machineInShared2 := &Machine{
|
machineInShared2 := &types.Machine{
|
||||||
ID: 2,
|
ID: 2,
|
||||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
|
@ -185,16 +188,17 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
|
||||||
Hostname: "test_get_shared_nodes_2",
|
Hostname: "test_get_shared_nodes_2",
|
||||||
UserID: userShared2.ID,
|
UserID: userShared2.ID,
|
||||||
User: *userShared2,
|
User: *userShared2,
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
|
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
|
||||||
AuthKeyID: uint(preAuthKeyInShared2.ID),
|
AuthKeyID: uint(preAuthKeyInShared2.ID),
|
||||||
}
|
}
|
||||||
app.db.db.Save(machineInShared2)
|
err = app.db.MachineSave(machineInShared2)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname)
|
_, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
machineInShared3 := &Machine{
|
machineInShared3 := &types.Machine{
|
||||||
ID: 3,
|
ID: 3,
|
||||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
|
@ -202,16 +206,17 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
|
||||||
Hostname: "test_get_shared_nodes_3",
|
Hostname: "test_get_shared_nodes_3",
|
||||||
UserID: userShared3.ID,
|
UserID: userShared3.ID,
|
||||||
User: *userShared3,
|
User: *userShared3,
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
|
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
|
||||||
AuthKeyID: uint(preAuthKeyInShared3.ID),
|
AuthKeyID: uint(preAuthKeyInShared3.ID),
|
||||||
}
|
}
|
||||||
app.db.db.Save(machineInShared3)
|
err = app.db.MachineSave(machineInShared3)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname)
|
_, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
machine2InShared1 := &Machine{
|
machine2InShared1 := &types.Machine{
|
||||||
ID: 4,
|
ID: 4,
|
||||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
|
@ -219,11 +224,12 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
|
||||||
Hostname: "test_get_shared_nodes_4",
|
Hostname: "test_get_shared_nodes_4",
|
||||||
UserID: userShared1.ID,
|
UserID: userShared1.ID,
|
||||||
User: *userShared1,
|
User: *userShared1,
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")},
|
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")},
|
||||||
AuthKeyID: uint(PreAuthKey2InShared1.ID),
|
AuthKeyID: uint(PreAuthKey2InShared1.ID),
|
||||||
}
|
}
|
||||||
app.db.db.Save(machine2InShared1)
|
err = app.db.MachineSave(machine2InShared1)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
baseDomain := "foobar.headscale.net"
|
baseDomain := "foobar.headscale.net"
|
||||||
dnsConfigOrig := tailcfg.DNSConfig{
|
dnsConfigOrig := tailcfg.DNSConfig{
|
||||||
|
@ -232,7 +238,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
|
||||||
Proxied: true,
|
Proxied: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
peersOfMachineInShared1, err := app.db.getPeers(app.aclPolicy, app.aclRules, machineInShared1)
|
peersOfMachineInShared1, err := app.db.GetValidPeers(app.aclRules, machineInShared1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
dnsConfig := getMapResponseDNSConfig(
|
dnsConfig := getMapResponseDNSConfig(
|
||||||
|
@ -307,7 +313,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
|
||||||
_, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1")
|
_, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
machineInShared1 := &Machine{
|
machineInShared1 := &types.Machine{
|
||||||
ID: 1,
|
ID: 1,
|
||||||
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||||
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||||
|
@ -315,16 +321,17 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
|
||||||
Hostname: "test_get_shared_nodes_1",
|
Hostname: "test_get_shared_nodes_1",
|
||||||
UserID: userShared1.ID,
|
UserID: userShared1.ID,
|
||||||
User: *userShared1,
|
User: *userShared1,
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
|
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
|
||||||
AuthKeyID: uint(preAuthKeyInShared1.ID),
|
AuthKeyID: uint(preAuthKeyInShared1.ID),
|
||||||
}
|
}
|
||||||
app.db.db.Save(machineInShared1)
|
err = app.db.MachineSave(machineInShared1)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname)
|
_, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
machineInShared2 := &Machine{
|
machineInShared2 := &types.Machine{
|
||||||
ID: 2,
|
ID: 2,
|
||||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
|
@ -332,16 +339,17 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
|
||||||
Hostname: "test_get_shared_nodes_2",
|
Hostname: "test_get_shared_nodes_2",
|
||||||
UserID: userShared2.ID,
|
UserID: userShared2.ID,
|
||||||
User: *userShared2,
|
User: *userShared2,
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
|
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
|
||||||
AuthKeyID: uint(preAuthKeyInShared2.ID),
|
AuthKeyID: uint(preAuthKeyInShared2.ID),
|
||||||
}
|
}
|
||||||
app.db.db.Save(machineInShared2)
|
err = app.db.MachineSave(machineInShared2)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname)
|
_, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
machineInShared3 := &Machine{
|
machineInShared3 := &types.Machine{
|
||||||
ID: 3,
|
ID: 3,
|
||||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
|
@ -349,16 +357,17 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
|
||||||
Hostname: "test_get_shared_nodes_3",
|
Hostname: "test_get_shared_nodes_3",
|
||||||
UserID: userShared3.ID,
|
UserID: userShared3.ID,
|
||||||
User: *userShared3,
|
User: *userShared3,
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
|
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
|
||||||
AuthKeyID: uint(preAuthKeyInShared3.ID),
|
AuthKeyID: uint(preAuthKeyInShared3.ID),
|
||||||
}
|
}
|
||||||
app.db.db.Save(machineInShared3)
|
err = app.db.MachineSave(machineInShared3)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname)
|
_, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
machine2InShared1 := &Machine{
|
machine2InShared1 := &types.Machine{
|
||||||
ID: 4,
|
ID: 4,
|
||||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
|
@ -366,11 +375,12 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
|
||||||
Hostname: "test_get_shared_nodes_4",
|
Hostname: "test_get_shared_nodes_4",
|
||||||
UserID: userShared1.ID,
|
UserID: userShared1.ID,
|
||||||
User: *userShared1,
|
User: *userShared1,
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")},
|
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")},
|
||||||
AuthKeyID: uint(preAuthKey2InShared1.ID),
|
AuthKeyID: uint(preAuthKey2InShared1.ID),
|
||||||
}
|
}
|
||||||
app.db.db.Save(machine2InShared1)
|
err = app.db.MachineSave(machine2InShared1)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
baseDomain := "foobar.headscale.net"
|
baseDomain := "foobar.headscale.net"
|
||||||
dnsConfigOrig := tailcfg.DNSConfig{
|
dnsConfigOrig := tailcfg.DNSConfig{
|
||||||
|
@ -379,7 +389,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
|
||||||
Proxied: false,
|
Proxied: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
peersOfMachine1Shared1, err := app.db.getPeers(app.aclPolicy, app.aclRules, machineInShared1)
|
peersOfMachine1Shared1, err := app.db.GetValidPeers(app.aclRules, machineInShared1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
dnsConfig := getMapResponseDNSConfig(
|
dnsConfig := getMapResponseDNSConfig(
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
|
@ -36,7 +37,7 @@ func (api headscaleV1APIServer) GetUser(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &v1.GetUserResponse{User: user.toProto()}, nil
|
return &v1.GetUserResponse{User: user.Proto()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (api headscaleV1APIServer) CreateUser(
|
func (api headscaleV1APIServer) CreateUser(
|
||||||
|
@ -48,7 +49,7 @@ func (api headscaleV1APIServer) CreateUser(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &v1.CreateUserResponse{User: user.toProto()}, nil
|
return &v1.CreateUserResponse{User: user.Proto()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (api headscaleV1APIServer) RenameUser(
|
func (api headscaleV1APIServer) RenameUser(
|
||||||
|
@ -65,7 +66,7 @@ func (api headscaleV1APIServer) RenameUser(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &v1.RenameUserResponse{User: user.toProto()}, nil
|
return &v1.RenameUserResponse{User: user.Proto()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (api headscaleV1APIServer) DeleteUser(
|
func (api headscaleV1APIServer) DeleteUser(
|
||||||
|
@ -91,7 +92,7 @@ func (api headscaleV1APIServer) ListUsers(
|
||||||
|
|
||||||
response := make([]*v1.User, len(users))
|
response := make([]*v1.User, len(users))
|
||||||
for index, user := range users {
|
for index, user := range users {
|
||||||
response[index] = user.toProto()
|
response[index] = user.Proto()
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Trace().Caller().Interface("users", response).Msg("")
|
log.Trace().Caller().Interface("users", response).Msg("")
|
||||||
|
@ -128,7 +129,7 @@ func (api headscaleV1APIServer) CreatePreAuthKey(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &v1.CreatePreAuthKeyResponse{PreAuthKey: preAuthKey.toProto()}, nil
|
return &v1.CreatePreAuthKeyResponse{PreAuthKey: preAuthKey.Proto()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (api headscaleV1APIServer) ExpirePreAuthKey(
|
func (api headscaleV1APIServer) ExpirePreAuthKey(
|
||||||
|
@ -159,7 +160,7 @@ func (api headscaleV1APIServer) ListPreAuthKeys(
|
||||||
|
|
||||||
response := make([]*v1.PreAuthKey, len(preAuthKeys))
|
response := make([]*v1.PreAuthKey, len(preAuthKeys))
|
||||||
for index, key := range preAuthKeys {
|
for index, key := range preAuthKeys {
|
||||||
response[index] = key.toProto()
|
response[index] = key.Proto()
|
||||||
}
|
}
|
||||||
|
|
||||||
return &v1.ListPreAuthKeysResponse{PreAuthKeys: response}, nil
|
return &v1.ListPreAuthKeysResponse{PreAuthKeys: response}, nil
|
||||||
|
@ -179,13 +180,13 @@ func (api headscaleV1APIServer) RegisterMachine(
|
||||||
request.GetKey(),
|
request.GetKey(),
|
||||||
request.GetUser(),
|
request.GetUser(),
|
||||||
nil,
|
nil,
|
||||||
RegisterMethodCLI,
|
util.RegisterMethodCLI,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &v1.RegisterMachineResponse{Machine: machine.toProto()}, nil
|
return &v1.RegisterMachineResponse{Machine: machine.Proto()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (api headscaleV1APIServer) GetMachine(
|
func (api headscaleV1APIServer) GetMachine(
|
||||||
|
@ -197,7 +198,7 @@ func (api headscaleV1APIServer) GetMachine(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &v1.GetMachineResponse{Machine: machine.toProto()}, nil
|
return &v1.GetMachineResponse{Machine: machine.Proto()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (api headscaleV1APIServer) SetTags(
|
func (api headscaleV1APIServer) SetTags(
|
||||||
|
@ -218,7 +219,7 @@ func (api headscaleV1APIServer) SetTags(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = api.h.db.SetTags(machine, request.GetTags(), api.h.UpdateACLRules)
|
err = api.h.db.SetTags(machine, request.GetTags())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &v1.SetTagsResponse{
|
return &v1.SetTagsResponse{
|
||||||
Machine: nil,
|
Machine: nil,
|
||||||
|
@ -230,7 +231,7 @@ func (api headscaleV1APIServer) SetTags(
|
||||||
Strs("tags", request.GetTags()).
|
Strs("tags", request.GetTags()).
|
||||||
Msg("Changing tags of machine")
|
Msg("Changing tags of machine")
|
||||||
|
|
||||||
return &v1.SetTagsResponse{Machine: machine.toProto()}, nil
|
return &v1.SetTagsResponse{Machine: machine.Proto()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateTag(tag string) error {
|
func validateTag(tag string) error {
|
||||||
|
@ -283,7 +284,7 @@ func (api headscaleV1APIServer) ExpireMachine(
|
||||||
Time("expiry", *machine.Expiry).
|
Time("expiry", *machine.Expiry).
|
||||||
Msg("machine expired")
|
Msg("machine expired")
|
||||||
|
|
||||||
return &v1.ExpireMachineResponse{Machine: machine.toProto()}, nil
|
return &v1.ExpireMachineResponse{Machine: machine.Proto()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (api headscaleV1APIServer) RenameMachine(
|
func (api headscaleV1APIServer) RenameMachine(
|
||||||
|
@ -308,7 +309,7 @@ func (api headscaleV1APIServer) RenameMachine(
|
||||||
Str("new_name", request.GetNewName()).
|
Str("new_name", request.GetNewName()).
|
||||||
Msg("machine renamed")
|
Msg("machine renamed")
|
||||||
|
|
||||||
return &v1.RenameMachineResponse{Machine: machine.toProto()}, nil
|
return &v1.RenameMachineResponse{Machine: machine.Proto()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (api headscaleV1APIServer) ListMachines(
|
func (api headscaleV1APIServer) ListMachines(
|
||||||
|
@ -323,7 +324,7 @@ func (api headscaleV1APIServer) ListMachines(
|
||||||
|
|
||||||
response := make([]*v1.Machine, len(machines))
|
response := make([]*v1.Machine, len(machines))
|
||||||
for index, machine := range machines {
|
for index, machine := range machines {
|
||||||
response[index] = machine.toProto()
|
response[index] = machine.Proto()
|
||||||
}
|
}
|
||||||
|
|
||||||
return &v1.ListMachinesResponse{Machines: response}, nil
|
return &v1.ListMachinesResponse{Machines: response}, nil
|
||||||
|
@ -336,9 +337,8 @@ func (api headscaleV1APIServer) ListMachines(
|
||||||
|
|
||||||
response := make([]*v1.Machine, len(machines))
|
response := make([]*v1.Machine, len(machines))
|
||||||
for index, machine := range machines {
|
for index, machine := range machines {
|
||||||
m := machine.toProto()
|
m := machine.Proto()
|
||||||
validTags, invalidTags := getTags(
|
validTags, invalidTags := api.h.ACLPolicy.GetTagsOfMachine(
|
||||||
api.h.aclPolicy,
|
|
||||||
machine,
|
machine,
|
||||||
api.h.cfg.OIDC.StripEmaildomain,
|
api.h.cfg.OIDC.StripEmaildomain,
|
||||||
)
|
)
|
||||||
|
@ -364,7 +364,7 @@ func (api headscaleV1APIServer) MoveMachine(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &v1.MoveMachineResponse{Machine: machine.toProto()}, nil
|
return &v1.MoveMachineResponse{Machine: machine.Proto()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (api headscaleV1APIServer) GetRoutes(
|
func (api headscaleV1APIServer) GetRoutes(
|
||||||
|
@ -377,7 +377,7 @@ func (api headscaleV1APIServer) GetRoutes(
|
||||||
}
|
}
|
||||||
|
|
||||||
return &v1.GetRoutesResponse{
|
return &v1.GetRoutesResponse{
|
||||||
Routes: Routes(routes).toProto(),
|
Routes: types.Routes(routes).Proto(),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -420,7 +420,7 @@ func (api headscaleV1APIServer) GetMachineRoutes(
|
||||||
}
|
}
|
||||||
|
|
||||||
return &v1.GetMachineRoutesResponse{
|
return &v1.GetMachineRoutesResponse{
|
||||||
Routes: Routes(routes).toProto(),
|
Routes: types.Routes(routes).Proto(),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -459,7 +459,7 @@ func (api headscaleV1APIServer) ExpireApiKey(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.ExpireApiKeyRequest,
|
request *v1.ExpireApiKeyRequest,
|
||||||
) (*v1.ExpireApiKeyResponse, error) {
|
) (*v1.ExpireApiKeyResponse, error) {
|
||||||
var apiKey *APIKey
|
var apiKey *types.APIKey
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
apiKey, err = api.h.db.GetAPIKey(request.Prefix)
|
apiKey, err = api.h.db.GetAPIKey(request.Prefix)
|
||||||
|
@ -486,7 +486,7 @@ func (api headscaleV1APIServer) ListApiKeys(
|
||||||
|
|
||||||
response := make([]*v1.ApiKey, len(apiKeys))
|
response := make([]*v1.ApiKey, len(apiKeys))
|
||||||
for index, key := range apiKeys {
|
for index, key := range apiKeys {
|
||||||
response[index] = key.toProto()
|
response[index] = key.Proto()
|
||||||
}
|
}
|
||||||
|
|
||||||
return &v1.ListApiKeysResponse{ApiKeys: response}, nil
|
return &v1.ListApiKeysResponse{ApiKeys: response}, nil
|
||||||
|
@ -524,7 +524,7 @@ func (api headscaleV1APIServer) DebugCreateMachine(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
newMachine := Machine{
|
newMachine := types.Machine{
|
||||||
MachineKey: request.GetKey(),
|
MachineKey: request.GetKey(),
|
||||||
Hostname: request.GetName(),
|
Hostname: request.GetName(),
|
||||||
GivenName: givenName,
|
GivenName: givenName,
|
||||||
|
@ -534,7 +534,7 @@ func (api headscaleV1APIServer) DebugCreateMachine(
|
||||||
LastSeen: &time.Time{},
|
LastSeen: &time.Time{},
|
||||||
LastSuccessfulUpdate: &time.Time{},
|
LastSuccessfulUpdate: &time.Time{},
|
||||||
|
|
||||||
HostInfo: HostInfo(hostinfo),
|
HostInfo: types.HostInfo(hostinfo),
|
||||||
}
|
}
|
||||||
|
|
||||||
nodeKey := key.NodePublic{}
|
nodeKey := key.NodePublic{}
|
||||||
|
@ -549,7 +549,7 @@ func (api headscaleV1APIServer) DebugCreateMachine(
|
||||||
registerCacheExpiration,
|
registerCacheExpiration,
|
||||||
)
|
)
|
||||||
|
|
||||||
return &v1.DebugCreateMachineResponse{Machine: newMachine.toProto()}, nil
|
return &v1.DebugCreateMachineResponse{Machine: newMachine.Proto()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (api headscaleV1APIServer) mustEmbedUnimplementedHeadscaleServiceServer() {}
|
func (api headscaleV1APIServer) mustEmbedUnimplementedHeadscaleServiceServer() {}
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1,142 +0,0 @@
|
||||||
package hscontrol
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"go4.org/netipx"
|
|
||||||
"tailscale.com/tailcfg"
|
|
||||||
)
|
|
||||||
|
|
||||||
// This is borrowed from, and updated to use IPSet
|
|
||||||
// https://github.com/tailscale/tailscale/blob/71029cea2ddf82007b80f465b256d027eab0f02d/wgengine/filter/tailcfg.go#L97-L162
|
|
||||||
// TODO(kradalby): contribute upstream and make public.
|
|
||||||
var (
|
|
||||||
zeroIP4 = netip.AddrFrom4([4]byte{})
|
|
||||||
zeroIP6 = netip.AddrFrom16([16]byte{})
|
|
||||||
)
|
|
||||||
|
|
||||||
// parseIPSet parses arg as one:
|
|
||||||
//
|
|
||||||
// - an IP address (IPv4 or IPv6)
|
|
||||||
// - the string "*" to match everything (both IPv4 & IPv6)
|
|
||||||
// - a CIDR (e.g. "192.168.0.0/16")
|
|
||||||
// - a range of two IPs, inclusive, separated by hyphen ("2eff::1-2eff::0800")
|
|
||||||
//
|
|
||||||
// bits, if non-nil, is the legacy SrcBits CIDR length to make a IP
|
|
||||||
// address (without a slash) treated as a CIDR of *bits length.
|
|
||||||
// nolint
|
|
||||||
func parseIPSet(arg string, bits *int) (*netipx.IPSet, error) {
|
|
||||||
var ipSet netipx.IPSetBuilder
|
|
||||||
if arg == "*" {
|
|
||||||
ipSet.AddPrefix(netip.PrefixFrom(zeroIP4, 0))
|
|
||||||
ipSet.AddPrefix(netip.PrefixFrom(zeroIP6, 0))
|
|
||||||
|
|
||||||
return ipSet.IPSet()
|
|
||||||
}
|
|
||||||
if strings.Contains(arg, "/") {
|
|
||||||
pfx, err := netip.ParsePrefix(arg)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if pfx != pfx.Masked() {
|
|
||||||
return nil, fmt.Errorf("%v contains non-network bits set", pfx)
|
|
||||||
}
|
|
||||||
|
|
||||||
ipSet.AddPrefix(pfx)
|
|
||||||
|
|
||||||
return ipSet.IPSet()
|
|
||||||
}
|
|
||||||
if strings.Count(arg, "-") == 1 {
|
|
||||||
ip1s, ip2s, _ := strings.Cut(arg, "-")
|
|
||||||
|
|
||||||
ip1, err := netip.ParseAddr(ip1s)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ip2, err := netip.ParseAddr(ip2s)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
r := netipx.IPRangeFrom(ip1, ip2)
|
|
||||||
if !r.IsValid() {
|
|
||||||
return nil, fmt.Errorf("invalid IP range %q", arg)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, prefix := range r.Prefixes() {
|
|
||||||
ipSet.AddPrefix(prefix)
|
|
||||||
}
|
|
||||||
|
|
||||||
return ipSet.IPSet()
|
|
||||||
}
|
|
||||||
ip, err := netip.ParseAddr(arg)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid IP address %q", arg)
|
|
||||||
}
|
|
||||||
bits8 := uint8(ip.BitLen())
|
|
||||||
if bits != nil {
|
|
||||||
if *bits < 0 || *bits > int(bits8) {
|
|
||||||
return nil, fmt.Errorf("invalid CIDR size %d for IP %q", *bits, arg)
|
|
||||||
}
|
|
||||||
bits8 = uint8(*bits)
|
|
||||||
}
|
|
||||||
|
|
||||||
ipSet.AddPrefix(netip.PrefixFrom(ip, int(bits8)))
|
|
||||||
|
|
||||||
return ipSet.IPSet()
|
|
||||||
}
|
|
||||||
|
|
||||||
type Match struct {
|
|
||||||
Srcs *netipx.IPSet
|
|
||||||
Dests *netipx.IPSet
|
|
||||||
}
|
|
||||||
|
|
||||||
func MatchFromFilterRule(rule tailcfg.FilterRule) Match {
|
|
||||||
srcs := new(netipx.IPSetBuilder)
|
|
||||||
dests := new(netipx.IPSetBuilder)
|
|
||||||
|
|
||||||
for _, srcIP := range rule.SrcIPs {
|
|
||||||
set, _ := parseIPSet(srcIP, nil)
|
|
||||||
|
|
||||||
srcs.AddSet(set)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, dest := range rule.DstPorts {
|
|
||||||
set, _ := parseIPSet(dest.IP, nil)
|
|
||||||
|
|
||||||
dests.AddSet(set)
|
|
||||||
}
|
|
||||||
|
|
||||||
srcsSet, _ := srcs.IPSet()
|
|
||||||
destsSet, _ := dests.IPSet()
|
|
||||||
|
|
||||||
match := Match{
|
|
||||||
Srcs: srcsSet,
|
|
||||||
Dests: destsSet,
|
|
||||||
}
|
|
||||||
|
|
||||||
return match
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Match) SrcsContainsIPs(ips []netip.Addr) bool {
|
|
||||||
for _, ip := range ips {
|
|
||||||
if m.Srcs.Contains(ip) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Match) DestsContainsIP(ips []netip.Addr) bool {
|
|
||||||
for _, ip := range ips {
|
|
||||||
if m.Dests.Contains(ip) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
|
@ -14,6 +14,8 @@ import (
|
||||||
|
|
||||||
"github.com/coreos/go-oidc/v3/oidc"
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/db"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
@ -638,7 +640,7 @@ func getUserName(
|
||||||
claims *IDTokenClaims,
|
claims *IDTokenClaims,
|
||||||
stripEmaildomain bool,
|
stripEmaildomain bool,
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
userName, err := NormalizeToFQDNRules(
|
userName, err := util.NormalizeToFQDNRules(
|
||||||
claims.Email,
|
claims.Email,
|
||||||
stripEmaildomain,
|
stripEmaildomain,
|
||||||
)
|
)
|
||||||
|
@ -663,9 +665,9 @@ func getUserName(
|
||||||
func (h *Headscale) findOrCreateNewUserForOIDCCallback(
|
func (h *Headscale) findOrCreateNewUserForOIDCCallback(
|
||||||
writer http.ResponseWriter,
|
writer http.ResponseWriter,
|
||||||
userName string,
|
userName string,
|
||||||
) (*User, error) {
|
) (*types.User, error) {
|
||||||
user, err := h.db.GetUser(userName)
|
user, err := h.db.GetUser(userName)
|
||||||
if errors.Is(err, ErrUserNotFound) {
|
if errors.Is(err, db.ErrUserNotFound) {
|
||||||
user, err = h.db.CreateUser(userName)
|
user, err = h.db.CreateUser(userName)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -709,7 +711,7 @@ func (h *Headscale) findOrCreateNewUserForOIDCCallback(
|
||||||
|
|
||||||
func (h *Headscale) registerMachineForOIDCCallback(
|
func (h *Headscale) registerMachineForOIDCCallback(
|
||||||
writer http.ResponseWriter,
|
writer http.ResponseWriter,
|
||||||
user *User,
|
user *types.User,
|
||||||
nodeKey *key.NodePublic,
|
nodeKey *key.NodePublic,
|
||||||
expiry time.Time,
|
expiry time.Time,
|
||||||
) error {
|
) error {
|
||||||
|
@ -719,7 +721,7 @@ func (h *Headscale) registerMachineForOIDCCallback(
|
||||||
nodeKey.String(),
|
nodeKey.String(),
|
||||||
user.Name,
|
user.Name,
|
||||||
&expiry,
|
&expiry,
|
||||||
RegisterMethodOIDC,
|
util.RegisterMethodOIDC,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package hscontrol
|
package policy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
@ -12,6 +12,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/tailscale/hujson"
|
"github.com/tailscale/hujson"
|
||||||
|
@ -22,12 +23,12 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errEmptyPolicy = errors.New("empty policy")
|
ErrEmptyPolicy = errors.New("empty policy")
|
||||||
errInvalidAction = errors.New("invalid action")
|
ErrInvalidAction = errors.New("invalid action")
|
||||||
errInvalidGroup = errors.New("invalid group")
|
ErrInvalidGroup = errors.New("invalid group")
|
||||||
errInvalidTag = errors.New("invalid tag")
|
ErrInvalidTag = errors.New("invalid tag")
|
||||||
errInvalidPortFormat = errors.New("invalid port format")
|
ErrInvalidPortFormat = errors.New("invalid port format")
|
||||||
errWildcardIsNeeded = errors.New("wildcard as port is required for the protocol")
|
ErrWildcardIsNeeded = errors.New("wildcard as port is required for the protocol")
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -56,7 +57,7 @@ const (
|
||||||
var featureEnableSSH = envknob.RegisterBool("HEADSCALE_EXPERIMENTAL_FEATURE_SSH")
|
var featureEnableSSH = envknob.RegisterBool("HEADSCALE_EXPERIMENTAL_FEATURE_SSH")
|
||||||
|
|
||||||
// LoadACLPolicyFromPath loads the ACL policy from the specify path, and generates the ACL rules.
|
// LoadACLPolicyFromPath loads the ACL policy from the specify path, and generates the ACL rules.
|
||||||
func (h *Headscale) LoadACLPolicyFromPath(path string) error {
|
func LoadACLPolicyFromPath(path string) (*ACLPolicy, error) {
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Str("func", "LoadACLPolicy").
|
Str("func", "LoadACLPolicy").
|
||||||
Str("path", path).
|
Str("path", path).
|
||||||
|
@ -64,13 +65,13 @@ func (h *Headscale) LoadACLPolicyFromPath(path string) error {
|
||||||
|
|
||||||
policyFile, err := os.Open(path)
|
policyFile, err := os.Open(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer policyFile.Close()
|
defer policyFile.Close()
|
||||||
|
|
||||||
policyBytes, err := io.ReadAll(policyFile)
|
policyBytes, err := io.ReadAll(policyFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().
|
log.Debug().
|
||||||
|
@ -80,90 +81,90 @@ func (h *Headscale) LoadACLPolicyFromPath(path string) error {
|
||||||
|
|
||||||
switch filepath.Ext(path) {
|
switch filepath.Ext(path) {
|
||||||
case ".yml", ".yaml":
|
case ".yml", ".yaml":
|
||||||
return h.LoadACLPolicyFromBytes(policyBytes, "yaml")
|
return LoadACLPolicyFromBytes(policyBytes, "yaml")
|
||||||
}
|
}
|
||||||
|
|
||||||
return h.LoadACLPolicyFromBytes(policyBytes, "hujson")
|
return LoadACLPolicyFromBytes(policyBytes, "hujson")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) LoadACLPolicyFromBytes(acl []byte, format string) error {
|
func LoadACLPolicyFromBytes(acl []byte, format string) (*ACLPolicy, error) {
|
||||||
var policy ACLPolicy
|
var policy ACLPolicy
|
||||||
switch format {
|
switch format {
|
||||||
case "yaml":
|
case "yaml":
|
||||||
err := yaml.Unmarshal(acl, &policy)
|
err := yaml.Unmarshal(acl, &policy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
ast, err := hujson.Parse(acl)
|
ast, err := hujson.Parse(acl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ast.Standardize()
|
ast.Standardize()
|
||||||
acl = ast.Pack()
|
acl = ast.Pack()
|
||||||
err = json.Unmarshal(acl, &policy)
|
err = json.Unmarshal(acl, &policy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if policy.IsZero() {
|
if policy.IsZero() {
|
||||||
return errEmptyPolicy
|
return nil, ErrEmptyPolicy
|
||||||
}
|
}
|
||||||
|
|
||||||
h.aclPolicy = &policy
|
return &policy, nil
|
||||||
|
|
||||||
return h.UpdateACLRules()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) UpdateACLRules() error {
|
// TODO(kradalby): This needs to be replace with something that generates
|
||||||
machines, err := h.db.ListMachines()
|
// the rules as needed and not stores it on the global object, rules are
|
||||||
if err != nil {
|
// per node and that should be taken into account.
|
||||||
return err
|
func GenerateFilterRules(
|
||||||
|
policy *ACLPolicy,
|
||||||
|
machines types.Machines,
|
||||||
|
stripEmailDomain bool,
|
||||||
|
) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) {
|
||||||
|
if policy == nil {
|
||||||
|
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, ErrEmptyPolicy
|
||||||
}
|
}
|
||||||
|
|
||||||
if h.aclPolicy == nil {
|
rules, err := policy.generateFilterRules(machines, stripEmailDomain)
|
||||||
return errEmptyPolicy
|
|
||||||
}
|
|
||||||
|
|
||||||
rules, err := h.aclPolicy.generateFilterRules(machines, h.cfg.OIDC.StripEmaildomain)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Trace().Interface("ACL", rules).Msg("ACL rules generated")
|
log.Trace().Interface("ACL", rules).Msg("ACL rules generated")
|
||||||
h.aclRules = rules
|
|
||||||
|
|
||||||
|
var sshPolicy *tailcfg.SSHPolicy
|
||||||
if featureEnableSSH() {
|
if featureEnableSSH() {
|
||||||
sshRules, err := h.generateSSHRules()
|
sshRules, err := generateSSHRules(policy, machines, stripEmailDomain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
|
||||||
}
|
}
|
||||||
log.Trace().Interface("SSH", sshRules).Msg("SSH rules generated")
|
log.Trace().Interface("SSH", sshRules).Msg("SSH rules generated")
|
||||||
if h.sshPolicy == nil {
|
if sshPolicy == nil {
|
||||||
h.sshPolicy = &tailcfg.SSHPolicy{}
|
sshPolicy = &tailcfg.SSHPolicy{}
|
||||||
}
|
}
|
||||||
h.sshPolicy.Rules = sshRules
|
sshPolicy.Rules = sshRules
|
||||||
} else if h.aclPolicy != nil && len(h.aclPolicy.SSHs) > 0 {
|
} else if policy != nil && len(policy.SSHs) > 0 {
|
||||||
log.Info().Msg("SSH ACLs has been defined, but HEADSCALE_EXPERIMENTAL_FEATURE_SSH is not enabled, this is a unstable feature, check docs before activating")
|
log.Info().Msg("SSH ACLs has been defined, but HEADSCALE_EXPERIMENTAL_FEATURE_SSH is not enabled, this is a unstable feature, check docs before activating")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return rules, sshPolicy, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateFilterRules takes a set of machines and an ACLPolicy and generates a
|
// generateFilterRules takes a set of machines and an ACLPolicy and generates a
|
||||||
// set of Tailscale compatible FilterRules used to allow traffic on clients.
|
// set of Tailscale compatible FilterRules used to allow traffic on clients.
|
||||||
func (pol *ACLPolicy) generateFilterRules(
|
func (pol *ACLPolicy) generateFilterRules(
|
||||||
machines []Machine,
|
machines types.Machines,
|
||||||
stripEmailDomain bool,
|
stripEmailDomain bool,
|
||||||
) ([]tailcfg.FilterRule, error) {
|
) ([]tailcfg.FilterRule, error) {
|
||||||
rules := []tailcfg.FilterRule{}
|
rules := []tailcfg.FilterRule{}
|
||||||
|
|
||||||
for index, acl := range pol.ACLs {
|
for index, acl := range pol.ACLs {
|
||||||
if acl.Action != "accept" {
|
if acl.Action != "accept" {
|
||||||
return nil, errInvalidAction
|
return nil, ErrInvalidAction
|
||||||
}
|
}
|
||||||
|
|
||||||
srcIPs := []string{}
|
srcIPs := []string{}
|
||||||
|
@ -219,16 +220,15 @@ func (pol *ACLPolicy) generateFilterRules(
|
||||||
return rules, nil
|
return rules, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) {
|
func generateSSHRules(
|
||||||
|
policy *ACLPolicy,
|
||||||
|
machines types.Machines,
|
||||||
|
stripEmailDomain bool,
|
||||||
|
) ([]*tailcfg.SSHRule, error) {
|
||||||
rules := []*tailcfg.SSHRule{}
|
rules := []*tailcfg.SSHRule{}
|
||||||
|
|
||||||
if h.aclPolicy == nil {
|
if policy == nil {
|
||||||
return nil, errEmptyPolicy
|
return nil, ErrEmptyPolicy
|
||||||
}
|
|
||||||
|
|
||||||
machines, err := h.db.ListMachines()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
acceptAction := tailcfg.SSHAction{
|
acceptAction := tailcfg.SSHAction{
|
||||||
|
@ -251,7 +251,7 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) {
|
||||||
AllowLocalPortForwarding: false,
|
AllowLocalPortForwarding: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
for index, sshACL := range h.aclPolicy.SSHs {
|
for index, sshACL := range policy.SSHs {
|
||||||
action := rejectAction
|
action := rejectAction
|
||||||
switch sshACL.Action {
|
switch sshACL.Action {
|
||||||
case "accept":
|
case "accept":
|
||||||
|
@ -266,9 +266,9 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) {
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
log.Error().
|
log.Error().
|
||||||
Msgf("Error parsing SSH %d, unknown action '%s'", index, sshACL.Action)
|
Msgf("Error parsing SSH %d, unknown action '%s', skipping", index, sshACL.Action)
|
||||||
|
|
||||||
return nil, err
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources))
|
principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources))
|
||||||
|
@ -278,7 +278,7 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) {
|
||||||
Any: true,
|
Any: true,
|
||||||
})
|
})
|
||||||
} else if isGroup(rawSrc) {
|
} else if isGroup(rawSrc) {
|
||||||
users, err := h.aclPolicy.getUsersInGroup(rawSrc, h.cfg.OIDC.StripEmaildomain)
|
users, err := policy.getUsersInGroup(rawSrc, stripEmailDomain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Msgf("Error parsing SSH %d, Source %d", index, innerIndex)
|
Msgf("Error parsing SSH %d, Source %d", index, innerIndex)
|
||||||
|
@ -292,10 +292,10 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
expandedSrcs, err := h.aclPolicy.expandAlias(
|
expandedSrcs, err := policy.ExpandAlias(
|
||||||
machines,
|
machines,
|
||||||
rawSrc,
|
rawSrc,
|
||||||
h.cfg.OIDC.StripEmaildomain,
|
stripEmailDomain,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
|
@ -346,10 +346,10 @@ func sshCheckAction(duration string) (*tailcfg.SSHAction, error) {
|
||||||
// with the given src alias.
|
// with the given src alias.
|
||||||
func (pol *ACLPolicy) getIPsFromSource(
|
func (pol *ACLPolicy) getIPsFromSource(
|
||||||
src string,
|
src string,
|
||||||
machines []Machine,
|
machines types.Machines,
|
||||||
stripEmaildomain bool,
|
stripEmaildomain bool,
|
||||||
) ([]string, error) {
|
) ([]string, error) {
|
||||||
ipSet, err := pol.expandAlias(machines, src, stripEmaildomain)
|
ipSet, err := pol.ExpandAlias(machines, src, stripEmaildomain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return []string{}, err
|
return []string{}, err
|
||||||
}
|
}
|
||||||
|
@ -367,7 +367,7 @@ func (pol *ACLPolicy) getIPsFromSource(
|
||||||
// which are associated with the dest alias.
|
// which are associated with the dest alias.
|
||||||
func (pol *ACLPolicy) getNetPortRangeFromDestination(
|
func (pol *ACLPolicy) getNetPortRangeFromDestination(
|
||||||
dest string,
|
dest string,
|
||||||
machines []Machine,
|
machines types.Machines,
|
||||||
needsWildcard bool,
|
needsWildcard bool,
|
||||||
stripEmaildomain bool,
|
stripEmaildomain bool,
|
||||||
) ([]tailcfg.NetPortRange, error) {
|
) ([]tailcfg.NetPortRange, error) {
|
||||||
|
@ -390,7 +390,7 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination(
|
||||||
return nil, fmt.Errorf(
|
return nil, fmt.Errorf(
|
||||||
"failed to parse destination, tokens %v: %w",
|
"failed to parse destination, tokens %v: %w",
|
||||||
tokens,
|
tokens,
|
||||||
errInvalidPortFormat,
|
ErrInvalidPortFormat,
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
tokens = []string{maybeIPv6Str, port}
|
tokens = []string{maybeIPv6Str, port}
|
||||||
|
@ -414,7 +414,7 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination(
|
||||||
alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1])
|
alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1])
|
||||||
}
|
}
|
||||||
|
|
||||||
expanded, err := pol.expandAlias(
|
expanded, err := pol.ExpandAlias(
|
||||||
machines,
|
machines,
|
||||||
alias,
|
alias,
|
||||||
stripEmaildomain,
|
stripEmaildomain,
|
||||||
|
@ -499,13 +499,13 @@ func parseProtocol(protocol string) ([]int, bool, error) {
|
||||||
// - an ip
|
// - an ip
|
||||||
// - a cidr
|
// - a cidr
|
||||||
// and transform these in IPAddresses.
|
// and transform these in IPAddresses.
|
||||||
func (pol *ACLPolicy) expandAlias(
|
func (pol *ACLPolicy) ExpandAlias(
|
||||||
machines Machines,
|
machines types.Machines,
|
||||||
alias string,
|
alias string,
|
||||||
stripEmailDomain bool,
|
stripEmailDomain bool,
|
||||||
) (*netipx.IPSet, error) {
|
) (*netipx.IPSet, error) {
|
||||||
if isWildcard(alias) {
|
if isWildcard(alias) {
|
||||||
return parseIPSet("*", nil)
|
return util.ParseIPSet("*", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
build := netipx.IPSetBuilder{}
|
build := netipx.IPSetBuilder{}
|
||||||
|
@ -532,9 +532,9 @@ func (pol *ACLPolicy) expandAlias(
|
||||||
// if alias is an host
|
// if alias is an host
|
||||||
// Note, this is recursive.
|
// Note, this is recursive.
|
||||||
if h, ok := pol.Hosts[alias]; ok {
|
if h, ok := pol.Hosts[alias]; ok {
|
||||||
log.Trace().Str("host", h.String()).Msg("expandAlias got hosts entry")
|
log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry")
|
||||||
|
|
||||||
return pol.expandAlias(machines, h.String(), stripEmailDomain)
|
return pol.ExpandAlias(machines, h.String(), stripEmailDomain)
|
||||||
}
|
}
|
||||||
|
|
||||||
// if alias is an IP
|
// if alias is an IP
|
||||||
|
@ -557,11 +557,11 @@ func (pol *ACLPolicy) expandAlias(
|
||||||
// we assume in this function that we only have nodes from 1 user.
|
// we assume in this function that we only have nodes from 1 user.
|
||||||
func excludeCorrectlyTaggedNodes(
|
func excludeCorrectlyTaggedNodes(
|
||||||
aclPolicy *ACLPolicy,
|
aclPolicy *ACLPolicy,
|
||||||
nodes []Machine,
|
nodes types.Machines,
|
||||||
user string,
|
user string,
|
||||||
stripEmailDomain bool,
|
stripEmailDomain bool,
|
||||||
) []Machine {
|
) types.Machines {
|
||||||
out := []Machine{}
|
out := types.Machines{}
|
||||||
tags := []string{}
|
tags := []string{}
|
||||||
for tag := range aclPolicy.TagOwners {
|
for tag := range aclPolicy.TagOwners {
|
||||||
owners, _ := getTagOwners(aclPolicy, user, stripEmailDomain)
|
owners, _ := getTagOwners(aclPolicy, user, stripEmailDomain)
|
||||||
|
@ -601,7 +601,7 @@ func expandPorts(portsStr string, needsWildcard bool) (*[]tailcfg.PortRange, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if needsWildcard {
|
if needsWildcard {
|
||||||
return nil, errWildcardIsNeeded
|
return nil, ErrWildcardIsNeeded
|
||||||
}
|
}
|
||||||
|
|
||||||
ports := []tailcfg.PortRange{}
|
ports := []tailcfg.PortRange{}
|
||||||
|
@ -634,15 +634,15 @@ func expandPorts(portsStr string, needsWildcard bool) (*[]tailcfg.PortRange, err
|
||||||
})
|
})
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return nil, errInvalidPortFormat
|
return nil, ErrInvalidPortFormat
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &ports, nil
|
return &ports, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func filterMachinesByUser(machines []Machine, user string) []Machine {
|
func filterMachinesByUser(machines types.Machines, user string) types.Machines {
|
||||||
out := []Machine{}
|
out := types.Machines{}
|
||||||
for _, machine := range machines {
|
for _, machine := range machines {
|
||||||
if machine.User.Name == user {
|
if machine.User.Name == user {
|
||||||
out = append(out, machine)
|
out = append(out, machine)
|
||||||
|
@ -664,7 +664,7 @@ func getTagOwners(
|
||||||
if !ok {
|
if !ok {
|
||||||
return []string{}, fmt.Errorf(
|
return []string{}, fmt.Errorf(
|
||||||
"%w. %v isn't owned by a TagOwner. Please add one first. https://tailscale.com/kb/1018/acls/#tag-owners",
|
"%w. %v isn't owned by a TagOwner. Please add one first. https://tailscale.com/kb/1018/acls/#tag-owners",
|
||||||
errInvalidTag,
|
ErrInvalidTag,
|
||||||
tag,
|
tag,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -696,22 +696,22 @@ func (pol *ACLPolicy) getUsersInGroup(
|
||||||
return []string{}, fmt.Errorf(
|
return []string{}, fmt.Errorf(
|
||||||
"group %v isn't registered. %w",
|
"group %v isn't registered. %w",
|
||||||
group,
|
group,
|
||||||
errInvalidGroup,
|
ErrInvalidGroup,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
for _, group := range aclGroups {
|
for _, group := range aclGroups {
|
||||||
if isGroup(group) {
|
if isGroup(group) {
|
||||||
return []string{}, fmt.Errorf(
|
return []string{}, fmt.Errorf(
|
||||||
"%w. A group cannot be composed of groups. https://tailscale.com/kb/1018/acls/#groups",
|
"%w. A group cannot be composed of groups. https://tailscale.com/kb/1018/acls/#groups",
|
||||||
errInvalidGroup,
|
ErrInvalidGroup,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
grp, err := NormalizeToFQDNRules(group, stripEmailDomain)
|
grp, err := util.NormalizeToFQDNRules(group, stripEmailDomain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return []string{}, fmt.Errorf(
|
return []string{}, fmt.Errorf(
|
||||||
"failed to normalize group %q, err: %w",
|
"failed to normalize group %q, err: %w",
|
||||||
group,
|
group,
|
||||||
errInvalidGroup,
|
ErrInvalidGroup,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
users = append(users, grp)
|
users = append(users, grp)
|
||||||
|
@ -722,7 +722,7 @@ func (pol *ACLPolicy) getUsersInGroup(
|
||||||
|
|
||||||
func (pol *ACLPolicy) getIPsFromGroup(
|
func (pol *ACLPolicy) getIPsFromGroup(
|
||||||
group string,
|
group string,
|
||||||
machines Machines,
|
machines types.Machines,
|
||||||
stripEmailDomain bool,
|
stripEmailDomain bool,
|
||||||
) (*netipx.IPSet, error) {
|
) (*netipx.IPSet, error) {
|
||||||
build := netipx.IPSetBuilder{}
|
build := netipx.IPSetBuilder{}
|
||||||
|
@ -743,7 +743,7 @@ func (pol *ACLPolicy) getIPsFromGroup(
|
||||||
|
|
||||||
func (pol *ACLPolicy) getIPsFromTag(
|
func (pol *ACLPolicy) getIPsFromTag(
|
||||||
alias string,
|
alias string,
|
||||||
machines Machines,
|
machines types.Machines,
|
||||||
stripEmailDomain bool,
|
stripEmailDomain bool,
|
||||||
) (*netipx.IPSet, error) {
|
) (*netipx.IPSet, error) {
|
||||||
build := netipx.IPSetBuilder{}
|
build := netipx.IPSetBuilder{}
|
||||||
|
@ -758,12 +758,12 @@ func (pol *ACLPolicy) getIPsFromTag(
|
||||||
// find tag owners
|
// find tag owners
|
||||||
owners, err := getTagOwners(pol, alias, stripEmailDomain)
|
owners, err := getTagOwners(pol, alias, stripEmailDomain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, errInvalidTag) {
|
if errors.Is(err, ErrInvalidTag) {
|
||||||
ipSet, _ := build.IPSet()
|
ipSet, _ := build.IPSet()
|
||||||
if len(ipSet.Prefixes()) == 0 {
|
if len(ipSet.Prefixes()) == 0 {
|
||||||
return ipSet, fmt.Errorf(
|
return ipSet, fmt.Errorf(
|
||||||
"%w. %v isn't owned by a TagOwner and no forced tags are defined",
|
"%w. %v isn't owned by a TagOwner and no forced tags are defined",
|
||||||
errInvalidTag,
|
ErrInvalidTag,
|
||||||
alias,
|
alias,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -790,7 +790,7 @@ func (pol *ACLPolicy) getIPsFromTag(
|
||||||
|
|
||||||
func (pol *ACLPolicy) getIPsForUser(
|
func (pol *ACLPolicy) getIPsForUser(
|
||||||
user string,
|
user string,
|
||||||
machines Machines,
|
machines types.Machines,
|
||||||
stripEmailDomain bool,
|
stripEmailDomain bool,
|
||||||
) (*netipx.IPSet, error) {
|
) (*netipx.IPSet, error) {
|
||||||
build := netipx.IPSetBuilder{}
|
build := netipx.IPSetBuilder{}
|
||||||
|
@ -812,9 +812,9 @@ func (pol *ACLPolicy) getIPsForUser(
|
||||||
|
|
||||||
func (pol *ACLPolicy) getIPsFromSingleIP(
|
func (pol *ACLPolicy) getIPsFromSingleIP(
|
||||||
ip netip.Addr,
|
ip netip.Addr,
|
||||||
machines Machines,
|
machines types.Machines,
|
||||||
) (*netipx.IPSet, error) {
|
) (*netipx.IPSet, error) {
|
||||||
log.Trace().Str("ip", ip.String()).Msg("expandAlias got ip")
|
log.Trace().Str("ip", ip.String()).Msg("ExpandAlias got ip")
|
||||||
|
|
||||||
matches := machines.FilterByIP(ip)
|
matches := machines.FilterByIP(ip)
|
||||||
|
|
||||||
|
@ -830,7 +830,7 @@ func (pol *ACLPolicy) getIPsFromSingleIP(
|
||||||
|
|
||||||
func (pol *ACLPolicy) getIPsFromIPPrefix(
|
func (pol *ACLPolicy) getIPsFromIPPrefix(
|
||||||
prefix netip.Prefix,
|
prefix netip.Prefix,
|
||||||
machines Machines,
|
machines types.Machines,
|
||||||
) (*netipx.IPSet, error) {
|
) (*netipx.IPSet, error) {
|
||||||
log.Trace().Str("prefix", prefix.String()).Msg("expandAlias got prefix")
|
log.Trace().Str("prefix", prefix.String()).Msg("expandAlias got prefix")
|
||||||
build := netipx.IPSetBuilder{}
|
build := netipx.IPSetBuilder{}
|
||||||
|
@ -862,3 +862,65 @@ func isGroup(str string) bool {
|
||||||
func isTag(str string) bool {
|
func isTag(str string) bool {
|
||||||
return strings.HasPrefix(str, "tag:")
|
return strings.HasPrefix(str, "tag:")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getTags will return the tags of the current machine.
|
||||||
|
// Invalid tags are tags added by a user on a node, and that user doesn't have authority to add this tag.
|
||||||
|
// Valid tags are tags added by a user that is allowed in the ACL policy to add this tag.
|
||||||
|
func (pol *ACLPolicy) GetTagsOfMachine(
|
||||||
|
machine types.Machine,
|
||||||
|
stripEmailDomain bool,
|
||||||
|
) ([]string, []string) {
|
||||||
|
validTags := make([]string, 0)
|
||||||
|
invalidTags := make([]string, 0)
|
||||||
|
|
||||||
|
validTagMap := make(map[string]bool)
|
||||||
|
invalidTagMap := make(map[string]bool)
|
||||||
|
for _, tag := range machine.HostInfo.RequestTags {
|
||||||
|
owners, err := getTagOwners(pol, tag, stripEmailDomain)
|
||||||
|
if errors.Is(err, ErrInvalidTag) {
|
||||||
|
invalidTagMap[tag] = true
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var found bool
|
||||||
|
for _, owner := range owners {
|
||||||
|
if machine.User.Name == owner {
|
||||||
|
found = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if found {
|
||||||
|
validTagMap[tag] = true
|
||||||
|
} else {
|
||||||
|
invalidTagMap[tag] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for tag := range invalidTagMap {
|
||||||
|
invalidTags = append(invalidTags, tag)
|
||||||
|
}
|
||||||
|
for tag := range validTagMap {
|
||||||
|
validTags = append(validTags, tag)
|
||||||
|
}
|
||||||
|
|
||||||
|
return validTags, invalidTags
|
||||||
|
}
|
||||||
|
|
||||||
|
// FilterMachinesByACL returns the list of peers authorized to be accessed from a given machine.
|
||||||
|
func FilterMachinesByACL(
|
||||||
|
machine *types.Machine,
|
||||||
|
machines types.Machines,
|
||||||
|
filter []tailcfg.FilterRule,
|
||||||
|
) types.Machines {
|
||||||
|
result := types.Machines{}
|
||||||
|
|
||||||
|
for index, peer := range machines {
|
||||||
|
if peer.ID == machine.ID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if machine.CanAccess(filter, &machines[index]) || peer.CanAccess(filter, machine) {
|
||||||
|
result = append(result, peer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
File diff suppressed because it is too large
Load diff
|
@ -1,4 +1,4 @@
|
||||||
package hscontrol
|
package policy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
61
hscontrol/policy/matcher/matcher.go
Normal file
61
hscontrol/policy/matcher/matcher.go
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
package matcher
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
|
"go4.org/netipx"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Match struct {
|
||||||
|
Srcs *netipx.IPSet
|
||||||
|
Dests *netipx.IPSet
|
||||||
|
}
|
||||||
|
|
||||||
|
func MatchFromFilterRule(rule tailcfg.FilterRule) Match {
|
||||||
|
srcs := new(netipx.IPSetBuilder)
|
||||||
|
dests := new(netipx.IPSetBuilder)
|
||||||
|
|
||||||
|
for _, srcIP := range rule.SrcIPs {
|
||||||
|
set, _ := util.ParseIPSet(srcIP, nil)
|
||||||
|
|
||||||
|
srcs.AddSet(set)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, dest := range rule.DstPorts {
|
||||||
|
set, _ := util.ParseIPSet(dest.IP, nil)
|
||||||
|
|
||||||
|
dests.AddSet(set)
|
||||||
|
}
|
||||||
|
|
||||||
|
srcsSet, _ := srcs.IPSet()
|
||||||
|
destsSet, _ := dests.IPSet()
|
||||||
|
|
||||||
|
match := Match{
|
||||||
|
Srcs: srcsSet,
|
||||||
|
Dests: destsSet,
|
||||||
|
}
|
||||||
|
|
||||||
|
return match
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Match) SrcsContainsIPs(ips []netip.Addr) bool {
|
||||||
|
for _, ip := range ips {
|
||||||
|
if m.Srcs.Contains(ip) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Match) DestsContainsIP(ips []netip.Addr) bool {
|
||||||
|
for _, ip := range ips {
|
||||||
|
if m.Dests.Contains(ip) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
1
hscontrol/policy/matcher/matcher_test.go
Normal file
1
hscontrol/policy/matcher/matcher_test.go
Normal file
|
@ -0,0 +1 @@
|
||||||
|
package matcher
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
@ -171,7 +172,7 @@ func (h *Headscale) handleRegisterCommon(
|
||||||
// that we rely on a method that calls back some how (OpenID or CLI)
|
// that we rely on a method that calls back some how (OpenID or CLI)
|
||||||
// We create the machine and then keep it around until a callback
|
// We create the machine and then keep it around until a callback
|
||||||
// happens
|
// happens
|
||||||
newMachine := Machine{
|
newMachine := types.Machine{
|
||||||
MachineKey: util.MachinePublicKeyStripPrefix(machineKey),
|
MachineKey: util.MachinePublicKeyStripPrefix(machineKey),
|
||||||
Hostname: registerRequest.Hostinfo.Hostname,
|
Hostname: registerRequest.Hostinfo.Hostname,
|
||||||
GivenName: givenName,
|
GivenName: givenName,
|
||||||
|
@ -214,8 +215,7 @@ func (h *Headscale) handleRegisterCommon(
|
||||||
[]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)),
|
[]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)),
|
||||||
)
|
)
|
||||||
if err != nil || storedMachineKey.IsZero() {
|
if err != nil || storedMachineKey.IsZero() {
|
||||||
machine.MachineKey = util.MachinePublicKeyStripPrefix(machineKey)
|
if err := h.db.MachineSetMachineKey(machine, machineKey); err != nil {
|
||||||
if err := h.db.db.Save(&machine).Error; err != nil {
|
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
Str("func", "RegistrationHandler").
|
Str("func", "RegistrationHandler").
|
||||||
|
@ -244,7 +244,7 @@ func (h *Headscale) handleRegisterCommon(
|
||||||
|
|
||||||
// If machine is not expired, and it is register, we have a already accepted this machine,
|
// If machine is not expired, and it is register, we have a already accepted this machine,
|
||||||
// let it proceed with a valid registration
|
// let it proceed with a valid registration
|
||||||
if !machine.isExpired() {
|
if !machine.IsExpired() {
|
||||||
h.handleMachineValidRegistrationCommon(writer, *machine, machineKey, isNoise)
|
h.handleMachineValidRegistrationCommon(writer, *machine, machineKey, isNoise)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
@ -253,7 +253,7 @@ func (h *Headscale) handleRegisterCommon(
|
||||||
|
|
||||||
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
|
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
|
||||||
if machine.NodeKey == util.NodePublicKeyStripPrefix(registerRequest.OldNodeKey) &&
|
if machine.NodeKey == util.NodePublicKeyStripPrefix(registerRequest.OldNodeKey) &&
|
||||||
!machine.isExpired() {
|
!machine.IsExpired() {
|
||||||
h.handleMachineRefreshKeyCommon(
|
h.handleMachineRefreshKeyCommon(
|
||||||
writer,
|
writer,
|
||||||
registerRequest,
|
registerRequest,
|
||||||
|
@ -312,7 +312,7 @@ func (h *Headscale) handleAuthKeyCommon(
|
||||||
Msgf("Processing auth key for %s", registerRequest.Hostinfo.Hostname)
|
Msgf("Processing auth key for %s", registerRequest.Hostinfo.Hostname)
|
||||||
resp := tailcfg.RegisterResponse{}
|
resp := tailcfg.RegisterResponse{}
|
||||||
|
|
||||||
pak, err := h.db.checkKeyValidity(registerRequest.Auth.AuthKey)
|
pak, err := h.db.ValidatePreAuthKey(registerRequest.Auth.AuthKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -333,7 +333,7 @@ func (h *Headscale) handleAuthKeyCommon(
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot encode message")
|
Msg("Cannot encode message")
|
||||||
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
||||||
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name).
|
machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name).
|
||||||
Inc()
|
Inc()
|
||||||
|
|
||||||
return
|
return
|
||||||
|
@ -358,10 +358,10 @@ func (h *Headscale) handleAuthKeyCommon(
|
||||||
Msg("Failed authentication via AuthKey")
|
Msg("Failed authentication via AuthKey")
|
||||||
|
|
||||||
if pak != nil {
|
if pak != nil {
|
||||||
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name).
|
machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name).
|
||||||
Inc()
|
Inc()
|
||||||
} else {
|
} else {
|
||||||
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", "unknown").Inc()
|
machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", "unknown").Inc()
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
|
@ -401,10 +401,10 @@ func (h *Headscale) handleAuthKeyCommon(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
aclTags := pak.toProto().AclTags
|
aclTags := pak.Proto().AclTags
|
||||||
if len(aclTags) > 0 {
|
if len(aclTags) > 0 {
|
||||||
// This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login
|
// This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login
|
||||||
err = h.db.SetTags(machine, aclTags, h.UpdateACLRules)
|
err = h.db.SetTags(machine, aclTags)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
|
@ -433,17 +433,17 @@ func (h *Headscale) handleAuthKeyCommon(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
machineToRegister := Machine{
|
machineToRegister := types.Machine{
|
||||||
Hostname: registerRequest.Hostinfo.Hostname,
|
Hostname: registerRequest.Hostinfo.Hostname,
|
||||||
GivenName: givenName,
|
GivenName: givenName,
|
||||||
UserID: pak.User.ID,
|
UserID: pak.User.ID,
|
||||||
MachineKey: util.MachinePublicKeyStripPrefix(machineKey),
|
MachineKey: util.MachinePublicKeyStripPrefix(machineKey),
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
Expiry: ®isterRequest.Expiry,
|
Expiry: ®isterRequest.Expiry,
|
||||||
NodeKey: nodeKey,
|
NodeKey: nodeKey,
|
||||||
LastSeen: &now,
|
LastSeen: &now,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
ForcedTags: pak.toProto().AclTags,
|
ForcedTags: pak.Proto().AclTags,
|
||||||
}
|
}
|
||||||
|
|
||||||
machine, err = h.db.RegisterMachine(
|
machine, err = h.db.RegisterMachine(
|
||||||
|
@ -455,7 +455,7 @@ func (h *Headscale) handleAuthKeyCommon(
|
||||||
Bool("noise", isNoise).
|
Bool("noise", isNoise).
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("could not register machine")
|
Msg("could not register machine")
|
||||||
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name).
|
machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name).
|
||||||
Inc()
|
Inc()
|
||||||
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
||||||
|
|
||||||
|
@ -470,7 +470,7 @@ func (h *Headscale) handleAuthKeyCommon(
|
||||||
Bool("noise", isNoise).
|
Bool("noise", isNoise).
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Failed to use pre-auth key")
|
Msg("Failed to use pre-auth key")
|
||||||
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name).
|
machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name).
|
||||||
Inc()
|
Inc()
|
||||||
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
||||||
|
|
||||||
|
@ -478,10 +478,10 @@ func (h *Headscale) handleAuthKeyCommon(
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.MachineAuthorized = true
|
resp.MachineAuthorized = true
|
||||||
resp.User = *pak.User.toTailscaleUser()
|
resp.User = *pak.User.TailscaleUser()
|
||||||
// Provide LoginName when registering with pre-auth key
|
// Provide LoginName when registering with pre-auth key
|
||||||
// Otherwise it will need to exec `tailscale up` twice to fetch the *LoginName*
|
// Otherwise it will need to exec `tailscale up` twice to fetch the *LoginName*
|
||||||
resp.Login = *pak.User.toTailscaleLogin()
|
resp.Login = *pak.User.TailscaleLogin()
|
||||||
|
|
||||||
respBody, err := h.marshalResponse(resp, machineKey, isNoise)
|
respBody, err := h.marshalResponse(resp, machineKey, isNoise)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -492,13 +492,13 @@ func (h *Headscale) handleAuthKeyCommon(
|
||||||
Str("machine", registerRequest.Hostinfo.Hostname).
|
Str("machine", registerRequest.Hostinfo.Hostname).
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot encode message")
|
Msg("Cannot encode message")
|
||||||
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name).
|
machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name).
|
||||||
Inc()
|
Inc()
|
||||||
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "success", pak.User.Name).
|
machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "success", pak.User.Name).
|
||||||
Inc()
|
Inc()
|
||||||
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||||
writer.WriteHeader(http.StatusOK)
|
writer.WriteHeader(http.StatusOK)
|
||||||
|
@ -581,7 +581,7 @@ func (h *Headscale) handleNewMachineCommon(
|
||||||
|
|
||||||
func (h *Headscale) handleMachineLogOutCommon(
|
func (h *Headscale) handleMachineLogOutCommon(
|
||||||
writer http.ResponseWriter,
|
writer http.ResponseWriter,
|
||||||
machine Machine,
|
machine types.Machine,
|
||||||
machineKey key.MachinePublic,
|
machineKey key.MachinePublic,
|
||||||
isNoise bool,
|
isNoise bool,
|
||||||
) {
|
) {
|
||||||
|
@ -608,7 +608,7 @@ func (h *Headscale) handleMachineLogOutCommon(
|
||||||
resp.AuthURL = ""
|
resp.AuthURL = ""
|
||||||
resp.MachineAuthorized = false
|
resp.MachineAuthorized = false
|
||||||
resp.NodeKeyExpired = true
|
resp.NodeKeyExpired = true
|
||||||
resp.User = *machine.User.toTailscaleUser()
|
resp.User = *machine.User.TailscaleUser()
|
||||||
respBody, err := h.marshalResponse(resp, machineKey, isNoise)
|
respBody, err := h.marshalResponse(resp, machineKey, isNoise)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
|
@ -634,7 +634,7 @@ func (h *Headscale) handleMachineLogOutCommon(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if machine.isEphemeral() {
|
if machine.IsEphemeral() {
|
||||||
err = h.db.HardDeleteMachine(&machine)
|
err = h.db.HardDeleteMachine(&machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
|
@ -655,7 +655,7 @@ func (h *Headscale) handleMachineLogOutCommon(
|
||||||
|
|
||||||
func (h *Headscale) handleMachineValidRegistrationCommon(
|
func (h *Headscale) handleMachineValidRegistrationCommon(
|
||||||
writer http.ResponseWriter,
|
writer http.ResponseWriter,
|
||||||
machine Machine,
|
machine types.Machine,
|
||||||
machineKey key.MachinePublic,
|
machineKey key.MachinePublic,
|
||||||
isNoise bool,
|
isNoise bool,
|
||||||
) {
|
) {
|
||||||
|
@ -670,8 +670,8 @@ func (h *Headscale) handleMachineValidRegistrationCommon(
|
||||||
|
|
||||||
resp.AuthURL = ""
|
resp.AuthURL = ""
|
||||||
resp.MachineAuthorized = true
|
resp.MachineAuthorized = true
|
||||||
resp.User = *machine.User.toTailscaleUser()
|
resp.User = *machine.User.TailscaleUser()
|
||||||
resp.Login = *machine.User.toTailscaleLogin()
|
resp.Login = *machine.User.TailscaleLogin()
|
||||||
|
|
||||||
respBody, err := h.marshalResponse(resp, machineKey, isNoise)
|
respBody, err := h.marshalResponse(resp, machineKey, isNoise)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -710,7 +710,7 @@ func (h *Headscale) handleMachineValidRegistrationCommon(
|
||||||
func (h *Headscale) handleMachineRefreshKeyCommon(
|
func (h *Headscale) handleMachineRefreshKeyCommon(
|
||||||
writer http.ResponseWriter,
|
writer http.ResponseWriter,
|
||||||
registerRequest tailcfg.RegisterRequest,
|
registerRequest tailcfg.RegisterRequest,
|
||||||
machine Machine,
|
machine types.Machine,
|
||||||
machineKey key.MachinePublic,
|
machineKey key.MachinePublic,
|
||||||
isNoise bool,
|
isNoise bool,
|
||||||
) {
|
) {
|
||||||
|
@ -721,9 +721,9 @@ func (h *Headscale) handleMachineRefreshKeyCommon(
|
||||||
Bool("noise", isNoise).
|
Bool("noise", isNoise).
|
||||||
Str("machine", machine.Hostname).
|
Str("machine", machine.Hostname).
|
||||||
Msg("We have the OldNodeKey in the database. This is a key refresh")
|
Msg("We have the OldNodeKey in the database. This is a key refresh")
|
||||||
machine.NodeKey = util.NodePublicKeyStripPrefix(registerRequest.NodeKey)
|
|
||||||
|
|
||||||
if err := h.db.db.Save(&machine).Error; err != nil {
|
err := h.db.MachineSetNodeKey(&machine, registerRequest.NodeKey)
|
||||||
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
Err(err).
|
Err(err).
|
||||||
|
@ -734,7 +734,7 @@ func (h *Headscale) handleMachineRefreshKeyCommon(
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.AuthURL = ""
|
resp.AuthURL = ""
|
||||||
resp.User = *machine.User.toTailscaleUser()
|
resp.User = *machine.User.TailscaleUser()
|
||||||
respBody, err := h.marshalResponse(resp, machineKey, isNoise)
|
respBody, err := h.marshalResponse(resp, machineKey, isNoise)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
|
@ -770,7 +770,7 @@ func (h *Headscale) handleMachineRefreshKeyCommon(
|
||||||
func (h *Headscale) handleMachineExpiredOrLoggedOutCommon(
|
func (h *Headscale) handleMachineExpiredOrLoggedOutCommon(
|
||||||
writer http.ResponseWriter,
|
writer http.ResponseWriter,
|
||||||
registerRequest tailcfg.RegisterRequest,
|
registerRequest tailcfg.RegisterRequest,
|
||||||
machine Machine,
|
machine types.Machine,
|
||||||
machineKey key.MachinePublic,
|
machineKey key.MachinePublic,
|
||||||
isNoise bool,
|
isNoise bool,
|
||||||
) {
|
) {
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
|
@ -24,16 +25,16 @@ const machineNameContextKey = contextKey("machineName")
|
||||||
func (h *Headscale) handlePollCommon(
|
func (h *Headscale) handlePollCommon(
|
||||||
writer http.ResponseWriter,
|
writer http.ResponseWriter,
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
machine *Machine,
|
machine *types.Machine,
|
||||||
mapRequest tailcfg.MapRequest,
|
mapRequest tailcfg.MapRequest,
|
||||||
isNoise bool,
|
isNoise bool,
|
||||||
) {
|
) {
|
||||||
machine.Hostname = mapRequest.Hostinfo.Hostname
|
machine.Hostname = mapRequest.Hostinfo.Hostname
|
||||||
machine.HostInfo = HostInfo(*mapRequest.Hostinfo)
|
machine.HostInfo = types.HostInfo(*mapRequest.Hostinfo)
|
||||||
machine.DiscoKey = util.DiscoPublicKeyStripPrefix(mapRequest.DiscoKey)
|
machine.DiscoKey = util.DiscoPublicKeyStripPrefix(mapRequest.DiscoKey)
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
|
|
||||||
err := h.db.processMachineRoutes(machine)
|
err := h.db.ProcessMachineRoutes(machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -43,18 +44,13 @@ func (h *Headscale) handlePollCommon(
|
||||||
}
|
}
|
||||||
|
|
||||||
// update ACLRules with peer informations (to update server tags if necessary)
|
// update ACLRules with peer informations (to update server tags if necessary)
|
||||||
if h.aclPolicy != nil {
|
if h.ACLPolicy != nil {
|
||||||
err := h.UpdateACLRules()
|
// TODO(kradalby): Since this is not blocking, I might have introduced a bug here.
|
||||||
if err != nil {
|
// It will be resolved later as we change up the policy stuff.
|
||||||
log.Error().
|
h.policyUpdateChan <- struct{}{}
|
||||||
Caller().
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Err(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// update routes with peer information
|
// update routes with peer information
|
||||||
err = h.db.EnableAutoApprovedRoutes(h.aclPolicy, machine)
|
err = h.db.EnableAutoApprovedRoutes(h.ACLPolicy, machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -78,8 +74,7 @@ func (h *Headscale) handlePollCommon(
|
||||||
machine.LastSeen = &now
|
machine.LastSeen = &now
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.db.db.Updates(machine).Error; err != nil {
|
if err := h.db.MachineSave(machine); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "PollNetMap").
|
Str("handler", "PollNetMap").
|
||||||
Bool("noise", isNoise).
|
Bool("noise", isNoise).
|
||||||
|
@ -91,7 +86,6 @@ func (h *Headscale) handlePollCommon(
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
mapResp, err := h.getMapResponseData(mapRequest, machine, isNoise)
|
mapResp, err := h.getMapResponseData(mapRequest, machine, isNoise)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -244,7 +238,7 @@ func (h *Headscale) handlePollCommon(
|
||||||
func (h *Headscale) pollNetMapStream(
|
func (h *Headscale) pollNetMapStream(
|
||||||
writer http.ResponseWriter,
|
writer http.ResponseWriter,
|
||||||
ctxReq context.Context,
|
ctxReq context.Context,
|
||||||
machine *Machine,
|
machine *types.Machine,
|
||||||
mapRequest tailcfg.MapRequest,
|
mapRequest tailcfg.MapRequest,
|
||||||
pollDataChan chan []byte,
|
pollDataChan chan []byte,
|
||||||
keepAliveChan chan []byte,
|
keepAliveChan chan []byte,
|
||||||
|
@ -457,7 +451,7 @@ func (h *Headscale) pollNetMapStream(
|
||||||
updateRequestsReceivedOnChannel.WithLabelValues(machine.User.Name, machine.Hostname).
|
updateRequestsReceivedOnChannel.WithLabelValues(machine.User.Name, machine.Hostname).
|
||||||
Inc()
|
Inc()
|
||||||
|
|
||||||
if h.db.isOutdated(machine, h.getLastStateChange()) {
|
if h.db.IsOutdated(machine, h.getLastStateChange()) {
|
||||||
var lastUpdate time.Time
|
var lastUpdate time.Time
|
||||||
if machine.LastSuccessfulUpdate != nil {
|
if machine.LastSuccessfulUpdate != nil {
|
||||||
lastUpdate = *machine.LastSuccessfulUpdate
|
lastUpdate = *machine.LastSuccessfulUpdate
|
||||||
|
@ -626,7 +620,7 @@ func (h *Headscale) scheduledPollWorker(
|
||||||
updateChan chan struct{},
|
updateChan chan struct{},
|
||||||
keepAliveChan chan []byte,
|
keepAliveChan chan []byte,
|
||||||
mapRequest tailcfg.MapRequest,
|
mapRequest tailcfg.MapRequest,
|
||||||
machine *Machine,
|
machine *types.Machine,
|
||||||
isNoise bool,
|
isNoise bool,
|
||||||
) {
|
) {
|
||||||
keepAliveTicker := time.NewTicker(keepAliveInterval)
|
keepAliveTicker := time.NewTicker(keepAliveInterval)
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/klauspost/compress/zstd"
|
"github.com/klauspost/compress/zstd"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
@ -15,7 +16,7 @@ import (
|
||||||
|
|
||||||
func (h *Headscale) getMapResponseData(
|
func (h *Headscale) getMapResponseData(
|
||||||
mapRequest tailcfg.MapRequest,
|
mapRequest tailcfg.MapRequest,
|
||||||
machine *Machine,
|
machine *types.Machine,
|
||||||
isNoise bool,
|
isNoise bool,
|
||||||
) ([]byte, error) {
|
) ([]byte, error) {
|
||||||
mapResponse, err := h.generateMapResponse(mapRequest, machine)
|
mapResponse, err := h.generateMapResponse(mapRequest, machine)
|
||||||
|
@ -43,7 +44,7 @@ func (h *Headscale) getMapResponseData(
|
||||||
|
|
||||||
func (h *Headscale) getMapKeepAliveResponseData(
|
func (h *Headscale) getMapKeepAliveResponseData(
|
||||||
mapRequest tailcfg.MapRequest,
|
mapRequest tailcfg.MapRequest,
|
||||||
machine *Machine,
|
machine *types.Machine,
|
||||||
isNoise bool,
|
isNoise bool,
|
||||||
) ([]byte, error) {
|
) ([]byte, error) {
|
||||||
keepAliveResponse := tailcfg.MapResponse{
|
keepAliveResponse := tailcfg.MapResponse{
|
||||||
|
|
|
@ -18,7 +18,7 @@ type Suite struct{}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
tmpDir string
|
tmpDir string
|
||||||
app Headscale
|
app *Headscale
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Suite) SetUpTest(c *check.C) {
|
func (s *Suite) SetUpTest(c *check.C) {
|
||||||
|
@ -34,11 +34,15 @@ func (s *Suite) ResetDB(c *check.C) {
|
||||||
os.RemoveAll(tmpDir)
|
os.RemoveAll(tmpDir)
|
||||||
}
|
}
|
||||||
var err error
|
var err error
|
||||||
tmpDir, err = os.MkdirTemp("", "autoygg-client-test")
|
tmpDir, err = os.MkdirTemp("", "autoygg-client-test2")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Fatal(err)
|
c.Fatal(err)
|
||||||
}
|
}
|
||||||
cfg := Config{
|
cfg := Config{
|
||||||
|
PrivateKeyPath: tmpDir + "/private.key",
|
||||||
|
NoisePrivateKeyPath: tmpDir + "/noise_private.key",
|
||||||
|
DBtype: "sqlite3",
|
||||||
|
DBpath: tmpDir + "/headscale_test.db",
|
||||||
IPPrefixes: []netip.Prefix{
|
IPPrefixes: []netip.Prefix{
|
||||||
netip.MustParsePrefix("10.27.0.0/23"),
|
netip.MustParsePrefix("10.27.0.0/23"),
|
||||||
},
|
},
|
||||||
|
@ -47,29 +51,8 @@ func (s *Suite) ResetDB(c *check.C) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(kradalby): make this use NewHeadscale properly so it doesnt drift
|
app, err = NewHeadscale(&cfg)
|
||||||
app = Headscale{
|
|
||||||
cfg: &cfg,
|
|
||||||
dbType: "sqlite3",
|
|
||||||
dbString: tmpDir + "/headscale_test.db",
|
|
||||||
|
|
||||||
stateUpdateChan: make(chan struct{}),
|
|
||||||
cancelStateUpdateChan: make(chan struct{}),
|
|
||||||
}
|
|
||||||
|
|
||||||
go app.watchStateChannel()
|
|
||||||
|
|
||||||
db, err := NewHeadscaleDatabase(
|
|
||||||
app.dbType,
|
|
||||||
app.dbString,
|
|
||||||
cfg.OIDC.StripEmaildomain,
|
|
||||||
false,
|
|
||||||
app.stateUpdateChan,
|
|
||||||
cfg.IPPrefixes,
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Fatal(err)
|
c.Fatal(err)
|
||||||
}
|
}
|
||||||
app.db = db
|
|
||||||
}
|
}
|
41
hscontrol/types/api_key.go
Normal file
41
hscontrol/types/api_key.go
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
package types
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
)
|
||||||
|
|
||||||
|
// APIKey describes the datamodel for API keys used to remotely authenticate with
|
||||||
|
// headscale.
|
||||||
|
type APIKey struct {
|
||||||
|
ID uint64 `gorm:"primary_key"`
|
||||||
|
Prefix string `gorm:"uniqueIndex"`
|
||||||
|
Hash []byte
|
||||||
|
|
||||||
|
CreatedAt *time.Time
|
||||||
|
Expiration *time.Time
|
||||||
|
LastSeen *time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (key *APIKey) Proto() *v1.ApiKey {
|
||||||
|
protoKey := v1.ApiKey{
|
||||||
|
Id: key.ID,
|
||||||
|
Prefix: key.Prefix,
|
||||||
|
}
|
||||||
|
|
||||||
|
if key.Expiration != nil {
|
||||||
|
protoKey.Expiration = timestamppb.New(*key.Expiration)
|
||||||
|
}
|
||||||
|
|
||||||
|
if key.CreatedAt != nil {
|
||||||
|
protoKey.CreatedAt = timestamppb.New(*key.CreatedAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
if key.LastSeen != nil {
|
||||||
|
protoKey.LastSeen = timestamppb.New(*key.LastSeen)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &protoKey
|
||||||
|
}
|
108
hscontrol/types/common.go
Normal file
108
hscontrol/types/common.go
Normal file
|
@ -0,0 +1,108 @@
|
||||||
|
package types
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrCannotParsePrefix = errors.New("cannot parse prefix")
|
||||||
|
|
||||||
|
// This is a "wrapper" type around tailscales
|
||||||
|
// Hostinfo to allow us to add database "serialization"
|
||||||
|
// methods. This allows us to use a typed values throughout
|
||||||
|
// the code and not have to marshal/unmarshal and error
|
||||||
|
// check all over the code.
|
||||||
|
type HostInfo tailcfg.Hostinfo
|
||||||
|
|
||||||
|
func (hi *HostInfo) Scan(destination interface{}) error {
|
||||||
|
switch value := destination.(type) {
|
||||||
|
case []byte:
|
||||||
|
return json.Unmarshal(value, hi)
|
||||||
|
|
||||||
|
case string:
|
||||||
|
return json.Unmarshal([]byte(value), hi)
|
||||||
|
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value return json value, implement driver.Valuer interface.
|
||||||
|
func (hi HostInfo) Value() (driver.Value, error) {
|
||||||
|
bytes, err := json.Marshal(hi)
|
||||||
|
|
||||||
|
return string(bytes), err
|
||||||
|
}
|
||||||
|
|
||||||
|
type IPPrefix netip.Prefix
|
||||||
|
|
||||||
|
func (i *IPPrefix) Scan(destination interface{}) error {
|
||||||
|
switch value := destination.(type) {
|
||||||
|
case string:
|
||||||
|
prefix, err := netip.ParsePrefix(value)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*i = IPPrefix(prefix)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("%w: unexpected data type %T", ErrCannotParsePrefix, destination)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value return json value, implement driver.Valuer interface.
|
||||||
|
func (i IPPrefix) Value() (driver.Value, error) {
|
||||||
|
prefixStr := netip.Prefix(i).String()
|
||||||
|
|
||||||
|
return prefixStr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type IPPrefixes []netip.Prefix
|
||||||
|
|
||||||
|
func (i *IPPrefixes) Scan(destination interface{}) error {
|
||||||
|
switch value := destination.(type) {
|
||||||
|
case []byte:
|
||||||
|
return json.Unmarshal(value, i)
|
||||||
|
|
||||||
|
case string:
|
||||||
|
return json.Unmarshal([]byte(value), i)
|
||||||
|
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value return json value, implement driver.Valuer interface.
|
||||||
|
func (i IPPrefixes) Value() (driver.Value, error) {
|
||||||
|
bytes, err := json.Marshal(i)
|
||||||
|
|
||||||
|
return string(bytes), err
|
||||||
|
}
|
||||||
|
|
||||||
|
type StringList []string
|
||||||
|
|
||||||
|
func (i *StringList) Scan(destination interface{}) error {
|
||||||
|
switch value := destination.(type) {
|
||||||
|
case []byte:
|
||||||
|
return json.Unmarshal(value, i)
|
||||||
|
|
||||||
|
case string:
|
||||||
|
return json.Unmarshal([]byte(value), i)
|
||||||
|
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value return json value, implement driver.Valuer interface.
|
||||||
|
func (i StringList) Value() (driver.Value, error) {
|
||||||
|
bytes, err := json.Marshal(i)
|
||||||
|
|
||||||
|
return string(bytes), err
|
||||||
|
}
|
254
hscontrol/types/machine.go
Normal file
254
hscontrol/types/machine.go
Normal file
|
@ -0,0 +1,254 @@
|
||||||
|
package types
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||||
|
"go4.org/netipx"
|
||||||
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// TODO(kradalby): Move out of here when we got circdeps under control.
|
||||||
|
keepAliveInterval = 60 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrMachineAddressesInvalid = errors.New("failed to parse machine addresses")
|
||||||
|
|
||||||
|
// Machine is a Headscale client.
|
||||||
|
type Machine struct {
|
||||||
|
ID uint64 `gorm:"primary_key"`
|
||||||
|
MachineKey string `gorm:"type:varchar(64);unique_index"`
|
||||||
|
NodeKey string
|
||||||
|
DiscoKey string
|
||||||
|
IPAddresses MachineAddresses
|
||||||
|
|
||||||
|
// Hostname represents the name given by the Tailscale
|
||||||
|
// client during registration
|
||||||
|
Hostname string
|
||||||
|
|
||||||
|
// Givenname represents either:
|
||||||
|
// a DNS normalized version of Hostname
|
||||||
|
// a valid name set by the User
|
||||||
|
//
|
||||||
|
// GivenName is the name used in all DNS related
|
||||||
|
// parts of headscale.
|
||||||
|
GivenName string `gorm:"type:varchar(63);unique_index"`
|
||||||
|
UserID uint
|
||||||
|
User User `gorm:"foreignKey:UserID"`
|
||||||
|
|
||||||
|
RegisterMethod string
|
||||||
|
|
||||||
|
ForcedTags StringList
|
||||||
|
|
||||||
|
// TODO(kradalby): This seems like irrelevant information?
|
||||||
|
AuthKeyID uint
|
||||||
|
AuthKey *PreAuthKey
|
||||||
|
|
||||||
|
LastSeen *time.Time
|
||||||
|
LastSuccessfulUpdate *time.Time
|
||||||
|
Expiry *time.Time
|
||||||
|
|
||||||
|
HostInfo HostInfo
|
||||||
|
Endpoints StringList
|
||||||
|
|
||||||
|
CreatedAt time.Time
|
||||||
|
UpdatedAt time.Time
|
||||||
|
DeletedAt *time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type (
|
||||||
|
Machines []Machine
|
||||||
|
MachinesP []*Machine
|
||||||
|
)
|
||||||
|
|
||||||
|
type MachineAddresses []netip.Addr
|
||||||
|
|
||||||
|
func (ma MachineAddresses) ToStringSlice() []string {
|
||||||
|
strSlice := make([]string, 0, len(ma))
|
||||||
|
for _, addr := range ma {
|
||||||
|
strSlice = append(strSlice, addr.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
return strSlice
|
||||||
|
}
|
||||||
|
|
||||||
|
// AppendToIPSet adds the individual ips in MachineAddresses to a
|
||||||
|
// given netipx.IPSetBuilder.
|
||||||
|
func (ma MachineAddresses) AppendToIPSet(build *netipx.IPSetBuilder) {
|
||||||
|
for _, ip := range ma {
|
||||||
|
build.Add(ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ma *MachineAddresses) Scan(destination interface{}) error {
|
||||||
|
switch value := destination.(type) {
|
||||||
|
case string:
|
||||||
|
addresses := strings.Split(value, ",")
|
||||||
|
*ma = (*ma)[:0]
|
||||||
|
for _, addr := range addresses {
|
||||||
|
if len(addr) < 1 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
parsed, err := netip.ParseAddr(addr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*ma = append(*ma, parsed)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value return json value, implement driver.Valuer interface.
|
||||||
|
func (ma MachineAddresses) Value() (driver.Value, error) {
|
||||||
|
addresses := strings.Join(ma.ToStringSlice(), ",")
|
||||||
|
|
||||||
|
return addresses, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsExpired returns whether the machine registration has expired.
|
||||||
|
func (machine Machine) IsExpired() bool {
|
||||||
|
// If Expiry is not set, the client has not indicated that
|
||||||
|
// it wants an expiry time, it is therefor considered
|
||||||
|
// to mean "not expired"
|
||||||
|
if machine.Expiry == nil || machine.Expiry.IsZero() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return time.Now().UTC().After(*machine.Expiry)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsOnline returns if the machine is connected to Headscale.
|
||||||
|
// This is really a naive implementation, as we don't really see
|
||||||
|
// if there is a working connection between the client and the server.
|
||||||
|
func (machine *Machine) IsOnline() bool {
|
||||||
|
if machine.LastSeen == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if machine.IsExpired() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return machine.LastSeen.After(time.Now().Add(-keepAliveInterval))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsEphemeral returns if the machine is registered as an Ephemeral node.
|
||||||
|
// https://tailscale.com/kb/1111/ephemeral-nodes/
|
||||||
|
func (machine *Machine) IsEphemeral() bool {
|
||||||
|
return machine.AuthKey != nil && machine.AuthKey.Ephemeral
|
||||||
|
}
|
||||||
|
|
||||||
|
func (machine *Machine) CanAccess(filter []tailcfg.FilterRule, machine2 *Machine) bool {
|
||||||
|
for _, rule := range filter {
|
||||||
|
// TODO(kradalby): Cache or pregen this
|
||||||
|
matcher := matcher.MatchFromFilterRule(rule)
|
||||||
|
|
||||||
|
if !matcher.SrcsContainsIPs([]netip.Addr(machine.IPAddresses)) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if matcher.DestsContainsIP([]netip.Addr(machine2.IPAddresses)) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (machines Machines) FilterByIP(ip netip.Addr) Machines {
|
||||||
|
found := make(Machines, 0)
|
||||||
|
|
||||||
|
for _, machine := range machines {
|
||||||
|
for _, mIP := range machine.IPAddresses {
|
||||||
|
if ip == mIP {
|
||||||
|
found = append(found, machine)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return found
|
||||||
|
}
|
||||||
|
|
||||||
|
func (machine *Machine) Proto() *v1.Machine {
|
||||||
|
machineProto := &v1.Machine{
|
||||||
|
Id: machine.ID,
|
||||||
|
MachineKey: machine.MachineKey,
|
||||||
|
|
||||||
|
NodeKey: machine.NodeKey,
|
||||||
|
DiscoKey: machine.DiscoKey,
|
||||||
|
IpAddresses: machine.IPAddresses.ToStringSlice(),
|
||||||
|
Name: machine.Hostname,
|
||||||
|
GivenName: machine.GivenName,
|
||||||
|
User: machine.User.Proto(),
|
||||||
|
ForcedTags: machine.ForcedTags,
|
||||||
|
Online: machine.IsOnline(),
|
||||||
|
|
||||||
|
// TODO(kradalby): Implement register method enum converter
|
||||||
|
// RegisterMethod: ,
|
||||||
|
|
||||||
|
CreatedAt: timestamppb.New(machine.CreatedAt),
|
||||||
|
}
|
||||||
|
|
||||||
|
if machine.AuthKey != nil {
|
||||||
|
machineProto.PreAuthKey = machine.AuthKey.Proto()
|
||||||
|
}
|
||||||
|
|
||||||
|
if machine.LastSeen != nil {
|
||||||
|
machineProto.LastSeen = timestamppb.New(*machine.LastSeen)
|
||||||
|
}
|
||||||
|
|
||||||
|
if machine.LastSuccessfulUpdate != nil {
|
||||||
|
machineProto.LastSuccessfulUpdate = timestamppb.New(
|
||||||
|
*machine.LastSuccessfulUpdate,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if machine.Expiry != nil {
|
||||||
|
machineProto.Expiry = timestamppb.New(*machine.Expiry)
|
||||||
|
}
|
||||||
|
|
||||||
|
return machineProto
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHostInfo returns a Hostinfo struct for the machine.
|
||||||
|
func (machine *Machine) GetHostInfo() tailcfg.Hostinfo {
|
||||||
|
return tailcfg.Hostinfo(machine.HostInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (machine Machine) String() string {
|
||||||
|
return machine.Hostname
|
||||||
|
}
|
||||||
|
|
||||||
|
func (machines Machines) String() string {
|
||||||
|
temp := make([]string, len(machines))
|
||||||
|
|
||||||
|
for index, machine := range machines {
|
||||||
|
temp[index] = machine.Hostname
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(kradalby): Remove when we have generics...
|
||||||
|
func (machines MachinesP) String() string {
|
||||||
|
temp := make([]string, len(machines))
|
||||||
|
|
||||||
|
for index, machine := range machines {
|
||||||
|
temp[index] = machine.Hostname
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
|
||||||
|
}
|
1
hscontrol/types/machine_test.go
Normal file
1
hscontrol/types/machine_test.go
Normal file
|
@ -0,0 +1 @@
|
||||||
|
package types
|
58
hscontrol/types/preauth_key.go
Normal file
58
hscontrol/types/preauth_key.go
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
package types
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PreAuthKey describes a pre-authorization key usable in a particular user.
|
||||||
|
type PreAuthKey struct {
|
||||||
|
ID uint64 `gorm:"primary_key"`
|
||||||
|
Key string
|
||||||
|
UserID uint
|
||||||
|
User User
|
||||||
|
Reusable bool
|
||||||
|
Ephemeral bool `gorm:"default:false"`
|
||||||
|
Used bool `gorm:"default:false"`
|
||||||
|
ACLTags []PreAuthKeyACLTag
|
||||||
|
|
||||||
|
CreatedAt *time.Time
|
||||||
|
Expiration *time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// PreAuthKeyACLTag describes an autmatic tag applied to a node when registered with the associated PreAuthKey.
|
||||||
|
type PreAuthKeyACLTag struct {
|
||||||
|
ID uint64 `gorm:"primary_key"`
|
||||||
|
PreAuthKeyID uint64
|
||||||
|
Tag string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (key *PreAuthKey) Proto() *v1.PreAuthKey {
|
||||||
|
protoKey := v1.PreAuthKey{
|
||||||
|
User: key.User.Name,
|
||||||
|
Id: strconv.FormatUint(key.ID, util.Base10),
|
||||||
|
Key: key.Key,
|
||||||
|
Ephemeral: key.Ephemeral,
|
||||||
|
Reusable: key.Reusable,
|
||||||
|
Used: key.Used,
|
||||||
|
AclTags: make([]string, len(key.ACLTags)),
|
||||||
|
}
|
||||||
|
|
||||||
|
if key.Expiration != nil {
|
||||||
|
protoKey.Expiration = timestamppb.New(*key.Expiration)
|
||||||
|
}
|
||||||
|
|
||||||
|
if key.CreatedAt != nil {
|
||||||
|
protoKey.CreatedAt = timestamppb.New(*key.CreatedAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
for idx := range key.ACLTags {
|
||||||
|
protoKey.AclTags[idx] = key.ACLTags[idx].Tag
|
||||||
|
}
|
||||||
|
|
||||||
|
return &protoKey
|
||||||
|
}
|
71
hscontrol/types/routes.go
Normal file
71
hscontrol/types/routes.go
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
package types
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ExitRouteV4 = netip.MustParsePrefix("0.0.0.0/0")
|
||||||
|
ExitRouteV6 = netip.MustParsePrefix("::/0")
|
||||||
|
)
|
||||||
|
|
||||||
|
type Route struct {
|
||||||
|
gorm.Model
|
||||||
|
|
||||||
|
MachineID uint64
|
||||||
|
Machine Machine
|
||||||
|
Prefix IPPrefix
|
||||||
|
|
||||||
|
Advertised bool
|
||||||
|
Enabled bool
|
||||||
|
IsPrimary bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type Routes []Route
|
||||||
|
|
||||||
|
func (r *Route) String() string {
|
||||||
|
return fmt.Sprintf("%s:%s", r.Machine, netip.Prefix(r.Prefix).String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Route) IsExitRoute() bool {
|
||||||
|
return netip.Prefix(r.Prefix) == ExitRouteV4 || netip.Prefix(r.Prefix) == ExitRouteV6
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rs Routes) Prefixes() []netip.Prefix {
|
||||||
|
prefixes := make([]netip.Prefix, len(rs))
|
||||||
|
for i, r := range rs {
|
||||||
|
prefixes[i] = netip.Prefix(r.Prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
return prefixes
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rs Routes) Proto() []*v1.Route {
|
||||||
|
protoRoutes := []*v1.Route{}
|
||||||
|
|
||||||
|
for _, route := range rs {
|
||||||
|
protoRoute := v1.Route{
|
||||||
|
Id: uint64(route.ID),
|
||||||
|
Machine: route.Machine.Proto(),
|
||||||
|
Prefix: netip.Prefix(route.Prefix).String(),
|
||||||
|
Advertised: route.Advertised,
|
||||||
|
Enabled: route.Enabled,
|
||||||
|
IsPrimary: route.IsPrimary,
|
||||||
|
CreatedAt: timestamppb.New(route.CreatedAt),
|
||||||
|
UpdatedAt: timestamppb.New(route.UpdatedAt),
|
||||||
|
}
|
||||||
|
|
||||||
|
if route.DeletedAt.Valid {
|
||||||
|
protoRoute.DeletedAt = timestamppb.New(route.DeletedAt.Time)
|
||||||
|
}
|
||||||
|
|
||||||
|
protoRoutes = append(protoRoutes, &protoRoute)
|
||||||
|
}
|
||||||
|
|
||||||
|
return protoRoutes
|
||||||
|
}
|
55
hscontrol/types/users.go
Normal file
55
hscontrol/types/users.go
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
package types
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
)
|
||||||
|
|
||||||
|
// User is the way Headscale implements the concept of users in Tailscale
|
||||||
|
//
|
||||||
|
// At the end of the day, users in Tailscale are some kind of 'bubbles' or users
|
||||||
|
// that contain our machines.
|
||||||
|
type User struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string `gorm:"unique"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *User) TailscaleUser() *tailcfg.User {
|
||||||
|
user := tailcfg.User{
|
||||||
|
ID: tailcfg.UserID(n.ID),
|
||||||
|
LoginName: n.Name,
|
||||||
|
DisplayName: n.Name,
|
||||||
|
ProfilePicURL: "",
|
||||||
|
Domain: "headscale.net",
|
||||||
|
Logins: []tailcfg.LoginID{},
|
||||||
|
Created: time.Time{},
|
||||||
|
}
|
||||||
|
|
||||||
|
return &user
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *User) TailscaleLogin() *tailcfg.Login {
|
||||||
|
login := tailcfg.Login{
|
||||||
|
ID: tailcfg.LoginID(n.ID),
|
||||||
|
LoginName: n.Name,
|
||||||
|
DisplayName: n.Name,
|
||||||
|
ProfilePicURL: "",
|
||||||
|
Domain: "headscale.net",
|
||||||
|
}
|
||||||
|
|
||||||
|
return &login
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *User) Proto() *v1.User {
|
||||||
|
return &v1.User{
|
||||||
|
Id: strconv.FormatUint(uint64(n.ID), util.Base10),
|
||||||
|
Name: n.Name,
|
||||||
|
CreatedAt: timestamppb.New(n.CreatedAt),
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,415 +0,0 @@
|
||||||
package hscontrol
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"gopkg.in/check.v1"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (s *Suite) TestCreateAndDestroyUser(c *check.C) {
|
|
||||||
user, err := app.db.CreateUser("test")
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
c.Assert(user.Name, check.Equals, "test")
|
|
||||||
|
|
||||||
users, err := app.db.ListUsers()
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
c.Assert(len(users), check.Equals, 1)
|
|
||||||
|
|
||||||
err = app.db.DestroyUser("test")
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
_, err = app.db.GetUser("test")
|
|
||||||
c.Assert(err, check.NotNil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Suite) TestDestroyUserErrors(c *check.C) {
|
|
||||||
err := app.db.DestroyUser("test")
|
|
||||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
|
||||||
|
|
||||||
user, err := app.db.CreateUser("test")
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
err = app.db.DestroyUser("test")
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
result := app.db.db.Preload("User").First(&pak, "key = ?", pak.Key)
|
|
||||||
// destroying a user also deletes all associated preauthkeys
|
|
||||||
c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound)
|
|
||||||
|
|
||||||
user, err = app.db.CreateUser("test")
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
pak, err = app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
machine := Machine{
|
|
||||||
ID: 0,
|
|
||||||
MachineKey: "foo",
|
|
||||||
NodeKey: "bar",
|
|
||||||
DiscoKey: "faa",
|
|
||||||
Hostname: "testmachine",
|
|
||||||
UserID: user.ID,
|
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
|
||||||
AuthKeyID: uint(pak.ID),
|
|
||||||
}
|
|
||||||
app.db.db.Save(&machine)
|
|
||||||
|
|
||||||
err = app.db.DestroyUser("test")
|
|
||||||
c.Assert(err, check.Equals, ErrUserStillHasNodes)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Suite) TestRenameUser(c *check.C) {
|
|
||||||
userTest, err := app.db.CreateUser("test")
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
c.Assert(userTest.Name, check.Equals, "test")
|
|
||||||
|
|
||||||
users, err := app.db.ListUsers()
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
c.Assert(len(users), check.Equals, 1)
|
|
||||||
|
|
||||||
err = app.db.RenameUser("test", "test-renamed")
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
_, err = app.db.GetUser("test")
|
|
||||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
|
||||||
|
|
||||||
_, err = app.db.GetUser("test-renamed")
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
err = app.db.RenameUser("test-does-not-exit", "test")
|
|
||||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
|
||||||
|
|
||||||
userTest2, err := app.db.CreateUser("test2")
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
c.Assert(userTest2.Name, check.Equals, "test2")
|
|
||||||
|
|
||||||
err = app.db.RenameUser("test2", "test-renamed")
|
|
||||||
c.Assert(err, check.Equals, ErrUserExists)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
|
|
||||||
userShared1, err := app.db.CreateUser("shared1")
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
userShared2, err := app.db.CreateUser("shared2")
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
userShared3, err := app.db.CreateUser("shared3")
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
preAuthKeyShared1, err := app.db.CreatePreAuthKey(
|
|
||||||
userShared1.Name,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
preAuthKeyShared2, err := app.db.CreatePreAuthKey(
|
|
||||||
userShared2.Name,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
preAuthKeyShared3, err := app.db.CreatePreAuthKey(
|
|
||||||
userShared3.Name,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
preAuthKey2Shared1, err := app.db.CreatePreAuthKey(
|
|
||||||
userShared1.Name,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
_, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1")
|
|
||||||
c.Assert(err, check.NotNil)
|
|
||||||
|
|
||||||
machineInShared1 := &Machine{
|
|
||||||
ID: 1,
|
|
||||||
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
|
||||||
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
|
||||||
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
|
||||||
Hostname: "test_get_shared_nodes_1",
|
|
||||||
UserID: userShared1.ID,
|
|
||||||
User: *userShared1,
|
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
|
||||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
|
|
||||||
AuthKeyID: uint(preAuthKeyShared1.ID),
|
|
||||||
}
|
|
||||||
app.db.db.Save(machineInShared1)
|
|
||||||
|
|
||||||
_, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
machineInShared2 := &Machine{
|
|
||||||
ID: 2,
|
|
||||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
|
||||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
|
||||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
|
||||||
Hostname: "test_get_shared_nodes_2",
|
|
||||||
UserID: userShared2.ID,
|
|
||||||
User: *userShared2,
|
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
|
||||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
|
|
||||||
AuthKeyID: uint(preAuthKeyShared2.ID),
|
|
||||||
}
|
|
||||||
app.db.db.Save(machineInShared2)
|
|
||||||
|
|
||||||
_, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
machineInShared3 := &Machine{
|
|
||||||
ID: 3,
|
|
||||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
|
||||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
|
||||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
|
||||||
Hostname: "test_get_shared_nodes_3",
|
|
||||||
UserID: userShared3.ID,
|
|
||||||
User: *userShared3,
|
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
|
||||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
|
|
||||||
AuthKeyID: uint(preAuthKeyShared3.ID),
|
|
||||||
}
|
|
||||||
app.db.db.Save(machineInShared3)
|
|
||||||
|
|
||||||
_, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
machine2InShared1 := &Machine{
|
|
||||||
ID: 4,
|
|
||||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
|
||||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
|
||||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
|
||||||
Hostname: "test_get_shared_nodes_4",
|
|
||||||
UserID: userShared1.ID,
|
|
||||||
User: *userShared1,
|
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
|
||||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")},
|
|
||||||
AuthKeyID: uint(preAuthKey2Shared1.ID),
|
|
||||||
}
|
|
||||||
app.db.db.Save(machine2InShared1)
|
|
||||||
|
|
||||||
peersOfMachine1InShared1, err := app.db.getPeers(app.aclPolicy, app.aclRules, machineInShared1)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
userProfiles := app.db.getMapResponseUserProfiles(
|
|
||||||
*machineInShared1,
|
|
||||||
peersOfMachine1InShared1,
|
|
||||||
)
|
|
||||||
|
|
||||||
c.Assert(len(userProfiles), check.Equals, 3)
|
|
||||||
|
|
||||||
found := false
|
|
||||||
for _, userProfiles := range userProfiles {
|
|
||||||
if userProfiles.DisplayName == userShared1.Name {
|
|
||||||
found = true
|
|
||||||
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.Assert(found, check.Equals, true)
|
|
||||||
|
|
||||||
found = false
|
|
||||||
for _, userProfile := range userProfiles {
|
|
||||||
if userProfile.DisplayName == userShared2.Name {
|
|
||||||
found = true
|
|
||||||
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.Assert(found, check.Equals, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNormalizeToFQDNRules(t *testing.T) {
|
|
||||||
type args struct {
|
|
||||||
name string
|
|
||||||
stripEmailDomain bool
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
want string
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "normalize simple name",
|
|
||||||
args: args{
|
|
||||||
name: "normalize-simple.name",
|
|
||||||
stripEmailDomain: false,
|
|
||||||
},
|
|
||||||
want: "normalize-simple.name",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "normalize an email",
|
|
||||||
args: args{
|
|
||||||
name: "foo.bar@example.com",
|
|
||||||
stripEmailDomain: false,
|
|
||||||
},
|
|
||||||
want: "foo.bar.example.com",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "normalize an email domain should be removed",
|
|
||||||
args: args{
|
|
||||||
name: "foo.bar@example.com",
|
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
|
||||||
want: "foo.bar",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "strip enabled no email passed as argument",
|
|
||||||
args: args{
|
|
||||||
name: "not-email-and-strip-enabled",
|
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
|
||||||
want: "not-email-and-strip-enabled",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "normalize complex email",
|
|
||||||
args: args{
|
|
||||||
name: "foo.bar+complex-email@example.com",
|
|
||||||
stripEmailDomain: false,
|
|
||||||
},
|
|
||||||
want: "foo.bar-complex-email.example.com",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "user name with space",
|
|
||||||
args: args{
|
|
||||||
name: "name space",
|
|
||||||
stripEmailDomain: false,
|
|
||||||
},
|
|
||||||
want: "name-space",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "user with quote",
|
|
||||||
args: args{
|
|
||||||
name: "Jamie's iPhone 5",
|
|
||||||
stripEmailDomain: false,
|
|
||||||
},
|
|
||||||
want: "jamies-iphone-5",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got, err := NormalizeToFQDNRules(tt.args.name, tt.args.stripEmailDomain)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf(
|
|
||||||
"NormalizeToFQDNRules() error = %v, wantErr %v",
|
|
||||||
err,
|
|
||||||
tt.wantErr,
|
|
||||||
)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if got != tt.want {
|
|
||||||
t.Errorf("NormalizeToFQDNRules() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCheckForFQDNRules(t *testing.T) {
|
|
||||||
type args struct {
|
|
||||||
name string
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "valid: user",
|
|
||||||
args: args{name: "valid-user"},
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid: capitalized user",
|
|
||||||
args: args{name: "Invalid-CapItaLIzed-user"},
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid: email as user",
|
|
||||||
args: args{name: "foo.bar@example.com"},
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid: chars in user name",
|
|
||||||
args: args{name: "super-user+name"},
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid: too long name for user",
|
|
||||||
args: args{
|
|
||||||
name: "super-long-useruseruser-name-that-should-be-a-little-more-than-63-chars",
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if err := CheckForFQDNRules(tt.args.name); (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("CheckForFQDNRules() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Suite) TestSetMachineUser(c *check.C) {
|
|
||||||
oldUser, err := app.db.CreateUser("old")
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
newUser, err := app.db.CreateUser("new")
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
pak, err := app.db.CreatePreAuthKey(oldUser.Name, false, false, nil, nil)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
machine := Machine{
|
|
||||||
ID: 0,
|
|
||||||
MachineKey: "foo",
|
|
||||||
NodeKey: "bar",
|
|
||||||
DiscoKey: "faa",
|
|
||||||
Hostname: "testmachine",
|
|
||||||
UserID: oldUser.ID,
|
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
|
||||||
AuthKeyID: uint(pak.ID),
|
|
||||||
}
|
|
||||||
app.db.db.Save(&machine)
|
|
||||||
c.Assert(machine.UserID, check.Equals, oldUser.ID)
|
|
||||||
|
|
||||||
err = app.db.SetMachineUser(&machine, newUser.Name)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
c.Assert(machine.UserID, check.Equals, newUser.ID)
|
|
||||||
c.Assert(machine.User.Name, check.Equals, newUser.Name)
|
|
||||||
|
|
||||||
err = app.db.SetMachineUser(&machine, "non-existing-user")
|
|
||||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
|
||||||
|
|
||||||
err = app.db.SetMachineUser(&machine, newUser.Name)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
c.Assert(machine.UserID, check.Equals, newUser.ID)
|
|
||||||
c.Assert(machine.User.Name, check.Equals, newUser.Name)
|
|
||||||
}
|
|
|
@ -1,12 +1,94 @@
|
||||||
package util
|
package util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"go4.org/netipx"
|
"go4.org/netipx"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// This is borrowed from, and updated to use IPSet
|
||||||
|
// https://github.com/tailscale/tailscale/blob/71029cea2ddf82007b80f465b256d027eab0f02d/wgengine/filter/tailcfg.go#L97-L162
|
||||||
|
// TODO(kradalby): contribute upstream and make public.
|
||||||
|
var (
|
||||||
|
zeroIP4 = netip.AddrFrom4([4]byte{})
|
||||||
|
zeroIP6 = netip.AddrFrom16([16]byte{})
|
||||||
|
)
|
||||||
|
|
||||||
|
// parseIPSet parses arg as one:
|
||||||
|
//
|
||||||
|
// - an IP address (IPv4 or IPv6)
|
||||||
|
// - the string "*" to match everything (both IPv4 & IPv6)
|
||||||
|
// - a CIDR (e.g. "192.168.0.0/16")
|
||||||
|
// - a range of two IPs, inclusive, separated by hyphen ("2eff::1-2eff::0800")
|
||||||
|
//
|
||||||
|
// bits, if non-nil, is the legacy SrcBits CIDR length to make a IP
|
||||||
|
// address (without a slash) treated as a CIDR of *bits length.
|
||||||
|
// nolint
|
||||||
|
func ParseIPSet(arg string, bits *int) (*netipx.IPSet, error) {
|
||||||
|
var ipSet netipx.IPSetBuilder
|
||||||
|
if arg == "*" {
|
||||||
|
ipSet.AddPrefix(netip.PrefixFrom(zeroIP4, 0))
|
||||||
|
ipSet.AddPrefix(netip.PrefixFrom(zeroIP6, 0))
|
||||||
|
|
||||||
|
return ipSet.IPSet()
|
||||||
|
}
|
||||||
|
if strings.Contains(arg, "/") {
|
||||||
|
pfx, err := netip.ParsePrefix(arg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if pfx != pfx.Masked() {
|
||||||
|
return nil, fmt.Errorf("%v contains non-network bits set", pfx)
|
||||||
|
}
|
||||||
|
|
||||||
|
ipSet.AddPrefix(pfx)
|
||||||
|
|
||||||
|
return ipSet.IPSet()
|
||||||
|
}
|
||||||
|
if strings.Count(arg, "-") == 1 {
|
||||||
|
ip1s, ip2s, _ := strings.Cut(arg, "-")
|
||||||
|
|
||||||
|
ip1, err := netip.ParseAddr(ip1s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ip2, err := netip.ParseAddr(ip2s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
r := netipx.IPRangeFrom(ip1, ip2)
|
||||||
|
if !r.IsValid() {
|
||||||
|
return nil, fmt.Errorf("invalid IP range %q", arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, prefix := range r.Prefixes() {
|
||||||
|
ipSet.AddPrefix(prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ipSet.IPSet()
|
||||||
|
}
|
||||||
|
ip, err := netip.ParseAddr(arg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid IP address %q", arg)
|
||||||
|
}
|
||||||
|
bits8 := uint8(ip.BitLen())
|
||||||
|
if bits != nil {
|
||||||
|
if *bits < 0 || *bits > int(bits8) {
|
||||||
|
return nil, fmt.Errorf("invalid CIDR size %d for IP %q", *bits, arg)
|
||||||
|
}
|
||||||
|
bits8 = uint8(*bits)
|
||||||
|
}
|
||||||
|
|
||||||
|
ipSet.AddPrefix(netip.PrefixFrom(ip, int(bits8)))
|
||||||
|
|
||||||
|
return ipSet.IPSet()
|
||||||
|
}
|
||||||
|
|
||||||
func GetIPPrefixEndpoints(na netip.Prefix) (netip.Addr, netip.Addr) {
|
func GetIPPrefixEndpoints(na netip.Prefix) (netip.Addr, netip.Addr) {
|
||||||
var network, broadcast netip.Addr
|
var network, broadcast netip.Addr
|
||||||
ipRange := netipx.RangeOfPrefix(na)
|
ipRange := netipx.RangeOfPrefix(na)
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package hscontrol
|
package util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
@ -105,7 +105,7 @@ func Test_parseIPSet(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, err := parseIPSet(tt.args.arg, tt.args.bits)
|
got, err := ParseIPSet(tt.args.arg, tt.args.bits)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("parseIPSet() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("parseIPSet() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
|
7
hscontrol/util/const.go
Normal file
7
hscontrol/util/const.go
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
package util
|
||||||
|
|
||||||
|
const (
|
||||||
|
RegisterMethodAuthKey = "authkey"
|
||||||
|
RegisterMethodOIDC = "oidc"
|
||||||
|
RegisterMethodCLI = "cli"
|
||||||
|
)
|
69
hscontrol/util/dns.go
Normal file
69
hscontrol/util/dns.go
Normal file
|
@ -0,0 +1,69 @@
|
||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// value related to RFC 1123 and 952.
|
||||||
|
LabelHostnameLength = 63
|
||||||
|
)
|
||||||
|
|
||||||
|
var invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+")
|
||||||
|
|
||||||
|
var ErrInvalidUserName = errors.New("invalid user name")
|
||||||
|
|
||||||
|
// NormalizeToFQDNRules will replace forbidden chars in user
|
||||||
|
// it can also return an error if the user doesn't respect RFC 952 and 1123.
|
||||||
|
func NormalizeToFQDNRules(name string, stripEmailDomain bool) (string, error) {
|
||||||
|
name = strings.ToLower(name)
|
||||||
|
name = strings.ReplaceAll(name, "'", "")
|
||||||
|
atIdx := strings.Index(name, "@")
|
||||||
|
if stripEmailDomain && atIdx > 0 {
|
||||||
|
name = name[:atIdx]
|
||||||
|
} else {
|
||||||
|
name = strings.ReplaceAll(name, "@", ".")
|
||||||
|
}
|
||||||
|
name = invalidCharsInUserRegex.ReplaceAllString(name, "-")
|
||||||
|
|
||||||
|
for _, elt := range strings.Split(name, ".") {
|
||||||
|
if len(elt) > LabelHostnameLength {
|
||||||
|
return "", fmt.Errorf(
|
||||||
|
"label %v is more than 63 chars: %w",
|
||||||
|
elt,
|
||||||
|
ErrInvalidUserName,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return name, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func CheckForFQDNRules(name string) error {
|
||||||
|
if len(name) > LabelHostnameLength {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"DNS segment must not be over 63 chars. %v doesn't comply with this rule: %w",
|
||||||
|
name,
|
||||||
|
ErrInvalidUserName,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if strings.ToLower(name) != name {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"DNS segment should be lowercase. %v doesn't comply with this rule: %w",
|
||||||
|
name,
|
||||||
|
ErrInvalidUserName,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if invalidCharsInUserRegex.MatchString(name) {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"DNS segment should only be composed of lowercase ASCII letters numbers, hyphen and dots. %v doesn't comply with theses rules: %w",
|
||||||
|
name,
|
||||||
|
ErrInvalidUserName,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
143
hscontrol/util/dns_test.go
Normal file
143
hscontrol/util/dns_test.go
Normal file
|
@ -0,0 +1,143 @@
|
||||||
|
package util
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestNormalizeToFQDNRules(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
name string
|
||||||
|
stripEmailDomain bool
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "normalize simple name",
|
||||||
|
args: args{
|
||||||
|
name: "normalize-simple.name",
|
||||||
|
stripEmailDomain: false,
|
||||||
|
},
|
||||||
|
want: "normalize-simple.name",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "normalize an email",
|
||||||
|
args: args{
|
||||||
|
name: "foo.bar@example.com",
|
||||||
|
stripEmailDomain: false,
|
||||||
|
},
|
||||||
|
want: "foo.bar.example.com",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "normalize an email domain should be removed",
|
||||||
|
args: args{
|
||||||
|
name: "foo.bar@example.com",
|
||||||
|
stripEmailDomain: true,
|
||||||
|
},
|
||||||
|
want: "foo.bar",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "strip enabled no email passed as argument",
|
||||||
|
args: args{
|
||||||
|
name: "not-email-and-strip-enabled",
|
||||||
|
stripEmailDomain: true,
|
||||||
|
},
|
||||||
|
want: "not-email-and-strip-enabled",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "normalize complex email",
|
||||||
|
args: args{
|
||||||
|
name: "foo.bar+complex-email@example.com",
|
||||||
|
stripEmailDomain: false,
|
||||||
|
},
|
||||||
|
want: "foo.bar-complex-email.example.com",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "user name with space",
|
||||||
|
args: args{
|
||||||
|
name: "name space",
|
||||||
|
stripEmailDomain: false,
|
||||||
|
},
|
||||||
|
want: "name-space",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "user with quote",
|
||||||
|
args: args{
|
||||||
|
name: "Jamie's iPhone 5",
|
||||||
|
stripEmailDomain: false,
|
||||||
|
},
|
||||||
|
want: "jamies-iphone-5",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := NormalizeToFQDNRules(tt.args.name, tt.args.stripEmailDomain)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf(
|
||||||
|
"NormalizeToFQDNRules() error = %v, wantErr %v",
|
||||||
|
err,
|
||||||
|
tt.wantErr,
|
||||||
|
)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("NormalizeToFQDNRules() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckForFQDNRules(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
name string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid: user",
|
||||||
|
args: args{name: "valid-user"},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid: capitalized user",
|
||||||
|
args: args{name: "Invalid-CapItaLIzed-user"},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid: email as user",
|
||||||
|
args: args{name: "foo.bar@example.com"},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid: chars in user name",
|
||||||
|
args: args{name: "super-user+name"},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid: too long name for user",
|
||||||
|
args: args{
|
||||||
|
name: "super-long-useruseruser-name-that-should-be-a-little-more-than-63-chars",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if err := CheckForFQDNRules(tt.args.name); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("CheckForFQDNRules() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -6,7 +6,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol"
|
"github.com/juanfont/headscale/hscontrol/policy"
|
||||||
"github.com/juanfont/headscale/integration/hsic"
|
"github.com/juanfont/headscale/integration/hsic"
|
||||||
"github.com/juanfont/headscale/integration/tsic"
|
"github.com/juanfont/headscale/integration/tsic"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
@ -45,7 +45,7 @@ var veryLargeDestination = []string{
|
||||||
"208.0.0.0/4:*",
|
"208.0.0.0/4:*",
|
||||||
}
|
}
|
||||||
|
|
||||||
func aclScenario(t *testing.T, policy *hscontrol.ACLPolicy, clientsPerUser int) *Scenario {
|
func aclScenario(t *testing.T, policy *policy.ACLPolicy, clientsPerUser int) *Scenario {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
scenario, err := NewScenario()
|
scenario, err := NewScenario()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
@ -92,7 +92,7 @@ func TestACLHostsInNetMapTable(t *testing.T) {
|
||||||
// they can access minus one (them self).
|
// they can access minus one (them self).
|
||||||
tests := map[string]struct {
|
tests := map[string]struct {
|
||||||
users map[string]int
|
users map[string]int
|
||||||
policy hscontrol.ACLPolicy
|
policy policy.ACLPolicy
|
||||||
want map[string]int
|
want map[string]int
|
||||||
}{
|
}{
|
||||||
// Test that when we have no ACL, each client netmap has
|
// Test that when we have no ACL, each client netmap has
|
||||||
|
@ -102,8 +102,8 @@ func TestACLHostsInNetMapTable(t *testing.T) {
|
||||||
"user1": 2,
|
"user1": 2,
|
||||||
"user2": 2,
|
"user2": 2,
|
||||||
},
|
},
|
||||||
policy: hscontrol.ACLPolicy{
|
policy: policy.ACLPolicy{
|
||||||
ACLs: []hscontrol.ACL{
|
ACLs: []policy.ACL{
|
||||||
{
|
{
|
||||||
Action: "accept",
|
Action: "accept",
|
||||||
Sources: []string{"*"},
|
Sources: []string{"*"},
|
||||||
|
@ -123,8 +123,8 @@ func TestACLHostsInNetMapTable(t *testing.T) {
|
||||||
"user1": 2,
|
"user1": 2,
|
||||||
"user2": 2,
|
"user2": 2,
|
||||||
},
|
},
|
||||||
policy: hscontrol.ACLPolicy{
|
policy: policy.ACLPolicy{
|
||||||
ACLs: []hscontrol.ACL{
|
ACLs: []policy.ACL{
|
||||||
{
|
{
|
||||||
Action: "accept",
|
Action: "accept",
|
||||||
Sources: []string{"user1"},
|
Sources: []string{"user1"},
|
||||||
|
@ -149,8 +149,8 @@ func TestACLHostsInNetMapTable(t *testing.T) {
|
||||||
"user1": 2,
|
"user1": 2,
|
||||||
"user2": 2,
|
"user2": 2,
|
||||||
},
|
},
|
||||||
policy: hscontrol.ACLPolicy{
|
policy: policy.ACLPolicy{
|
||||||
ACLs: []hscontrol.ACL{
|
ACLs: []policy.ACL{
|
||||||
{
|
{
|
||||||
Action: "accept",
|
Action: "accept",
|
||||||
Sources: []string{"user1"},
|
Sources: []string{"user1"},
|
||||||
|
@ -186,8 +186,8 @@ func TestACLHostsInNetMapTable(t *testing.T) {
|
||||||
"user1": 2,
|
"user1": 2,
|
||||||
"user2": 2,
|
"user2": 2,
|
||||||
},
|
},
|
||||||
policy: hscontrol.ACLPolicy{
|
policy: policy.ACLPolicy{
|
||||||
ACLs: []hscontrol.ACL{
|
ACLs: []policy.ACL{
|
||||||
{
|
{
|
||||||
Action: "accept",
|
Action: "accept",
|
||||||
Sources: []string{"user1"},
|
Sources: []string{"user1"},
|
||||||
|
@ -214,8 +214,8 @@ func TestACLHostsInNetMapTable(t *testing.T) {
|
||||||
"user1": 2,
|
"user1": 2,
|
||||||
"user2": 2,
|
"user2": 2,
|
||||||
},
|
},
|
||||||
policy: hscontrol.ACLPolicy{
|
policy: policy.ACLPolicy{
|
||||||
ACLs: []hscontrol.ACL{
|
ACLs: []policy.ACL{
|
||||||
{
|
{
|
||||||
Action: "accept",
|
Action: "accept",
|
||||||
Sources: []string{"user1"},
|
Sources: []string{"user1"},
|
||||||
|
@ -282,8 +282,8 @@ func TestACLAllowUser80Dst(t *testing.T) {
|
||||||
IntegrationSkip(t)
|
IntegrationSkip(t)
|
||||||
|
|
||||||
scenario := aclScenario(t,
|
scenario := aclScenario(t,
|
||||||
&hscontrol.ACLPolicy{
|
&policy.ACLPolicy{
|
||||||
ACLs: []hscontrol.ACL{
|
ACLs: []policy.ACL{
|
||||||
{
|
{
|
||||||
Action: "accept",
|
Action: "accept",
|
||||||
Sources: []string{"user1"},
|
Sources: []string{"user1"},
|
||||||
|
@ -338,11 +338,11 @@ func TestACLDenyAllPort80(t *testing.T) {
|
||||||
IntegrationSkip(t)
|
IntegrationSkip(t)
|
||||||
|
|
||||||
scenario := aclScenario(t,
|
scenario := aclScenario(t,
|
||||||
&hscontrol.ACLPolicy{
|
&policy.ACLPolicy{
|
||||||
Groups: map[string][]string{
|
Groups: map[string][]string{
|
||||||
"group:integration-acl-test": {"user1", "user2"},
|
"group:integration-acl-test": {"user1", "user2"},
|
||||||
},
|
},
|
||||||
ACLs: []hscontrol.ACL{
|
ACLs: []policy.ACL{
|
||||||
{
|
{
|
||||||
Action: "accept",
|
Action: "accept",
|
||||||
Sources: []string{"group:integration-acl-test"},
|
Sources: []string{"group:integration-acl-test"},
|
||||||
|
@ -387,8 +387,8 @@ func TestACLAllowUserDst(t *testing.T) {
|
||||||
IntegrationSkip(t)
|
IntegrationSkip(t)
|
||||||
|
|
||||||
scenario := aclScenario(t,
|
scenario := aclScenario(t,
|
||||||
&hscontrol.ACLPolicy{
|
&policy.ACLPolicy{
|
||||||
ACLs: []hscontrol.ACL{
|
ACLs: []policy.ACL{
|
||||||
{
|
{
|
||||||
Action: "accept",
|
Action: "accept",
|
||||||
Sources: []string{"user1"},
|
Sources: []string{"user1"},
|
||||||
|
@ -445,8 +445,8 @@ func TestACLAllowStarDst(t *testing.T) {
|
||||||
IntegrationSkip(t)
|
IntegrationSkip(t)
|
||||||
|
|
||||||
scenario := aclScenario(t,
|
scenario := aclScenario(t,
|
||||||
&hscontrol.ACLPolicy{
|
&policy.ACLPolicy{
|
||||||
ACLs: []hscontrol.ACL{
|
ACLs: []policy.ACL{
|
||||||
{
|
{
|
||||||
Action: "accept",
|
Action: "accept",
|
||||||
Sources: []string{"user1"},
|
Sources: []string{"user1"},
|
||||||
|
@ -504,11 +504,11 @@ func TestACLNamedHostsCanReachBySubnet(t *testing.T) {
|
||||||
IntegrationSkip(t)
|
IntegrationSkip(t)
|
||||||
|
|
||||||
scenario := aclScenario(t,
|
scenario := aclScenario(t,
|
||||||
&hscontrol.ACLPolicy{
|
&policy.ACLPolicy{
|
||||||
Hosts: hscontrol.Hosts{
|
Hosts: policy.Hosts{
|
||||||
"all": netip.MustParsePrefix("100.64.0.0/24"),
|
"all": netip.MustParsePrefix("100.64.0.0/24"),
|
||||||
},
|
},
|
||||||
ACLs: []hscontrol.ACL{
|
ACLs: []policy.ACL{
|
||||||
// Everyone can curl test3
|
// Everyone can curl test3
|
||||||
{
|
{
|
||||||
Action: "accept",
|
Action: "accept",
|
||||||
|
@ -603,16 +603,16 @@ func TestACLNamedHostsCanReach(t *testing.T) {
|
||||||
IntegrationSkip(t)
|
IntegrationSkip(t)
|
||||||
|
|
||||||
tests := map[string]struct {
|
tests := map[string]struct {
|
||||||
policy hscontrol.ACLPolicy
|
policy policy.ACLPolicy
|
||||||
}{
|
}{
|
||||||
"ipv4": {
|
"ipv4": {
|
||||||
policy: hscontrol.ACLPolicy{
|
policy: policy.ACLPolicy{
|
||||||
Hosts: hscontrol.Hosts{
|
Hosts: policy.Hosts{
|
||||||
"test1": netip.MustParsePrefix("100.64.0.1/32"),
|
"test1": netip.MustParsePrefix("100.64.0.1/32"),
|
||||||
"test2": netip.MustParsePrefix("100.64.0.2/32"),
|
"test2": netip.MustParsePrefix("100.64.0.2/32"),
|
||||||
"test3": netip.MustParsePrefix("100.64.0.3/32"),
|
"test3": netip.MustParsePrefix("100.64.0.3/32"),
|
||||||
},
|
},
|
||||||
ACLs: []hscontrol.ACL{
|
ACLs: []policy.ACL{
|
||||||
// Everyone can curl test3
|
// Everyone can curl test3
|
||||||
{
|
{
|
||||||
Action: "accept",
|
Action: "accept",
|
||||||
|
@ -629,13 +629,13 @@ func TestACLNamedHostsCanReach(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"ipv6": {
|
"ipv6": {
|
||||||
policy: hscontrol.ACLPolicy{
|
policy: policy.ACLPolicy{
|
||||||
Hosts: hscontrol.Hosts{
|
Hosts: policy.Hosts{
|
||||||
"test1": netip.MustParsePrefix("fd7a:115c:a1e0::1/128"),
|
"test1": netip.MustParsePrefix("fd7a:115c:a1e0::1/128"),
|
||||||
"test2": netip.MustParsePrefix("fd7a:115c:a1e0::2/128"),
|
"test2": netip.MustParsePrefix("fd7a:115c:a1e0::2/128"),
|
||||||
"test3": netip.MustParsePrefix("fd7a:115c:a1e0::3/128"),
|
"test3": netip.MustParsePrefix("fd7a:115c:a1e0::3/128"),
|
||||||
},
|
},
|
||||||
ACLs: []hscontrol.ACL{
|
ACLs: []policy.ACL{
|
||||||
// Everyone can curl test3
|
// Everyone can curl test3
|
||||||
{
|
{
|
||||||
Action: "accept",
|
Action: "accept",
|
||||||
|
@ -854,11 +854,11 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) {
|
||||||
IntegrationSkip(t)
|
IntegrationSkip(t)
|
||||||
|
|
||||||
tests := map[string]struct {
|
tests := map[string]struct {
|
||||||
policy hscontrol.ACLPolicy
|
policy policy.ACLPolicy
|
||||||
}{
|
}{
|
||||||
"ipv4": {
|
"ipv4": {
|
||||||
policy: hscontrol.ACLPolicy{
|
policy: policy.ACLPolicy{
|
||||||
ACLs: []hscontrol.ACL{
|
ACLs: []policy.ACL{
|
||||||
{
|
{
|
||||||
Action: "accept",
|
Action: "accept",
|
||||||
Sources: []string{"100.64.0.1"},
|
Sources: []string{"100.64.0.1"},
|
||||||
|
@ -868,8 +868,8 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"ipv6": {
|
"ipv6": {
|
||||||
policy: hscontrol.ACLPolicy{
|
policy: policy.ACLPolicy{
|
||||||
ACLs: []hscontrol.ACL{
|
ACLs: []policy.ACL{
|
||||||
{
|
{
|
||||||
Action: "accept",
|
Action: "accept",
|
||||||
Sources: []string{"fd7a:115c:a1e0::1"},
|
Sources: []string{"fd7a:115c:a1e0::1"},
|
||||||
|
@ -879,12 +879,12 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"hostv4cidr": {
|
"hostv4cidr": {
|
||||||
policy: hscontrol.ACLPolicy{
|
policy: policy.ACLPolicy{
|
||||||
Hosts: hscontrol.Hosts{
|
Hosts: policy.Hosts{
|
||||||
"test1": netip.MustParsePrefix("100.64.0.1/32"),
|
"test1": netip.MustParsePrefix("100.64.0.1/32"),
|
||||||
"test2": netip.MustParsePrefix("100.64.0.2/32"),
|
"test2": netip.MustParsePrefix("100.64.0.2/32"),
|
||||||
},
|
},
|
||||||
ACLs: []hscontrol.ACL{
|
ACLs: []policy.ACL{
|
||||||
{
|
{
|
||||||
Action: "accept",
|
Action: "accept",
|
||||||
Sources: []string{"test1"},
|
Sources: []string{"test1"},
|
||||||
|
@ -894,12 +894,12 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"hostv6cidr": {
|
"hostv6cidr": {
|
||||||
policy: hscontrol.ACLPolicy{
|
policy: policy.ACLPolicy{
|
||||||
Hosts: hscontrol.Hosts{
|
Hosts: policy.Hosts{
|
||||||
"test1": netip.MustParsePrefix("fd7a:115c:a1e0::1/128"),
|
"test1": netip.MustParsePrefix("fd7a:115c:a1e0::1/128"),
|
||||||
"test2": netip.MustParsePrefix("fd7a:115c:a1e0::2/128"),
|
"test2": netip.MustParsePrefix("fd7a:115c:a1e0::2/128"),
|
||||||
},
|
},
|
||||||
ACLs: []hscontrol.ACL{
|
ACLs: []policy.ACL{
|
||||||
{
|
{
|
||||||
Action: "accept",
|
Action: "accept",
|
||||||
Sources: []string{"test1"},
|
Sources: []string{"test1"},
|
||||||
|
@ -909,12 +909,12 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"group": {
|
"group": {
|
||||||
policy: hscontrol.ACLPolicy{
|
policy: policy.ACLPolicy{
|
||||||
Groups: map[string][]string{
|
Groups: map[string][]string{
|
||||||
"group:one": {"user1"},
|
"group:one": {"user1"},
|
||||||
"group:two": {"user2"},
|
"group:two": {"user2"},
|
||||||
},
|
},
|
||||||
ACLs: []hscontrol.ACL{
|
ACLs: []policy.ACL{
|
||||||
{
|
{
|
||||||
Action: "accept",
|
Action: "accept",
|
||||||
Sources: []string{"group:one"},
|
Sources: []string{"group:one"},
|
||||||
|
|
|
@ -23,7 +23,7 @@ import (
|
||||||
|
|
||||||
"github.com/davecgh/go-spew/spew"
|
"github.com/davecgh/go-spew/spew"
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
"github.com/juanfont/headscale/hscontrol"
|
"github.com/juanfont/headscale/hscontrol/policy"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/juanfont/headscale/integration/dockertestutil"
|
"github.com/juanfont/headscale/integration/dockertestutil"
|
||||||
"github.com/juanfont/headscale/integration/integrationutil"
|
"github.com/juanfont/headscale/integration/integrationutil"
|
||||||
|
@ -60,7 +60,7 @@ type HeadscaleInContainer struct {
|
||||||
port int
|
port int
|
||||||
extraPorts []string
|
extraPorts []string
|
||||||
hostPortBindings map[string][]string
|
hostPortBindings map[string][]string
|
||||||
aclPolicy *hscontrol.ACLPolicy
|
aclPolicy *policy.ACLPolicy
|
||||||
env map[string]string
|
env map[string]string
|
||||||
tlsCert []byte
|
tlsCert []byte
|
||||||
tlsKey []byte
|
tlsKey []byte
|
||||||
|
@ -73,7 +73,7 @@ type Option = func(c *HeadscaleInContainer)
|
||||||
|
|
||||||
// WithACLPolicy adds a hscontrol.ACLPolicy policy to the
|
// WithACLPolicy adds a hscontrol.ACLPolicy policy to the
|
||||||
// HeadscaleInContainer instance.
|
// HeadscaleInContainer instance.
|
||||||
func WithACLPolicy(acl *hscontrol.ACLPolicy) Option {
|
func WithACLPolicy(acl *policy.ACLPolicy) Option {
|
||||||
return func(hsic *HeadscaleInContainer) {
|
return func(hsic *HeadscaleInContainer) {
|
||||||
// TODO(kradalby): Move somewhere appropriate
|
// TODO(kradalby): Move somewhere appropriate
|
||||||
hsic.env["HEADSCALE_ACL_POLICY_PATH"] = aclPolicyPath
|
hsic.env["HEADSCALE_ACL_POLICY_PATH"] = aclPolicyPath
|
||||||
|
|
|
@ -6,7 +6,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol"
|
"github.com/juanfont/headscale/hscontrol/policy"
|
||||||
"github.com/juanfont/headscale/integration/hsic"
|
"github.com/juanfont/headscale/integration/hsic"
|
||||||
"github.com/juanfont/headscale/integration/tsic"
|
"github.com/juanfont/headscale/integration/tsic"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
@ -57,18 +57,18 @@ func TestSSHOneUserAllToAll(t *testing.T) {
|
||||||
err = scenario.CreateHeadscaleEnv(spec,
|
err = scenario.CreateHeadscaleEnv(spec,
|
||||||
[]tsic.Option{tsic.WithSSH()},
|
[]tsic.Option{tsic.WithSSH()},
|
||||||
hsic.WithACLPolicy(
|
hsic.WithACLPolicy(
|
||||||
&hscontrol.ACLPolicy{
|
&policy.ACLPolicy{
|
||||||
Groups: map[string][]string{
|
Groups: map[string][]string{
|
||||||
"group:integration-test": {"user1"},
|
"group:integration-test": {"user1"},
|
||||||
},
|
},
|
||||||
ACLs: []hscontrol.ACL{
|
ACLs: []policy.ACL{
|
||||||
{
|
{
|
||||||
Action: "accept",
|
Action: "accept",
|
||||||
Sources: []string{"*"},
|
Sources: []string{"*"},
|
||||||
Destinations: []string{"*:*"},
|
Destinations: []string{"*:*"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
SSHs: []hscontrol.SSH{
|
SSHs: []policy.SSH{
|
||||||
{
|
{
|
||||||
Action: "accept",
|
Action: "accept",
|
||||||
Sources: []string{"group:integration-test"},
|
Sources: []string{"group:integration-test"},
|
||||||
|
@ -134,18 +134,18 @@ func TestSSHMultipleUsersAllToAll(t *testing.T) {
|
||||||
err = scenario.CreateHeadscaleEnv(spec,
|
err = scenario.CreateHeadscaleEnv(spec,
|
||||||
[]tsic.Option{tsic.WithSSH()},
|
[]tsic.Option{tsic.WithSSH()},
|
||||||
hsic.WithACLPolicy(
|
hsic.WithACLPolicy(
|
||||||
&hscontrol.ACLPolicy{
|
&policy.ACLPolicy{
|
||||||
Groups: map[string][]string{
|
Groups: map[string][]string{
|
||||||
"group:integration-test": {"user1", "user2"},
|
"group:integration-test": {"user1", "user2"},
|
||||||
},
|
},
|
||||||
ACLs: []hscontrol.ACL{
|
ACLs: []policy.ACL{
|
||||||
{
|
{
|
||||||
Action: "accept",
|
Action: "accept",
|
||||||
Sources: []string{"*"},
|
Sources: []string{"*"},
|
||||||
Destinations: []string{"*:*"},
|
Destinations: []string{"*:*"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
SSHs: []hscontrol.SSH{
|
SSHs: []policy.SSH{
|
||||||
{
|
{
|
||||||
Action: "accept",
|
Action: "accept",
|
||||||
Sources: []string{"group:integration-test"},
|
Sources: []string{"group:integration-test"},
|
||||||
|
@ -216,18 +216,18 @@ func TestSSHNoSSHConfigured(t *testing.T) {
|
||||||
err = scenario.CreateHeadscaleEnv(spec,
|
err = scenario.CreateHeadscaleEnv(spec,
|
||||||
[]tsic.Option{tsic.WithSSH()},
|
[]tsic.Option{tsic.WithSSH()},
|
||||||
hsic.WithACLPolicy(
|
hsic.WithACLPolicy(
|
||||||
&hscontrol.ACLPolicy{
|
&policy.ACLPolicy{
|
||||||
Groups: map[string][]string{
|
Groups: map[string][]string{
|
||||||
"group:integration-test": {"user1"},
|
"group:integration-test": {"user1"},
|
||||||
},
|
},
|
||||||
ACLs: []hscontrol.ACL{
|
ACLs: []policy.ACL{
|
||||||
{
|
{
|
||||||
Action: "accept",
|
Action: "accept",
|
||||||
Sources: []string{"*"},
|
Sources: []string{"*"},
|
||||||
Destinations: []string{"*:*"},
|
Destinations: []string{"*:*"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
SSHs: []hscontrol.SSH{},
|
SSHs: []policy.SSH{},
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
hsic.WithTestName("sshnoneconfigured"),
|
hsic.WithTestName("sshnoneconfigured"),
|
||||||
|
@ -286,18 +286,18 @@ func TestSSHIsBlockedInACL(t *testing.T) {
|
||||||
err = scenario.CreateHeadscaleEnv(spec,
|
err = scenario.CreateHeadscaleEnv(spec,
|
||||||
[]tsic.Option{tsic.WithSSH()},
|
[]tsic.Option{tsic.WithSSH()},
|
||||||
hsic.WithACLPolicy(
|
hsic.WithACLPolicy(
|
||||||
&hscontrol.ACLPolicy{
|
&policy.ACLPolicy{
|
||||||
Groups: map[string][]string{
|
Groups: map[string][]string{
|
||||||
"group:integration-test": {"user1"},
|
"group:integration-test": {"user1"},
|
||||||
},
|
},
|
||||||
ACLs: []hscontrol.ACL{
|
ACLs: []policy.ACL{
|
||||||
{
|
{
|
||||||
Action: "accept",
|
Action: "accept",
|
||||||
Sources: []string{"*"},
|
Sources: []string{"*"},
|
||||||
Destinations: []string{"*:80"},
|
Destinations: []string{"*:80"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
SSHs: []hscontrol.SSH{
|
SSHs: []policy.SSH{
|
||||||
{
|
{
|
||||||
Action: "accept",
|
Action: "accept",
|
||||||
Sources: []string{"group:integration-test"},
|
Sources: []string{"group:integration-test"},
|
||||||
|
@ -364,19 +364,19 @@ func TestSSUserOnlyIsolation(t *testing.T) {
|
||||||
err = scenario.CreateHeadscaleEnv(spec,
|
err = scenario.CreateHeadscaleEnv(spec,
|
||||||
[]tsic.Option{tsic.WithSSH()},
|
[]tsic.Option{tsic.WithSSH()},
|
||||||
hsic.WithACLPolicy(
|
hsic.WithACLPolicy(
|
||||||
&hscontrol.ACLPolicy{
|
&policy.ACLPolicy{
|
||||||
Groups: map[string][]string{
|
Groups: map[string][]string{
|
||||||
"group:ssh1": {"useracl1"},
|
"group:ssh1": {"useracl1"},
|
||||||
"group:ssh2": {"useracl2"},
|
"group:ssh2": {"useracl2"},
|
||||||
},
|
},
|
||||||
ACLs: []hscontrol.ACL{
|
ACLs: []policy.ACL{
|
||||||
{
|
{
|
||||||
Action: "accept",
|
Action: "accept",
|
||||||
Sources: []string{"*"},
|
Sources: []string{"*"},
|
||||||
Destinations: []string{"*:*"},
|
Destinations: []string{"*:*"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
SSHs: []hscontrol.SSH{
|
SSHs: []policy.SSH{
|
||||||
{
|
{
|
||||||
Action: "accept",
|
Action: "accept",
|
||||||
Sources: []string{"group:ssh1"},
|
Sources: []string{"group:ssh1"},
|
||||||
|
|
Loading…
Reference in a new issue