mirror of
https://github.com/juanfont/headscale.git
synced 2025-02-08 10:18:01 +09:00
Return better web errors to the user (#2398)
Some checks are pending
Build / build-nix (push) Waiting to run
Build / build-cross (GOARCH=386 GOOS=linux) (push) Waiting to run
Build / build-cross (GOARCH=amd64 GOOS=darwin) (push) Waiting to run
Build / build-cross (GOARCH=amd64 GOOS=linux) (push) Waiting to run
Build / build-cross (GOARCH=arm GOOS=linux GOARM=5) (push) Waiting to run
Build / build-cross (GOARCH=arm GOOS=linux GOARM=6) (push) Waiting to run
Build / build-cross (GOARCH=arm GOOS=linux GOARM=7) (push) Waiting to run
Build / build-cross (GOARCH=arm64 GOOS=darwin) (push) Waiting to run
Build / build-cross (GOARCH=arm64 GOOS=linux) (push) Waiting to run
Tests / test (push) Waiting to run
Some checks are pending
Build / build-nix (push) Waiting to run
Build / build-cross (GOARCH=386 GOOS=linux) (push) Waiting to run
Build / build-cross (GOARCH=amd64 GOOS=darwin) (push) Waiting to run
Build / build-cross (GOARCH=amd64 GOOS=linux) (push) Waiting to run
Build / build-cross (GOARCH=arm GOOS=linux GOARM=5) (push) Waiting to run
Build / build-cross (GOARCH=arm GOOS=linux GOARM=6) (push) Waiting to run
Build / build-cross (GOARCH=arm GOOS=linux GOARM=7) (push) Waiting to run
Build / build-cross (GOARCH=arm64 GOOS=darwin) (push) Waiting to run
Build / build-cross (GOARCH=arm64 GOOS=linux) (push) Waiting to run
Tests / test (push) Waiting to run
* add dedicated http error to propagate to user Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * classify user errors in http handlers Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * move validation of pre auth key out of db This move separates the logic a bit and allow us to write specific errors for the caller, in this case the web layer so we can present the user with the correct error codes without bleeding web stuff into a generic validate. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * update changelog Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> --------- Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
1c7f3bc440
commit
45752db0f6
10 changed files with 268 additions and 229 deletions
|
@ -28,6 +28,8 @@
|
|||
[#2396](https://github.com/juanfont/headscale/pull/2396)
|
||||
- Pre auth keys that are used by a node can no longer be deleted
|
||||
[#2396](https://github.com/juanfont/headscale/pull/2396)
|
||||
- Rehaul HTTP errors, return better status code and errors to users
|
||||
[#2398](https://github.com/juanfont/headscale/pull/2398)
|
||||
|
||||
## 0.24.2 (2025-01-30)
|
||||
|
||||
|
|
|
@ -72,7 +72,7 @@ func (h *Headscale) handleExistingNode(
|
|||
machineKey key.MachinePublic,
|
||||
) (*tailcfg.RegisterResponse, error) {
|
||||
if node.MachineKey != machineKey {
|
||||
return nil, errors.New("node already exists with different machine key")
|
||||
return nil, NewHTTPError(http.StatusUnauthorized, "node exist with different machine key", nil)
|
||||
}
|
||||
|
||||
expired := node.IsExpired()
|
||||
|
@ -81,7 +81,7 @@ func (h *Headscale) handleExistingNode(
|
|||
|
||||
// The client is trying to extend their key, this is not allowed.
|
||||
if requestExpiry.After(time.Now()) {
|
||||
return nil, errors.New("extending key is not allowed")
|
||||
return nil, NewHTTPError(http.StatusBadRequest, "extending key is not allowed", nil)
|
||||
}
|
||||
|
||||
// If the request expiry is in the past, we consider it a logout.
|
||||
|
@ -155,13 +155,42 @@ func (h *Headscale) waitForFollowup(
|
|||
}
|
||||
}
|
||||
|
||||
// canUsePreAuthKey checks if a pre auth key can be used.
|
||||
func canUsePreAuthKey(pak *types.PreAuthKey) error {
|
||||
if pak == nil {
|
||||
return NewHTTPError(http.StatusUnauthorized, "invalid authkey", nil)
|
||||
}
|
||||
if pak.Expiration != nil && pak.Expiration.Before(time.Now()) {
|
||||
return NewHTTPError(http.StatusUnauthorized, "authkey expired", nil)
|
||||
}
|
||||
|
||||
// we don't need to check if has been used before
|
||||
if pak.Reusable {
|
||||
return nil
|
||||
}
|
||||
|
||||
if pak.Used {
|
||||
return NewHTTPError(http.StatusUnauthorized, "authkey already used", nil)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Headscale) handleRegisterWithAuthKey(
|
||||
regReq tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
) (*tailcfg.RegisterResponse, error) {
|
||||
pak, err := h.db.ValidatePreAuthKey(regReq.Auth.AuthKey)
|
||||
pak, err := h.db.GetPreAuthKey(regReq.Auth.AuthKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid pre auth key: %w", err)
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, NewHTTPError(http.StatusUnauthorized, "invalid pre auth key", nil)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = canUsePreAuthKey(pak)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nodeToRegister := types.Node{
|
||||
|
|
130
hscontrol/auth_test.go
Normal file
130
hscontrol/auth_test.go
Normal file
|
@ -0,0 +1,130 @@
|
|||
package hscontrol
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
)
|
||||
|
||||
func TestCanUsePreAuthKey(t *testing.T) {
|
||||
now := time.Now()
|
||||
past := now.Add(-time.Hour)
|
||||
future := now.Add(time.Hour)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
pak *types.PreAuthKey
|
||||
wantErr bool
|
||||
err HTTPError
|
||||
}{
|
||||
{
|
||||
name: "valid reusable key",
|
||||
pak: &types.PreAuthKey{
|
||||
Reusable: true,
|
||||
Used: false,
|
||||
Expiration: &future,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid non-reusable key",
|
||||
pak: &types.PreAuthKey{
|
||||
Reusable: false,
|
||||
Used: false,
|
||||
Expiration: &future,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "expired key",
|
||||
pak: &types.PreAuthKey{
|
||||
Reusable: false,
|
||||
Used: false,
|
||||
Expiration: &past,
|
||||
},
|
||||
wantErr: true,
|
||||
err: NewHTTPError(http.StatusUnauthorized, "authkey expired", nil),
|
||||
},
|
||||
{
|
||||
name: "used non-reusable key",
|
||||
pak: &types.PreAuthKey{
|
||||
Reusable: false,
|
||||
Used: true,
|
||||
Expiration: &future,
|
||||
},
|
||||
wantErr: true,
|
||||
err: NewHTTPError(http.StatusUnauthorized, "authkey already used", nil),
|
||||
},
|
||||
{
|
||||
name: "used reusable key",
|
||||
pak: &types.PreAuthKey{
|
||||
Reusable: true,
|
||||
Used: true,
|
||||
Expiration: &future,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "no expiration date",
|
||||
pak: &types.PreAuthKey{
|
||||
Reusable: false,
|
||||
Used: false,
|
||||
Expiration: nil,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "nil preauth key",
|
||||
pak: nil,
|
||||
wantErr: true,
|
||||
err: NewHTTPError(http.StatusUnauthorized, "invalid authkey", nil),
|
||||
},
|
||||
{
|
||||
name: "expired and used key",
|
||||
pak: &types.PreAuthKey{
|
||||
Reusable: false,
|
||||
Used: true,
|
||||
Expiration: &past,
|
||||
},
|
||||
wantErr: true,
|
||||
err: NewHTTPError(http.StatusUnauthorized, "authkey expired", nil),
|
||||
},
|
||||
{
|
||||
name: "no expiration and used key",
|
||||
pak: &types.PreAuthKey{
|
||||
Reusable: false,
|
||||
Used: true,
|
||||
Expiration: nil,
|
||||
},
|
||||
wantErr: true,
|
||||
err: NewHTTPError(http.StatusUnauthorized, "authkey already used", nil),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := canUsePreAuthKey(tt.pak)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("expected error but got none")
|
||||
} else {
|
||||
httpErr, ok := err.(HTTPError)
|
||||
if !ok {
|
||||
t.Errorf("expected HTTPError but got %T", err)
|
||||
} else {
|
||||
if diff := cmp.Diff(tt.err, httpErr); diff != "" {
|
||||
t.Errorf("unexpected error (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("expected no error but got %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -10,7 +10,6 @@ import (
|
|||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/types/ptr"
|
||||
"tailscale.com/util/set"
|
||||
)
|
||||
|
||||
|
@ -64,6 +63,7 @@ func CreatePreAuthKey(
|
|||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
// TODO(kradalby): unify the key generations spread all over the code.
|
||||
kstr, err := generateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -108,18 +108,21 @@ func ListPreAuthKeysByUser(tx *gorm.DB, uid types.UserID) ([]types.PreAuthKey, e
|
|||
return keys, nil
|
||||
}
|
||||
|
||||
// GetPreAuthKey returns a PreAuthKey for a given key.
|
||||
func GetPreAuthKey(tx *gorm.DB, user string, key string) (*types.PreAuthKey, error) {
|
||||
pak, err := ValidatePreAuthKey(tx, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
func (hsdb *HSDatabase) GetPreAuthKey(key string) (*types.PreAuthKey, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.PreAuthKey, error) {
|
||||
return GetPreAuthKey(rx, key)
|
||||
})
|
||||
}
|
||||
|
||||
if pak.User.Name != user {
|
||||
return nil, ErrUserMismatch
|
||||
// GetPreAuthKey returns a PreAuthKey for a given key. The caller is responsible
|
||||
// for checking if the key is usable (expired or used).
|
||||
func GetPreAuthKey(tx *gorm.DB, key string) (*types.PreAuthKey, error) {
|
||||
pak := types.PreAuthKey{}
|
||||
if err := tx.Preload("User").First(&pak, "key = ?", key).Error; err != nil {
|
||||
return nil, ErrPreAuthKeyNotFound
|
||||
}
|
||||
|
||||
return pak, nil
|
||||
return &pak, nil
|
||||
}
|
||||
|
||||
// DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey
|
||||
|
@ -140,15 +143,6 @@ func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error {
|
|||
})
|
||||
}
|
||||
|
||||
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
|
||||
func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
|
||||
if err := tx.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UsePreAuthKey marks a PreAuthKey as used.
|
||||
func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
|
||||
k.Used = true
|
||||
|
@ -159,44 +153,13 @@ func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.PreAuthKey, error) {
|
||||
return ValidatePreAuthKey(rx, k)
|
||||
})
|
||||
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
|
||||
func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
|
||||
if err := tx.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// ValidatePreAuthKey does the heavy lifting for validation of the PreAuthKey coming from a node
|
||||
// If returns no error and a PreAuthKey, it can be used.
|
||||
func ValidatePreAuthKey(tx *gorm.DB, k string) (*types.PreAuthKey, error) {
|
||||
pak := types.PreAuthKey{}
|
||||
if result := tx.Preload("User").First(&pak, "key = ?", k); errors.Is(
|
||||
result.Error,
|
||||
gorm.ErrRecordNotFound,
|
||||
) {
|
||||
return nil, ErrPreAuthKeyNotFound
|
||||
}
|
||||
|
||||
if pak.Expiration != nil && pak.Expiration.Before(time.Now()) {
|
||||
return nil, ErrPreAuthKeyExpired
|
||||
}
|
||||
|
||||
if pak.Reusable { // we don't need to check if has been used before
|
||||
return &pak, nil
|
||||
}
|
||||
|
||||
nodes := types.Nodes{}
|
||||
if err := tx.
|
||||
Preload("AuthKey").
|
||||
Where(&types.Node{AuthKeyID: ptr.To(pak.ID)}).
|
||||
Find(&nodes).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(nodes) != 0 || pak.Used {
|
||||
return nil, ErrSingleUseAuthKeyHasBeenUsed
|
||||
}
|
||||
|
||||
return &pak, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func generateKey() (string, error) {
|
||||
|
|
|
@ -3,14 +3,14 @@ package db
|
|||
import (
|
||||
"sort"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/check.v1"
|
||||
"tailscale.com/types/ptr"
|
||||
|
||||
"gopkg.in/check.v1"
|
||||
)
|
||||
|
||||
func (*Suite) TestCreatePreAuthKey(c *check.C) {
|
||||
|
@ -43,123 +43,6 @@ func (*Suite) TestCreatePreAuthKey(c *check.C) {
|
|||
c.Assert((keys)[0].User.ID, check.Equals, user.ID)
|
||||
}
|
||||
|
||||
func (*Suite) TestExpiredPreAuthKey(c *check.C) {
|
||||
user, err := db.CreateUser(types.User{Name: "test2"})
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
now := time.Now().Add(-5 * time.Second)
|
||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, &now, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
key, err := db.ValidatePreAuthKey(pak.Key)
|
||||
c.Assert(err, check.Equals, ErrPreAuthKeyExpired)
|
||||
c.Assert(key, check.IsNil)
|
||||
}
|
||||
|
||||
func (*Suite) TestPreAuthKeyDoesNotExist(c *check.C) {
|
||||
key, err := db.ValidatePreAuthKey("potatoKey")
|
||||
c.Assert(err, check.Equals, ErrPreAuthKeyNotFound)
|
||||
c.Assert(key, check.IsNil)
|
||||
}
|
||||
|
||||
func (*Suite) TestValidateKeyOk(c *check.C) {
|
||||
user, err := db.CreateUser(types.User{Name: "test3"})
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
key, err := db.ValidatePreAuthKey(pak.Key)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(key.ID, check.Equals, pak.ID)
|
||||
}
|
||||
|
||||
func (*Suite) TestAlreadyUsedKey(c *check.C) {
|
||||
user, err := db.CreateUser(types.User{Name: "test4"})
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
node := types.Node{
|
||||
ID: 0,
|
||||
Hostname: "testest",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: ptr.To(pak.ID),
|
||||
}
|
||||
trx := db.DB.Save(&node)
|
||||
c.Assert(trx.Error, check.IsNil)
|
||||
|
||||
key, err := db.ValidatePreAuthKey(pak.Key)
|
||||
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)
|
||||
c.Assert(key, check.IsNil)
|
||||
}
|
||||
|
||||
func (*Suite) TestReusableBeingUsedKey(c *check.C) {
|
||||
user, err := db.CreateUser(types.User{Name: "test5"})
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
node := types.Node{
|
||||
ID: 1,
|
||||
Hostname: "testest",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: ptr.To(pak.ID),
|
||||
}
|
||||
trx := db.DB.Save(&node)
|
||||
c.Assert(trx.Error, check.IsNil)
|
||||
|
||||
key, err := db.ValidatePreAuthKey(pak.Key)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(key.ID, check.Equals, pak.ID)
|
||||
}
|
||||
|
||||
func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) {
|
||||
user, err := db.CreateUser(types.User{Name: "test6"})
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
key, err := db.ValidatePreAuthKey(pak.Key)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(key.ID, check.Equals, pak.ID)
|
||||
}
|
||||
|
||||
func (*Suite) TestExpirePreauthKey(c *check.C) {
|
||||
user, err := db.CreateUser(types.User{Name: "test3"})
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(pak.Expiration, check.IsNil)
|
||||
|
||||
err = db.ExpirePreAuthKey(pak)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(pak.Expiration, check.NotNil)
|
||||
|
||||
key, err := db.ValidatePreAuthKey(pak.Key)
|
||||
c.Assert(err, check.Equals, ErrPreAuthKeyExpired)
|
||||
c.Assert(key, check.IsNil)
|
||||
}
|
||||
|
||||
func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) {
|
||||
user, err := db.CreateUser(types.User{Name: "test6"})
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
pak.Used = true
|
||||
db.DB.Save(&pak)
|
||||
|
||||
_, err = db.ValidatePreAuthKey(pak.Key)
|
||||
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)
|
||||
}
|
||||
|
||||
func (*Suite) TestPreAuthKeyACLTags(c *check.C) {
|
||||
user, err := db.CreateUser(types.User{Name: "test8"})
|
||||
c.Assert(err, check.IsNil)
|
||||
|
|
|
@ -181,11 +181,15 @@ func (api headscaleV1APIServer) ExpirePreAuthKey(
|
|||
request *v1.ExpirePreAuthKeyRequest,
|
||||
) (*v1.ExpirePreAuthKeyResponse, error) {
|
||||
err := api.h.db.Write(func(tx *gorm.DB) error {
|
||||
preAuthKey, err := db.GetPreAuthKey(tx, request.GetUser(), request.Key)
|
||||
preAuthKey, err := db.GetPreAuthKey(tx, request.Key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if preAuthKey.User.Name != request.GetUser() {
|
||||
return fmt.Errorf("preauth key does not belong to user")
|
||||
}
|
||||
|
||||
return db.ExpirePreAuthKey(tx, preAuthKey)
|
||||
})
|
||||
if err != nil {
|
||||
|
|
|
@ -33,10 +33,33 @@ const (
|
|||
)
|
||||
|
||||
// httpError logs an error and sends an HTTP error response with the given
|
||||
func httpError(w http.ResponseWriter, err error, userError string, code int) {
|
||||
log.Error().Err(err).Msg(userError)
|
||||
http.Error(w, userError, code)
|
||||
func httpError(w http.ResponseWriter, err error) {
|
||||
var herr HTTPError
|
||||
if errors.As(err, &herr) {
|
||||
http.Error(w, herr.Msg, herr.Code)
|
||||
log.Error().Err(herr.Err).Int("code", herr.Code).Msgf("user msg: %s", herr.Msg)
|
||||
} else {
|
||||
http.Error(w, "internal server error", http.StatusInternalServerError)
|
||||
log.Error().Err(err).Int("code", http.StatusInternalServerError).Msg("http internal server error")
|
||||
}
|
||||
}
|
||||
|
||||
// HTTPError represents an error that is surfaced to the user via web.
|
||||
type HTTPError struct {
|
||||
Code int // HTTP response code to send to client; 0 means 500
|
||||
Msg string // Response body to send to client
|
||||
Err error // Detailed error to log on the server
|
||||
}
|
||||
|
||||
func (e HTTPError) Error() string { return fmt.Sprintf("http error[%d]: %s, %s", e.Code, e.Msg, e.Err) }
|
||||
func (e HTTPError) Unwrap() error { return e.Err }
|
||||
|
||||
// Error returns an HTTPError containing the given information.
|
||||
func NewHTTPError(code int, msg string, err error) HTTPError {
|
||||
return HTTPError{Code: code, Msg: msg, Err: err}
|
||||
}
|
||||
|
||||
var errMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed, "method not allowed", nil)
|
||||
|
||||
var ErrRegisterMethodCLIDoesNotSupportExpire = errors.New(
|
||||
"machines registered with CLI does not support expire",
|
||||
|
@ -47,12 +70,12 @@ func parseCabailityVersion(req *http.Request) (tailcfg.CapabilityVersion, error)
|
|||
clientCapabilityStr := req.URL.Query().Get("v")
|
||||
|
||||
if clientCapabilityStr == "" {
|
||||
return 0, ErrNoCapabilityVersion
|
||||
return 0, NewHTTPError(http.StatusBadRequest, "capability version must be set", nil)
|
||||
}
|
||||
|
||||
clientCapabilityVersion, err := strconv.Atoi(clientCapabilityStr)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to parse capability version: %w", err)
|
||||
return 0, NewHTTPError(http.StatusBadRequest, "invalid capability version", fmt.Errorf("failed to parse capability version: %w", err))
|
||||
}
|
||||
|
||||
return tailcfg.CapabilityVersion(clientCapabilityVersion), nil
|
||||
|
@ -85,13 +108,13 @@ func (h *Headscale) VerifyHandler(
|
|||
req *http.Request,
|
||||
) {
|
||||
if req.Method != http.MethodPost {
|
||||
httpError(writer, nil, "Wrong method", http.StatusMethodNotAllowed)
|
||||
httpError(writer, errMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
allow, err := h.derpRequestIsAllowed(req)
|
||||
if err != nil {
|
||||
httpError(writer, err, "Internal error", http.StatusInternalServerError)
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -112,7 +135,7 @@ func (h *Headscale) KeyHandler(
|
|||
// New Tailscale clients send a 'v' parameter to indicate the CurrentCapabilityVersion
|
||||
capVer, err := parseCabailityVersion(req)
|
||||
if err != nil {
|
||||
httpError(writer, err, "Internal error", http.StatusInternalServerError)
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -199,7 +222,7 @@ func (a *AuthProviderWeb) RegisterHandler(
|
|||
// the template and log an error.
|
||||
registrationId, err := types.RegistrationIDFromString(registrationIdStr)
|
||||
if err != nil {
|
||||
httpError(writer, err, "invalid registration ID", http.StatusBadRequest)
|
||||
httpError(writer, NewHTTPError(http.StatusBadRequest, "invalid registration id", err))
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ package hscontrol
|
|||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
@ -12,6 +13,7 @@ import (
|
|||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/net/http2"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/control/controlbase"
|
||||
"tailscale.com/control/controlhttp/controlhttpserver"
|
||||
"tailscale.com/tailcfg"
|
||||
|
@ -81,7 +83,7 @@ func (h *Headscale) NoiseUpgradeHandler(
|
|||
noiseServer.earlyNoise,
|
||||
)
|
||||
if err != nil {
|
||||
httpError(writer, err, "noise upgrade failed", http.StatusInternalServerError)
|
||||
httpError(writer, fmt.Errorf("noise upgrade failed: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -198,7 +200,7 @@ func (ns *noiseServer) NoisePollNetMapHandler(
|
|||
|
||||
var mapRequest tailcfg.MapRequest
|
||||
if err := json.Unmarshal(body, &mapRequest); err != nil {
|
||||
httpError(writer, err, "Internal error", http.StatusInternalServerError)
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -211,7 +213,11 @@ func (ns *noiseServer) NoisePollNetMapHandler(
|
|||
|
||||
node, err := ns.headscale.db.GetNodeByNodeKey(mapRequest.NodeKey)
|
||||
if err != nil {
|
||||
httpError(writer, err, "Internal error", http.StatusInternalServerError)
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
httpError(writer, NewHTTPError(http.StatusNotFound, "node not found", nil))
|
||||
return
|
||||
}
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -230,7 +236,7 @@ func (ns *noiseServer) NoiseRegistrationHandler(
|
|||
req *http.Request,
|
||||
) {
|
||||
if req.Method != http.MethodPost {
|
||||
httpError(writer, nil, "Wrong method", http.StatusMethodNotAllowed)
|
||||
httpError(writer, errMethodNotAllowed)
|
||||
|
||||
return
|
||||
}
|
||||
|
|
|
@ -141,21 +141,21 @@ func (a *AuthProviderOIDC) RegisterHandler(
|
|||
// the template and log an error.
|
||||
registrationId, err := types.RegistrationIDFromString(registrationIdStr)
|
||||
if err != nil {
|
||||
httpError(writer, err, "invalid registration ID", http.StatusBadRequest)
|
||||
httpError(writer, NewHTTPError(http.StatusBadRequest, "invalid registration id", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Set the state and nonce cookies to protect against CSRF attacks
|
||||
state, err := setCSRFCookie(writer, req, "state")
|
||||
if err != nil {
|
||||
httpError(writer, err, "Internal server error", http.StatusInternalServerError)
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Set the state and nonce cookies to protect against CSRF attacks
|
||||
nonce, err := setCSRFCookie(writer, req, "nonce")
|
||||
if err != nil {
|
||||
httpError(writer, err, "Internal server error", http.StatusInternalServerError)
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -219,34 +219,34 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||
) {
|
||||
code, state, err := extractCodeAndStateParamFromRequest(req)
|
||||
if err != nil {
|
||||
httpError(writer, err, err.Error(), http.StatusBadRequest)
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
|
||||
cookieState, err := req.Cookie("state")
|
||||
if err != nil {
|
||||
httpError(writer, err, "state not found", http.StatusBadRequest)
|
||||
httpError(writer, NewHTTPError(http.StatusBadRequest, "state not found", err))
|
||||
return
|
||||
}
|
||||
|
||||
if state != cookieState.Value {
|
||||
httpError(writer, err, "state did not match", http.StatusBadRequest)
|
||||
httpError(writer, NewHTTPError(http.StatusForbidden, "state did not match", nil))
|
||||
return
|
||||
}
|
||||
|
||||
idToken, err := a.extractIDToken(req.Context(), code, state)
|
||||
if err != nil {
|
||||
httpError(writer, err, err.Error(), http.StatusBadRequest)
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
|
||||
nonce, err := req.Cookie("nonce")
|
||||
if err != nil {
|
||||
httpError(writer, err, "nonce not found", http.StatusBadRequest)
|
||||
httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found", err))
|
||||
return
|
||||
}
|
||||
if idToken.Nonce != nonce.Value {
|
||||
httpError(writer, err, "nonce did not match", http.StatusBadRequest)
|
||||
httpError(writer, NewHTTPError(http.StatusForbidden, "nonce did not match", nil))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -254,29 +254,28 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||
|
||||
var claims types.OIDCClaims
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
err = fmt.Errorf("decoding ID token claims: %w", err)
|
||||
httpError(writer, err, err.Error(), http.StatusInternalServerError)
|
||||
httpError(writer, fmt.Errorf("decoding ID token claims: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
if err := validateOIDCAllowedDomains(a.cfg.AllowedDomains, &claims); err != nil {
|
||||
httpError(writer, err, err.Error(), http.StatusUnauthorized)
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validateOIDCAllowedGroups(a.cfg.AllowedGroups, &claims); err != nil {
|
||||
httpError(writer, err, err.Error(), http.StatusUnauthorized)
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validateOIDCAllowedUsers(a.cfg.AllowedUsers, &claims); err != nil {
|
||||
httpError(writer, err, err.Error(), http.StatusUnauthorized)
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := a.createOrUpdateUserFromClaim(&claims)
|
||||
if err != nil {
|
||||
httpError(writer, err, err.Error(), http.StatusInternalServerError)
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -289,9 +288,9 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||
// Register the node if it does not exist.
|
||||
if registrationId != nil {
|
||||
verb := "Reauthenticated"
|
||||
newNode, err := a.handleRegistrationID(user, *registrationId, nodeExpiry)
|
||||
newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry)
|
||||
if err != nil {
|
||||
httpError(writer, err, err.Error(), http.StatusInternalServerError)
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -302,7 +301,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||
// TODO(kradalby): replace with go-elem
|
||||
content, err := renderOIDCCallbackTemplate(user, verb)
|
||||
if err != nil {
|
||||
httpError(writer, err, err.Error(), http.StatusInternalServerError)
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -317,7 +316,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||
|
||||
// Neither node nor machine key was found in the state cache meaning
|
||||
// that we could not reauth nor register the node.
|
||||
httpError(writer, nil, "login session expired, try again", http.StatusInternalServerError)
|
||||
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -328,7 +327,7 @@ func extractCodeAndStateParamFromRequest(
|
|||
state := req.URL.Query().Get("state")
|
||||
|
||||
if code == "" || state == "" {
|
||||
return "", "", errEmptyOIDCCallbackParams
|
||||
return "", "", NewHTTPError(http.StatusBadRequest, "missing code or state parameter", errEmptyOIDCCallbackParams)
|
||||
}
|
||||
|
||||
return code, state, nil
|
||||
|
@ -346,7 +345,7 @@ func (a *AuthProviderOIDC) extractIDToken(
|
|||
if a.cfg.PKCE.Enabled {
|
||||
regInfo, ok := a.registrationCache.Get(state)
|
||||
if !ok {
|
||||
return nil, errNoOIDCRegistrationInfo
|
||||
return nil, NewHTTPError(http.StatusNotFound, "registration not found", errNoOIDCRegistrationInfo)
|
||||
}
|
||||
if regInfo.Verifier != nil {
|
||||
exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(*regInfo.Verifier)}
|
||||
|
@ -355,18 +354,18 @@ func (a *AuthProviderOIDC) extractIDToken(
|
|||
|
||||
oauth2Token, err := a.oauth2Config.Exchange(ctx, code, exchangeOpts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not exchange code for token: %w", err)
|
||||
return nil, NewHTTPError(http.StatusForbidden, "invalid code", fmt.Errorf("could not exchange code for token: %w", err))
|
||||
}
|
||||
|
||||
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return nil, errNoOIDCIDToken
|
||||
return nil, NewHTTPError(http.StatusBadRequest, "no id_token", errNoOIDCIDToken)
|
||||
}
|
||||
|
||||
verifier := a.oidcProvider.Verifier(&oidc.Config{ClientID: a.cfg.ClientID})
|
||||
idToken, err := verifier.Verify(ctx, rawIDToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to verify ID token: %w", err)
|
||||
return nil, NewHTTPError(http.StatusForbidden, "failed to verify id_token", fmt.Errorf("failed to verify ID token: %w", err))
|
||||
}
|
||||
|
||||
return idToken, nil
|
||||
|
@ -381,7 +380,7 @@ func validateOIDCAllowedDomains(
|
|||
if len(allowedDomains) > 0 {
|
||||
if at := strings.LastIndex(claims.Email, "@"); at < 0 ||
|
||||
!slices.Contains(allowedDomains, claims.Email[at+1:]) {
|
||||
return errOIDCAllowedDomains
|
||||
return NewHTTPError(http.StatusUnauthorized, "unauthorised domain", errOIDCAllowedDomains)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -403,7 +402,7 @@ func validateOIDCAllowedGroups(
|
|||
}
|
||||
}
|
||||
|
||||
return errOIDCAllowedGroups
|
||||
return NewHTTPError(http.StatusUnauthorized, "unauthorised group", errOIDCAllowedGroups)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -417,7 +416,7 @@ func validateOIDCAllowedUsers(
|
|||
) error {
|
||||
if len(allowedUsers) > 0 &&
|
||||
!slices.Contains(allowedUsers, claims.Email) {
|
||||
return errOIDCAllowedUsers
|
||||
return NewHTTPError(http.StatusUnauthorized, "unauthorised user", errOIDCAllowedUsers)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -488,7 +487,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
|||
return user, nil
|
||||
}
|
||||
|
||||
func (a *AuthProviderOIDC) handleRegistrationID(
|
||||
func (a *AuthProviderOIDC) handleRegistration(
|
||||
user *types.User,
|
||||
registrationID types.RegistrationID,
|
||||
expiry time.Time,
|
||||
|
|
|
@ -39,19 +39,19 @@ func (h *Headscale) ApplePlatformConfig(
|
|||
vars := mux.Vars(req)
|
||||
platform, ok := vars["platform"]
|
||||
if !ok {
|
||||
httpError(writer, nil, "No platform specified", http.StatusBadRequest)
|
||||
httpError(writer, NewHTTPError(http.StatusBadRequest, "no platform specified", nil))
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.NewV4()
|
||||
if err != nil {
|
||||
httpError(writer, nil, "Failed to create UUID", http.StatusInternalServerError)
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
|
||||
contentID, err := uuid.NewV4()
|
||||
if err != nil {
|
||||
httpError(writer, nil, "Failed to create UUID", http.StatusInternalServerError)
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -65,21 +65,21 @@ func (h *Headscale) ApplePlatformConfig(
|
|||
switch platform {
|
||||
case "macos-standalone":
|
||||
if err := macosStandaloneTemplate.Execute(&payload, platformConfig); err != nil {
|
||||
httpError(writer, err, "Could not render Apple macOS template", http.StatusInternalServerError)
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
case "macos-app-store":
|
||||
if err := macosAppStoreTemplate.Execute(&payload, platformConfig); err != nil {
|
||||
httpError(writer, err, "Could not render Apple macOS template", http.StatusInternalServerError)
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
case "ios":
|
||||
if err := iosTemplate.Execute(&payload, platformConfig); err != nil {
|
||||
httpError(writer, err, "Could not render Apple iOS template", http.StatusInternalServerError)
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
default:
|
||||
httpError(writer, err, "Invalid platform. Only ios, macos-app-store and macos-standalone are supported", http.StatusInternalServerError)
|
||||
httpError(writer, NewHTTPError(http.StatusBadRequest, "platform must be ios, macos-app-store or macos-standalone", nil))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -91,7 +91,7 @@ func (h *Headscale) ApplePlatformConfig(
|
|||
|
||||
var content bytes.Buffer
|
||||
if err := commonTemplate.Execute(&content, config); err != nil {
|
||||
httpError(writer, err, "Could not render platform iOS template", http.StatusInternalServerError)
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue