Initial work eliminating one/two letter variables

This commit is contained in:
Kristoffer Dalby 2021-11-14 20:32:03 +01:00
parent 53ed749f45
commit 471c0b4993
No known key found for this signature in database
GPG key ID: 09F62DC067465735
19 changed files with 568 additions and 532 deletions

View file

@ -28,6 +28,9 @@ linters:
# In progress # In progress
- gocritic - gocritic
# TODO: approve: ok, db, id
- varnamelen
# We should strive to enable these: # We should strive to enable these:
- testpackage - testpackage
- stylecheck - stylecheck
@ -39,7 +42,6 @@ linters:
- gosec - gosec
- forbidigo - forbidigo
- dupl - dupl
- varnamelen
- makezero - makezero
- paralleltest - paralleltest

74
acls.go
View file

@ -41,18 +41,18 @@ func (h *Headscale) LoadACLPolicy(path string) error {
defer policyFile.Close() defer policyFile.Close()
var policy ACLPolicy var policy ACLPolicy
b, err := io.ReadAll(policyFile) policyBytes, err := io.ReadAll(policyFile)
if err != nil { if err != nil {
return err return err
} }
ast, err := hujson.Parse(b) ast, err := hujson.Parse(policyBytes)
if err != nil { if err != nil {
return err return err
} }
ast.Standardize() ast.Standardize()
b = ast.Pack() policyBytes = ast.Pack()
err = json.Unmarshal(b, &policy) err = json.Unmarshal(policyBytes, &policy)
if err != nil { if err != nil {
return err return err
} }
@ -73,32 +73,32 @@ func (h *Headscale) LoadACLPolicy(path string) error {
func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) { func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) {
rules := []tailcfg.FilterRule{} rules := []tailcfg.FilterRule{}
for i, a := range h.aclPolicy.ACLs { for index, acl := range h.aclPolicy.ACLs {
if a.Action != "accept" { if acl.Action != "accept" {
return nil, errorInvalidAction return nil, errorInvalidAction
} }
r := tailcfg.FilterRule{} filterRule := tailcfg.FilterRule{}
srcIPs := []string{} srcIPs := []string{}
for j, u := range a.Users { for innerIndex, user := range acl.Users {
srcs, err := h.generateACLPolicySrcIP(u) srcs, err := h.generateACLPolicySrcIP(user)
if err != nil { if err != nil {
log.Error(). log.Error().
Msgf("Error parsing ACL %d, User %d", i, j) Msgf("Error parsing ACL %d, User %d", index, innerIndex)
return nil, err return nil, err
} }
srcIPs = append(srcIPs, srcs...) srcIPs = append(srcIPs, srcs...)
} }
r.SrcIPs = srcIPs filterRule.SrcIPs = srcIPs
destPorts := []tailcfg.NetPortRange{} destPorts := []tailcfg.NetPortRange{}
for j, d := range a.Ports { for innerIndex, ports := range acl.Ports {
dests, err := h.generateACLPolicyDestPorts(d) dests, err := h.generateACLPolicyDestPorts(ports)
if err != nil { if err != nil {
log.Error(). log.Error().
Msgf("Error parsing ACL %d, Port %d", i, j) Msgf("Error parsing ACL %d, Port %d", index, innerIndex)
return nil, err return nil, err
} }
@ -162,17 +162,17 @@ func (h *Headscale) generateACLPolicyDestPorts(
return dests, nil return dests, nil
} }
func (h *Headscale) expandAlias(s string) ([]string, error) { func (h *Headscale) expandAlias(alias string) ([]string, error) {
if s == "*" { if alias == "*" {
return []string{"*"}, nil return []string{"*"}, nil
} }
if strings.HasPrefix(s, "group:") { if strings.HasPrefix(alias, "group:") {
if _, ok := h.aclPolicy.Groups[s]; !ok { if _, ok := h.aclPolicy.Groups[alias]; !ok {
return nil, errorInvalidGroup return nil, errorInvalidGroup
} }
ips := []string{} ips := []string{}
for _, n := range h.aclPolicy.Groups[s] { for _, n := range h.aclPolicy.Groups[alias] {
nodes, err := h.ListMachinesInNamespace(n) nodes, err := h.ListMachinesInNamespace(n)
if err != nil { if err != nil {
return nil, errorInvalidNamespace return nil, errorInvalidNamespace
@ -185,8 +185,8 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
return ips, nil return ips, nil
} }
if strings.HasPrefix(s, "tag:") { if strings.HasPrefix(alias, "tag:") {
if _, ok := h.aclPolicy.TagOwners[s]; !ok { if _, ok := h.aclPolicy.TagOwners[alias]; !ok {
return nil, errorInvalidTag return nil, errorInvalidTag
} }
@ -197,10 +197,10 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
return nil, err return nil, err
} }
ips := []string{} ips := []string{}
for _, m := range machines { for _, machine := range machines {
hostinfo := tailcfg.Hostinfo{} hostinfo := tailcfg.Hostinfo{}
if len(m.HostInfo) != 0 { if len(machine.HostInfo) != 0 {
hi, err := m.HostInfo.MarshalJSON() hi, err := machine.HostInfo.MarshalJSON()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -211,8 +211,8 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
// FIXME: Check TagOwners allows this // FIXME: Check TagOwners allows this
for _, t := range hostinfo.RequestTags { for _, t := range hostinfo.RequestTags {
if s[4:] == t { if alias[4:] == t {
ips = append(ips, m.IPAddress) ips = append(ips, machine.IPAddress)
break break
} }
@ -223,7 +223,7 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
return ips, nil return ips, nil
} }
n, err := h.GetNamespace(s) n, err := h.GetNamespace(alias)
if err == nil { if err == nil {
nodes, err := h.ListMachinesInNamespace(n.Name) nodes, err := h.ListMachinesInNamespace(n.Name)
if err != nil { if err != nil {
@ -237,16 +237,16 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
return ips, nil return ips, nil
} }
if h, ok := h.aclPolicy.Hosts[s]; ok { if h, ok := h.aclPolicy.Hosts[alias]; ok {
return []string{h.String()}, nil return []string{h.String()}, nil
} }
ip, err := netaddr.ParseIP(s) ip, err := netaddr.ParseIP(alias)
if err == nil { if err == nil {
return []string{ip.String()}, nil return []string{ip.String()}, nil
} }
cidr, err := netaddr.ParseIPPrefix(s) cidr, err := netaddr.ParseIPPrefix(alias)
if err == nil { if err == nil {
return []string{cidr.String()}, nil return []string{cidr.String()}, nil
} }
@ -254,25 +254,25 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
return nil, errorInvalidUserSection return nil, errorInvalidUserSection
} }
func (h *Headscale) expandPorts(s string) (*[]tailcfg.PortRange, error) { func (h *Headscale) expandPorts(portsStr string) (*[]tailcfg.PortRange, error) {
if s == "*" { if portsStr == "*" {
return &[]tailcfg.PortRange{ return &[]tailcfg.PortRange{
{First: PORT_RANGE_BEGIN, Last: PORT_RANGE_END}, {First: PORT_RANGE_BEGIN, Last: PORT_RANGE_END},
}, nil }, nil
} }
ports := []tailcfg.PortRange{} ports := []tailcfg.PortRange{}
for _, p := range strings.Split(s, ",") { for _, portStr := range strings.Split(portsStr, ",") {
rang := strings.Split(p, "-") rang := strings.Split(portStr, "-")
switch len(rang) { switch len(rang) {
case 1: case 1:
pi, err := strconv.ParseUint(rang[0], BASE_10, BIT_SIZE_16) port, err := strconv.ParseUint(rang[0], BASE_10, BIT_SIZE_16)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ports = append(ports, tailcfg.PortRange{ ports = append(ports, tailcfg.PortRange{
First: uint16(pi), First: uint16(port),
Last: uint16(pi), Last: uint16(port),
}) })
case EXPECTED_TOKEN_ITEMS: case EXPECTED_TOKEN_ITEMS:

View file

@ -41,37 +41,37 @@ type ACLTest struct {
} }
// UnmarshalJSON allows to parse the Hosts directly into netaddr objects. // UnmarshalJSON allows to parse the Hosts directly into netaddr objects.
func (h *Hosts) UnmarshalJSON(data []byte) error { func (hosts *Hosts) UnmarshalJSON(data []byte) error {
hosts := Hosts{} newHosts := Hosts{}
hs := make(map[string]string) hostIpPrefixMap := make(map[string]string)
ast, err := hujson.Parse(data) ast, err := hujson.Parse(data)
if err != nil { if err != nil {
return err return err
} }
ast.Standardize() ast.Standardize()
data = ast.Pack() data = ast.Pack()
err = json.Unmarshal(data, &hs) err = json.Unmarshal(data, &hostIpPrefixMap)
if err != nil { if err != nil {
return err return err
} }
for k, v := range hs { for host, prefixStr := range hostIpPrefixMap {
if !strings.Contains(v, "/") { if !strings.Contains(prefixStr, "/") {
v += "/32" prefixStr += "/32"
} }
prefix, err := netaddr.ParseIPPrefix(v) prefix, err := netaddr.ParseIPPrefix(prefixStr)
if err != nil { if err != nil {
return err return err
} }
hosts[k] = prefix newHosts[host] = prefix
} }
*h = hosts *hosts = newHosts
return nil return nil
} }
// IsZero is perhaps a bit naive here. // IsZero is perhaps a bit naive here.
func (p ACLPolicy) IsZero() bool { func (policy ACLPolicy) IsZero() bool {
if len(p.Groups) == 0 && len(p.Hosts) == 0 && len(p.ACLs) == 0 { if len(policy.Groups) == 0 && len(policy.Hosts) == 0 && len(policy.ACLs) == 0 {
return true return true
} }

229
api.go
View file

@ -22,21 +22,25 @@ const RESERVED_RESPONSE_HEADER_SIZE = 4
// KeyHandler provides the Headscale pub key // KeyHandler provides the Headscale pub key
// Listens in /key. // Listens in /key.
func (h *Headscale) KeyHandler(c *gin.Context) { func (h *Headscale) KeyHandler(ctx *gin.Context) {
c.Data(http.StatusOK, "text/plain; charset=utf-8", []byte(h.publicKey.HexString())) ctx.Data(
http.StatusOK,
"text/plain; charset=utf-8",
[]byte(h.publicKey.HexString()),
)
} }
// RegisterWebAPI shows a simple message in the browser to point to the CLI // RegisterWebAPI shows a simple message in the browser to point to the CLI
// Listens in /register. // Listens in /register.
func (h *Headscale) RegisterWebAPI(c *gin.Context) { func (h *Headscale) RegisterWebAPI(ctx *gin.Context) {
mKeyStr := c.Query("key") machineKeyStr := ctx.Query("key")
if mKeyStr == "" { if machineKeyStr == "" {
c.String(http.StatusBadRequest, "Wrong params") ctx.String(http.StatusBadRequest, "Wrong params")
return return
} }
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(` ctx.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
<html> <html>
<body> <body>
<h1>headscale</h1> <h1>headscale</h1>
@ -53,45 +57,45 @@ func (h *Headscale) RegisterWebAPI(c *gin.Context) {
</body> </body>
</html> </html>
`, mKeyStr))) `, machineKeyStr)))
} }
// RegistrationHandler handles the actual registration process of a machine // RegistrationHandler handles the actual registration process of a machine
// Endpoint /machine/:id. // Endpoint /machine/:id.
func (h *Headscale) RegistrationHandler(c *gin.Context) { func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
body, _ := io.ReadAll(c.Request.Body) body, _ := io.ReadAll(ctx.Request.Body)
mKeyStr := c.Param("id") machineKeyStr := ctx.Param("id")
mKey, err := wgkey.ParseHex(mKeyStr) machineKey, err := wgkey.ParseHex(machineKeyStr)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "Registration"). Str("handler", "Registration").
Err(err). Err(err).
Msg("Cannot parse machine key") Msg("Cannot parse machine key")
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc() machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
c.String(http.StatusInternalServerError, "Sad!") ctx.String(http.StatusInternalServerError, "Sad!")
return return
} }
req := tailcfg.RegisterRequest{} req := tailcfg.RegisterRequest{}
err = decode(body, &req, &mKey, h.privateKey) err = decode(body, &req, &machineKey, h.privateKey)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "Registration"). Str("handler", "Registration").
Err(err). Err(err).
Msg("Cannot decode message") Msg("Cannot decode message")
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc() machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
c.String(http.StatusInternalServerError, "Very sad!") ctx.String(http.StatusInternalServerError, "Very sad!")
return return
} }
now := time.Now().UTC() now := time.Now().UTC()
m, err := h.GetMachineByMachineKey(mKey.HexString()) machine, err := h.GetMachineByMachineKey(machineKey.HexString())
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine") log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine")
newMachine := Machine{ newMachine := Machine{
Expiry: &time.Time{}, Expiry: &time.Time{},
MachineKey: mKey.HexString(), MachineKey: machineKey.HexString(),
Name: req.Hostinfo.Hostname, Name: req.Hostinfo.Hostname,
} }
if err := h.db.Create(&newMachine).Error; err != nil { if err := h.db.Create(&newMachine).Error; err != nil {
@ -99,16 +103,16 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
Str("handler", "Registration"). Str("handler", "Registration").
Err(err). Err(err).
Msg("Could not create row") Msg("Could not create row")
machineRegistrations.WithLabelValues("unknown", "web", "error", m.Namespace.Name). machineRegistrations.WithLabelValues("unknown", "web", "error", machine.Namespace.Name).
Inc() Inc()
return return
} }
m = &newMachine machine = &newMachine
} }
if !m.Registered && req.Auth.AuthKey != "" { if !machine.Registered && req.Auth.AuthKey != "" {
h.handleAuthKey(c, h.db, mKey, req, *m) h.handleAuthKey(ctx, h.db, machineKey, req, *machine)
return return
} }
@ -116,63 +120,63 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
resp := tailcfg.RegisterResponse{} resp := tailcfg.RegisterResponse{}
// We have the updated key! // We have the updated key!
if m.NodeKey == wgkey.Key(req.NodeKey).HexString() { if machine.NodeKey == wgkey.Key(req.NodeKey).HexString() {
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout) // The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648 // https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) { if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) {
log.Info(). log.Info().
Str("handler", "Registration"). Str("handler", "Registration").
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Client requested logout") Msg("Client requested logout")
m.Expiry = &req.Expiry // save the expiry so that the machine is marked as expired machine.Expiry = &req.Expiry // save the expiry so that the machine is marked as expired
h.db.Save(&m) h.db.Save(&machine)
resp.AuthURL = "" resp.AuthURL = ""
resp.MachineAuthorized = false resp.MachineAuthorized = false
resp.User = *m.Namespace.toUser() resp.User = *machine.Namespace.toUser()
respBody, err := encode(resp, &mKey, h.privateKey) respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "Registration"). Str("handler", "Registration").
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
c.String(http.StatusInternalServerError, "") ctx.String(http.StatusInternalServerError, "")
return return
} }
c.Data(http.StatusOK, "application/json; charset=utf-8", respBody) ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
return return
} }
if m.Registered && m.Expiry.UTC().After(now) { if machine.Registered && machine.Expiry.UTC().After(now) {
// The machine registration is valid, respond with redirect to /map // The machine registration is valid, respond with redirect to /map
log.Debug(). log.Debug().
Str("handler", "Registration"). Str("handler", "Registration").
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Client is registered and we have the current NodeKey. All clear to /map") Msg("Client is registered and we have the current NodeKey. All clear to /map")
resp.AuthURL = "" resp.AuthURL = ""
resp.MachineAuthorized = true resp.MachineAuthorized = true
resp.User = *m.Namespace.toUser() resp.User = *machine.Namespace.toUser()
resp.Login = *m.Namespace.toLogin() resp.Login = *machine.Namespace.toLogin()
respBody, err := encode(resp, &mKey, h.privateKey) respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "Registration"). Str("handler", "Registration").
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
machineRegistrations.WithLabelValues("update", "web", "error", m.Namespace.Name). machineRegistrations.WithLabelValues("update", "web", "error", machine.Namespace.Name).
Inc() Inc()
c.String(http.StatusInternalServerError, "") ctx.String(http.StatusInternalServerError, "")
return return
} }
machineRegistrations.WithLabelValues("update", "web", "success", m.Namespace.Name). machineRegistrations.WithLabelValues("update", "web", "success", machine.Namespace.Name).
Inc() Inc()
c.Data(http.StatusOK, "application/json; charset=utf-8", respBody) ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
return return
} }
@ -180,15 +184,15 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
// The client has registered before, but has expired // The client has registered before, but has expired
log.Debug(). log.Debug().
Str("handler", "Registration"). Str("handler", "Registration").
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Machine registration has expired. Sending a authurl to register") Msg("Machine registration has expired. Sending a authurl to register")
if h.cfg.OIDC.Issuer != "" { if h.cfg.OIDC.Issuer != "" {
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString())
} else { } else {
resp.AuthURL = fmt.Sprintf("%s/register?key=%s", resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString())
} }
// When a client connects, it may request a specific expiry time in its // When a client connects, it may request a specific expiry time in its
@ -197,51 +201,52 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
// into two steps (which cant pass arbitrary data between them easily) and needs to be // into two steps (which cant pass arbitrary data between them easily) and needs to be
// retrieved again after the user has authenticated. After the authentication flow // retrieved again after the user has authenticated. After the authentication flow
// completes, RequestedExpiry is copied into Expiry. // completes, RequestedExpiry is copied into Expiry.
m.RequestedExpiry = &req.Expiry machine.RequestedExpiry = &req.Expiry
h.db.Save(&m) h.db.Save(&machine)
respBody, err := encode(resp, &mKey, h.privateKey) respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "Registration"). Str("handler", "Registration").
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
machineRegistrations.WithLabelValues("new", "web", "error", m.Namespace.Name). machineRegistrations.WithLabelValues("new", "web", "error", machine.Namespace.Name).
Inc() Inc()
c.String(http.StatusInternalServerError, "") ctx.String(http.StatusInternalServerError, "")
return return
} }
machineRegistrations.WithLabelValues("new", "web", "success", m.Namespace.Name). machineRegistrations.WithLabelValues("new", "web", "success", machine.Namespace.Name).
Inc() Inc()
c.Data(http.StatusOK, "application/json; charset=utf-8", respBody) ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
return return
} }
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration // The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
if m.NodeKey == wgkey.Key(req.OldNodeKey).HexString() && m.Expiry.UTC().After(now) { if machine.NodeKey == wgkey.Key(req.OldNodeKey).HexString() &&
machine.Expiry.UTC().After(now) {
log.Debug(). log.Debug().
Str("handler", "Registration"). Str("handler", "Registration").
Str("machine", m.Name). Str("machine", machine.Name).
Msg("We have the OldNodeKey in the database. This is a key refresh") Msg("We have the OldNodeKey in the database. This is a key refresh")
m.NodeKey = wgkey.Key(req.NodeKey).HexString() machine.NodeKey = wgkey.Key(req.NodeKey).HexString()
h.db.Save(&m) h.db.Save(&machine)
resp.AuthURL = "" resp.AuthURL = ""
resp.User = *m.Namespace.toUser() resp.User = *machine.Namespace.toUser()
respBody, err := encode(resp, &mKey, h.privateKey) respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "Registration"). Str("handler", "Registration").
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
c.String(http.StatusInternalServerError, "Extremely sad!") ctx.String(http.StatusInternalServerError, "Extremely sad!")
return return
} }
c.Data(http.StatusOK, "application/json; charset=utf-8", respBody) ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
return return
} }
@ -249,47 +254,47 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
// The machine registration is new, redirect the client to the registration URL // The machine registration is new, redirect the client to the registration URL
log.Debug(). log.Debug().
Str("handler", "Registration"). Str("handler", "Registration").
Str("machine", m.Name). Str("machine", machine.Name).
Msg("The node is sending us a new NodeKey, sending auth url") Msg("The node is sending us a new NodeKey, sending auth url")
if h.cfg.OIDC.Issuer != "" { if h.cfg.OIDC.Issuer != "" {
resp.AuthURL = fmt.Sprintf( resp.AuthURL = fmt.Sprintf(
"%s/oidc/register/%s", "%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), strings.TrimSuffix(h.cfg.ServerURL, "/"),
mKey.HexString(), machineKey.HexString(),
) )
} else { } else {
resp.AuthURL = fmt.Sprintf("%s/register?key=%s", resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString())
} }
// save the requested expiry time for retrieval later in the authentication flow // save the requested expiry time for retrieval later in the authentication flow
m.RequestedExpiry = &req.Expiry machine.RequestedExpiry = &req.Expiry
m.NodeKey = wgkey.Key(req.NodeKey).HexString() // save the NodeKey machine.NodeKey = wgkey.Key(req.NodeKey).HexString() // save the NodeKey
h.db.Save(&m) h.db.Save(&machine)
respBody, err := encode(resp, &mKey, h.privateKey) respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "Registration"). Str("handler", "Registration").
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
c.String(http.StatusInternalServerError, "") ctx.String(http.StatusInternalServerError, "")
return return
} }
c.Data(http.StatusOK, "application/json; charset=utf-8", respBody) ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
} }
func (h *Headscale) getMapResponse( func (h *Headscale) getMapResponse(
mKey wgkey.Key, machineKey wgkey.Key,
req tailcfg.MapRequest, req tailcfg.MapRequest,
m *Machine, machine *Machine,
) ([]byte, error) { ) ([]byte, error) {
log.Trace(). log.Trace().
Str("func", "getMapResponse"). Str("func", "getMapResponse").
Str("machine", req.Hostinfo.Hostname). Str("machine", req.Hostinfo.Hostname).
Msg("Creating Map response") Msg("Creating Map response")
node, err := m.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig, true) node, err := machine.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("func", "getMapResponse"). Str("func", "getMapResponse").
@ -299,7 +304,7 @@ func (h *Headscale) getMapResponse(
return nil, err return nil, err
} }
peers, err := h.getPeers(m) peers, err := h.getPeers(machine)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("func", "getMapResponse"). Str("func", "getMapResponse").
@ -309,7 +314,7 @@ func (h *Headscale) getMapResponse(
return nil, err return nil, err
} }
profiles := getMapResponseUserProfiles(*m, peers) profiles := getMapResponseUserProfiles(*machine, peers)
nodePeers, err := peers.toNodes(h.cfg.BaseDomain, h.cfg.DNSConfig, true) nodePeers, err := peers.toNodes(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
if err != nil { if err != nil {
@ -324,7 +329,7 @@ func (h *Headscale) getMapResponse(
dnsConfig := getMapResponseDNSConfig( dnsConfig := getMapResponseDNSConfig(
h.cfg.DNSConfig, h.cfg.DNSConfig,
h.cfg.BaseDomain, h.cfg.BaseDomain,
*m, *machine,
peers, peers,
) )
@ -351,12 +356,12 @@ func (h *Headscale) getMapResponse(
encoder, _ := zstd.NewWriter(nil) encoder, _ := zstd.NewWriter(nil)
srcCompressed := encoder.EncodeAll(src, nil) srcCompressed := encoder.EncodeAll(src, nil)
respBody, err = encodeMsg(srcCompressed, &mKey, h.privateKey) respBody, err = encodeMsg(srcCompressed, &machineKey, h.privateKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} else { } else {
respBody, err = encode(resp, &mKey, h.privateKey) respBody, err = encode(resp, &machineKey, h.privateKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -370,24 +375,24 @@ func (h *Headscale) getMapResponse(
} }
func (h *Headscale) getMapKeepAliveResponse( func (h *Headscale) getMapKeepAliveResponse(
mKey wgkey.Key, machineKey wgkey.Key,
req tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
) ([]byte, error) { ) ([]byte, error) {
resp := tailcfg.MapResponse{ mapResponse := tailcfg.MapResponse{
KeepAlive: true, KeepAlive: true,
} }
var respBody []byte var respBody []byte
var err error var err error
if req.Compress == "zstd" { if mapRequest.Compress == "zstd" {
src, _ := json.Marshal(resp) src, _ := json.Marshal(mapResponse)
encoder, _ := zstd.NewWriter(nil) encoder, _ := zstd.NewWriter(nil)
srcCompressed := encoder.EncodeAll(src, nil) srcCompressed := encoder.EncodeAll(src, nil)
respBody, err = encodeMsg(srcCompressed, &mKey, h.privateKey) respBody, err = encodeMsg(srcCompressed, &machineKey, h.privateKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} else { } else {
respBody, err = encode(resp, &mKey, h.privateKey) respBody, err = encode(mapResponse, &machineKey, h.privateKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -400,22 +405,22 @@ func (h *Headscale) getMapKeepAliveResponse(
} }
func (h *Headscale) handleAuthKey( func (h *Headscale) handleAuthKey(
c *gin.Context, ctx *gin.Context,
db *gorm.DB, db *gorm.DB,
idKey wgkey.Key, idKey wgkey.Key,
req tailcfg.RegisterRequest, reqisterRequest tailcfg.RegisterRequest,
m Machine, machine Machine,
) { ) {
log.Debug(). log.Debug().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", req.Hostinfo.Hostname). Str("machine", reqisterRequest.Hostinfo.Hostname).
Msgf("Processing auth key for %s", req.Hostinfo.Hostname) Msgf("Processing auth key for %s", reqisterRequest.Hostinfo.Hostname)
resp := tailcfg.RegisterResponse{} resp := tailcfg.RegisterResponse{}
pak, err := h.checkKeyValidity(req.Auth.AuthKey) pak, err := h.checkKeyValidity(reqisterRequest.Auth.AuthKey)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", m.Name). Str("machine", machine.Name).
Err(err). Err(err).
Msg("Failed authentication via AuthKey") Msg("Failed authentication via AuthKey")
resp.MachineAuthorized = false resp.MachineAuthorized = false
@ -423,21 +428,21 @@ func (h *Headscale) handleAuthKey(
if err != nil { if err != nil {
log.Error(). log.Error().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", m.Name). Str("machine", machine.Name).
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
c.String(http.StatusInternalServerError, "") ctx.String(http.StatusInternalServerError, "")
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name). machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
Inc() Inc()
return return
} }
c.Data(http.StatusUnauthorized, "application/json; charset=utf-8", respBody) ctx.Data(http.StatusUnauthorized, "application/json; charset=utf-8", respBody)
log.Error(). log.Error().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Failed authentication via AuthKey") Msg("Failed authentication via AuthKey")
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name). machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
Inc() Inc()
return return
@ -445,32 +450,34 @@ func (h *Headscale) handleAuthKey(
log.Debug(). log.Debug().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Authentication key was valid, proceeding to acquire an IP address") Msg("Authentication key was valid, proceeding to acquire an IP address")
ip, err := h.getAvailableIP() ip, err := h.getAvailableIP()
if err != nil { if err != nil {
log.Error(). log.Error().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Failed to find an available IP") Msg("Failed to find an available IP")
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name). machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
Inc() Inc()
return return
} }
log.Info(). log.Info().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", m.Name). Str("machine", machine.Name).
Str("ip", ip.String()). Str("ip", ip.String()).
Msgf("Assigning %s to %s", ip, m.Name) Msgf("Assigning %s to %s", ip, machine.Name)
m.AuthKeyID = uint(pak.ID) machine.AuthKeyID = uint(pak.ID)
m.IPAddress = ip.String() machine.IPAddress = ip.String()
m.NamespaceID = pak.NamespaceID machine.NamespaceID = pak.NamespaceID
m.NodeKey = wgkey.Key(req.NodeKey).HexString() // we update it just in case machine.NodeKey = wgkey.Key(reqisterRequest.NodeKey).
m.Registered = true HexString()
m.RegisterMethod = "authKey" // we update it just in case
db.Save(&m) machine.Registered = true
machine.RegisterMethod = "authKey"
db.Save(&machine)
pak.Used = true pak.Used = true
db.Save(&pak) db.Save(&pak)
@ -481,21 +488,21 @@ func (h *Headscale) handleAuthKey(
if err != nil { if err != nil {
log.Error(). log.Error().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", m.Name). Str("machine", machine.Name).
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name). machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
Inc() Inc()
c.String(http.StatusInternalServerError, "Extremely sad!") ctx.String(http.StatusInternalServerError, "Extremely sad!")
return return
} }
machineRegistrations.WithLabelValues("new", "authkey", "success", m.Namespace.Name). machineRegistrations.WithLabelValues("new", "authkey", "success", machine.Namespace.Name).
Inc() Inc()
c.Data(http.StatusOK, "application/json; charset=utf-8", respBody) ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
log.Info(). log.Info().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", m.Name). Str("machine", machine.Name).
Str("ip", ip.String()). Str("ip", ip.String()).
Msg("Successfully authenticated via AuthKey") Msg("Successfully authenticated via AuthKey")
} }

140
app.go
View file

@ -169,7 +169,7 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
return nil, errors.New("unsupported DB") return nil, errors.New("unsupported DB")
} }
h := Headscale{ app := Headscale{
cfg: cfg, cfg: cfg,
dbType: cfg.DBtype, dbType: cfg.DBtype,
dbString: dbString, dbString: dbString,
@ -178,32 +178,32 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
aclRules: tailcfg.FilterAllowAll, // default allowall aclRules: tailcfg.FilterAllowAll, // default allowall
} }
err = h.initDB() err = app.initDB()
if err != nil { if err != nil {
return nil, err return nil, err
} }
if cfg.OIDC.Issuer != "" { if cfg.OIDC.Issuer != "" {
err = h.initOIDC() err = app.initOIDC()
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
if h.cfg.DNSConfig != nil && h.cfg.DNSConfig.Proxied { // if MagicDNS if app.cfg.DNSConfig != nil && app.cfg.DNSConfig.Proxied { // if MagicDNS
magicDNSDomains := generateMagicDNSRootDomains( magicDNSDomains := generateMagicDNSRootDomains(
h.cfg.IPPrefix, app.cfg.IPPrefix,
) )
// we might have routes already from Split DNS // we might have routes already from Split DNS
if h.cfg.DNSConfig.Routes == nil { if app.cfg.DNSConfig.Routes == nil {
h.cfg.DNSConfig.Routes = make(map[string][]dnstype.Resolver) app.cfg.DNSConfig.Routes = make(map[string][]dnstype.Resolver)
} }
for _, d := range magicDNSDomains { for _, d := range magicDNSDomains {
h.cfg.DNSConfig.Routes[d.WithoutTrailingDot()] = nil app.cfg.DNSConfig.Routes[d.WithoutTrailingDot()] = nil
} }
} }
return &h, nil return &app, nil
} }
// Redirect to our TLS url. // Redirect to our TLS url.
@ -229,35 +229,37 @@ func (h *Headscale) expireEphemeralNodesWorker() {
return return
} }
for _, ns := range namespaces { for _, namespace := range namespaces {
machines, err := h.ListMachinesInNamespace(ns.Name) machines, err := h.ListMachinesInNamespace(namespace.Name)
if err != nil { if err != nil {
log.Error(). log.Error().
Err(err). Err(err).
Str("namespace", ns.Name). Str("namespace", namespace.Name).
Msg("Error listing machines in namespace") Msg("Error listing machines in namespace")
return return
} }
for _, m := range machines { for _, machine := range machines {
if m.AuthKey != nil && m.LastSeen != nil && m.AuthKey.Ephemeral && if machine.AuthKey != nil && machine.LastSeen != nil &&
time.Now().After(m.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) { machine.AuthKey.Ephemeral &&
time.Now().
After(machine.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) {
log.Info(). log.Info().
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Ephemeral client removed from database") Msg("Ephemeral client removed from database")
err = h.db.Unscoped().Delete(m).Error err = h.db.Unscoped().Delete(machine).Error
if err != nil { if err != nil {
log.Error(). log.Error().
Err(err). Err(err).
Str("machine", m.Name). Str("machine", machine.Name).
Msg("🤮 Cannot delete ephemeral machine from the database") Msg("🤮 Cannot delete ephemeral machine from the database")
} }
} }
} }
h.setLastStateChangeToNow(ns.Name) h.setLastStateChangeToNow(namespace.Name)
} }
} }
@ -284,18 +286,18 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
// with the "legacy" database-based client // with the "legacy" database-based client
// It is also neede for grpc-gateway to be able to connect to // It is also neede for grpc-gateway to be able to connect to
// the server // the server
p, _ := peer.FromContext(ctx) client, _ := peer.FromContext(ctx)
log.Trace(). log.Trace().
Caller(). Caller().
Str("client_address", p.Addr.String()). Str("client_address", client.Addr.String()).
Msg("Client is trying to authenticate") Msg("Client is trying to authenticate")
md, ok := metadata.FromIncomingContext(ctx) meta, ok := metadata.FromIncomingContext(ctx)
if !ok { if !ok {
log.Error(). log.Error().
Caller(). Caller().
Str("client_address", p.Addr.String()). Str("client_address", client.Addr.String()).
Msg("Retrieving metadata is failed") Msg("Retrieving metadata is failed")
return ctx, status.Errorf( return ctx, status.Errorf(
@ -304,11 +306,11 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
) )
} }
authHeader, ok := md["authorization"] authHeader, ok := meta["authorization"]
if !ok { if !ok {
log.Error(). log.Error().
Caller(). Caller().
Str("client_address", p.Addr.String()). Str("client_address", client.Addr.String()).
Msg("Authorization token is not supplied") Msg("Authorization token is not supplied")
return ctx, status.Errorf( return ctx, status.Errorf(
@ -322,7 +324,7 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
if !strings.HasPrefix(token, AUTH_PREFIX) { if !strings.HasPrefix(token, AUTH_PREFIX) {
log.Error(). log.Error().
Caller(). Caller().
Str("client_address", p.Addr.String()). Str("client_address", client.Addr.String()).
Msg(`missing "Bearer " prefix in "Authorization" header`) Msg(`missing "Bearer " prefix in "Authorization" header`)
return ctx, status.Error( return ctx, status.Error(
@ -353,25 +355,25 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
// return handler(ctx, req) // return handler(ctx, req)
} }
func (h *Headscale) httpAuthenticationMiddleware(c *gin.Context) { func (h *Headscale) httpAuthenticationMiddleware(ctx *gin.Context) {
log.Trace(). log.Trace().
Caller(). Caller().
Str("client_address", c.ClientIP()). Str("client_address", ctx.ClientIP()).
Msg("HTTP authentication invoked") Msg("HTTP authentication invoked")
authHeader := c.GetHeader("authorization") authHeader := ctx.GetHeader("authorization")
if !strings.HasPrefix(authHeader, AUTH_PREFIX) { if !strings.HasPrefix(authHeader, AUTH_PREFIX) {
log.Error(). log.Error().
Caller(). Caller().
Str("client_address", c.ClientIP()). Str("client_address", ctx.ClientIP()).
Msg(`missing "Bearer " prefix in "Authorization" header`) Msg(`missing "Bearer " prefix in "Authorization" header`)
c.AbortWithStatus(http.StatusUnauthorized) ctx.AbortWithStatus(http.StatusUnauthorized)
return return
} }
c.AbortWithStatus(http.StatusUnauthorized) ctx.AbortWithStatus(http.StatusUnauthorized)
// TODO(kradalby): Implement API key backend // TODO(kradalby): Implement API key backend
// Currently all traffic is unauthorized, this is intentional to allow // Currently all traffic is unauthorized, this is intentional to allow
@ -438,9 +440,9 @@ func (h *Headscale) Serve() error {
// Create the cmux object that will multiplex 2 protocols on the same port. // Create the cmux object that will multiplex 2 protocols on the same port.
// The two following listeners will be served on the same port below gracefully. // The two following listeners will be served on the same port below gracefully.
m := cmux.New(networkListener) networkMutex := cmux.New(networkListener)
// Match gRPC requests here // Match gRPC requests here
grpcListener := m.MatchWithWriters( grpcListener := networkMutex.MatchWithWriters(
cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc"), cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc"),
cmux.HTTP2MatchHeaderFieldSendSettings( cmux.HTTP2MatchHeaderFieldSendSettings(
"content-type", "content-type",
@ -448,7 +450,7 @@ func (h *Headscale) Serve() error {
), ),
) )
// Otherwise match regular http requests. // Otherwise match regular http requests.
httpListener := m.Match(cmux.Any()) httpListener := networkMutex.Match(cmux.Any())
grpcGatewayMux := runtime.NewServeMux() grpcGatewayMux := runtime.NewServeMux()
@ -471,33 +473,33 @@ func (h *Headscale) Serve() error {
return err return err
} }
r := gin.Default() router := gin.Default()
p := ginprometheus.NewPrometheus("gin") prometheus := ginprometheus.NewPrometheus("gin")
p.Use(r) prometheus.Use(router)
r.GET( router.GET(
"/health", "/health",
func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"healthy": "ok"}) }, func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"healthy": "ok"}) },
) )
r.GET("/key", h.KeyHandler) router.GET("/key", h.KeyHandler)
r.GET("/register", h.RegisterWebAPI) router.GET("/register", h.RegisterWebAPI)
r.POST("/machine/:id/map", h.PollNetMapHandler) router.POST("/machine/:id/map", h.PollNetMapHandler)
r.POST("/machine/:id", h.RegistrationHandler) router.POST("/machine/:id", h.RegistrationHandler)
r.GET("/oidc/register/:mkey", h.RegisterOIDC) router.GET("/oidc/register/:mkey", h.RegisterOIDC)
r.GET("/oidc/callback", h.OIDCCallback) router.GET("/oidc/callback", h.OIDCCallback)
r.GET("/apple", h.AppleMobileConfig) router.GET("/apple", h.AppleMobileConfig)
r.GET("/apple/:platform", h.ApplePlatformConfig) router.GET("/apple/:platform", h.ApplePlatformConfig)
r.GET("/swagger", SwaggerUI) router.GET("/swagger", SwaggerUI)
r.GET("/swagger/v1/openapiv2.json", SwaggerAPIv1) router.GET("/swagger/v1/openapiv2.json", SwaggerAPIv1)
api := r.Group("/api") api := router.Group("/api")
api.Use(h.httpAuthenticationMiddleware) api.Use(h.httpAuthenticationMiddleware)
{ {
api.Any("/v1/*any", gin.WrapF(grpcGatewayMux.ServeHTTP)) api.Any("/v1/*any", gin.WrapF(grpcGatewayMux.ServeHTTP))
} }
r.NoRoute(stdoutHandler) router.NoRoute(stdoutHandler)
// Fetch an initial DERP Map before we start serving // Fetch an initial DERP Map before we start serving
h.DERPMap = GetDERPMap(h.cfg.DERP) h.DERPMap = GetDERPMap(h.cfg.DERP)
@ -514,7 +516,7 @@ func (h *Headscale) Serve() error {
httpServer := &http.Server{ httpServer := &http.Server{
Addr: h.cfg.Addr, Addr: h.cfg.Addr,
Handler: r, Handler: router,
ReadTimeout: HTTP_READ_TIMEOUT, ReadTimeout: HTTP_READ_TIMEOUT,
// Go does not handle timeouts in HTTP very well, and there is // Go does not handle timeouts in HTTP very well, and there is
// no good way to handle streaming timeouts, therefore we need to // no good way to handle streaming timeouts, therefore we need to
@ -561,29 +563,29 @@ func (h *Headscale) Serve() error {
reflection.Register(grpcServer) reflection.Register(grpcServer)
reflection.Register(grpcSocket) reflection.Register(grpcSocket)
g := new(errgroup.Group) errorGroup := new(errgroup.Group)
g.Go(func() error { return grpcSocket.Serve(socketListener) }) errorGroup.Go(func() error { return grpcSocket.Serve(socketListener) })
// TODO(kradalby): Verify if we need the same TLS setup for gRPC as HTTP // TODO(kradalby): Verify if we need the same TLS setup for gRPC as HTTP
g.Go(func() error { return grpcServer.Serve(grpcListener) }) errorGroup.Go(func() error { return grpcServer.Serve(grpcListener) })
if tlsConfig != nil { if tlsConfig != nil {
g.Go(func() error { errorGroup.Go(func() error {
tlsl := tls.NewListener(httpListener, tlsConfig) tlsl := tls.NewListener(httpListener, tlsConfig)
return httpServer.Serve(tlsl) return httpServer.Serve(tlsl)
}) })
} else { } else {
g.Go(func() error { return httpServer.Serve(httpListener) }) errorGroup.Go(func() error { return httpServer.Serve(httpListener) })
} }
g.Go(func() error { return m.Serve() }) errorGroup.Go(func() error { return networkMutex.Serve() })
log.Info(). log.Info().
Msgf("listening and serving (multiplexed HTTP and gRPC) on: %s", h.cfg.Addr) Msgf("listening and serving (multiplexed HTTP and gRPC) on: %s", h.cfg.Addr)
return g.Wait() return errorGroup.Wait()
} }
func (h *Headscale) getTLSSettings() (*tls.Config, error) { func (h *Headscale) getTLSSettings() (*tls.Config, error) {
@ -594,7 +596,7 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
Msg("Listening with TLS but ServerURL does not start with https://") Msg("Listening with TLS but ServerURL does not start with https://")
} }
m := autocert.Manager{ certManager := autocert.Manager{
Prompt: autocert.AcceptTOS, Prompt: autocert.AcceptTOS,
HostPolicy: autocert.HostWhitelist(h.cfg.TLSLetsEncryptHostname), HostPolicy: autocert.HostWhitelist(h.cfg.TLSLetsEncryptHostname),
Cache: autocert.DirCache(h.cfg.TLSLetsEncryptCacheDir), Cache: autocert.DirCache(h.cfg.TLSLetsEncryptCacheDir),
@ -609,7 +611,7 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
// Configuration via autocert with TLS-ALPN-01 (https://tools.ietf.org/html/rfc8737) // Configuration via autocert with TLS-ALPN-01 (https://tools.ietf.org/html/rfc8737)
// The RFC requires that the validation is done on port 443; in other words, headscale // The RFC requires that the validation is done on port 443; in other words, headscale
// must be reachable on port 443. // must be reachable on port 443.
return m.TLSConfig(), nil return certManager.TLSConfig(), nil
case "HTTP-01": case "HTTP-01":
// Configuration via autocert with HTTP-01. This requires listening on // Configuration via autocert with HTTP-01. This requires listening on
@ -617,11 +619,11 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
// service, which can be configured to run on any other port. // service, which can be configured to run on any other port.
go func() { go func() {
log.Fatal(). log.Fatal().
Err(http.ListenAndServe(h.cfg.TLSLetsEncryptListen, m.HTTPHandler(http.HandlerFunc(h.redirect)))). Err(http.ListenAndServe(h.cfg.TLSLetsEncryptListen, certManager.HTTPHandler(http.HandlerFunc(h.redirect)))).
Msg("failed to set up a HTTP server") Msg("failed to set up a HTTP server")
}() }()
return m.TLSConfig(), nil return certManager.TLSConfig(), nil
default: default:
return nil, errors.New("unknown value for TLSLetsEncryptChallengeType") return nil, errors.New("unknown value for TLSLetsEncryptChallengeType")
@ -676,13 +678,13 @@ func (h *Headscale) getLastStateChange(namespaces ...string) time.Time {
} }
} }
func stdoutHandler(c *gin.Context) { func stdoutHandler(ctx *gin.Context) {
b, _ := io.ReadAll(c.Request.Body) body, _ := io.ReadAll(ctx.Request.Body)
log.Trace(). log.Trace().
Interface("header", c.Request.Header). Interface("header", ctx.Request.Header).
Interface("proto", c.Request.Proto). Interface("proto", ctx.Request.Proto).
Interface("url", c.Request.URL). Interface("url", ctx.Request.URL).
Bytes("body", b). Bytes("body", body).
Msg("Request did not match") Msg("Request did not match")
} }

View file

@ -12,8 +12,8 @@ import (
// AppleMobileConfig shows a simple message in the browser to point to the CLI // AppleMobileConfig shows a simple message in the browser to point to the CLI
// Listens in /register. // Listens in /register.
func (h *Headscale) AppleMobileConfig(c *gin.Context) { func (h *Headscale) AppleMobileConfig(ctx *gin.Context) {
t := template.Must(template.New("apple").Parse(` appleTemplate := template.Must(template.New("apple").Parse(`
<html> <html>
<body> <body>
<h1>Apple configuration profiles</h1> <h1>Apple configuration profiles</h1>
@ -67,12 +67,12 @@ func (h *Headscale) AppleMobileConfig(c *gin.Context) {
} }
var payload bytes.Buffer var payload bytes.Buffer
if err := t.Execute(&payload, config); err != nil { if err := appleTemplate.Execute(&payload, config); err != nil {
log.Error(). log.Error().
Str("handler", "AppleMobileConfig"). Str("handler", "AppleMobileConfig").
Err(err). Err(err).
Msg("Could not render Apple index template") Msg("Could not render Apple index template")
c.Data( ctx.Data(
http.StatusInternalServerError, http.StatusInternalServerError,
"text/html; charset=utf-8", "text/html; charset=utf-8",
[]byte("Could not render Apple index template"), []byte("Could not render Apple index template"),
@ -81,11 +81,11 @@ func (h *Headscale) AppleMobileConfig(c *gin.Context) {
return return
} }
c.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes()) ctx.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes())
} }
func (h *Headscale) ApplePlatformConfig(c *gin.Context) { func (h *Headscale) ApplePlatformConfig(ctx *gin.Context) {
platform := c.Param("platform") platform := ctx.Param("platform")
id, err := uuid.NewV4() id, err := uuid.NewV4()
if err != nil { if err != nil {
@ -93,7 +93,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
Str("handler", "ApplePlatformConfig"). Str("handler", "ApplePlatformConfig").
Err(err). Err(err).
Msg("Failed not create UUID") Msg("Failed not create UUID")
c.Data( ctx.Data(
http.StatusInternalServerError, http.StatusInternalServerError,
"text/html; charset=utf-8", "text/html; charset=utf-8",
[]byte("Failed to create UUID"), []byte("Failed to create UUID"),
@ -108,7 +108,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
Str("handler", "ApplePlatformConfig"). Str("handler", "ApplePlatformConfig").
Err(err). Err(err).
Msg("Failed not create UUID") Msg("Failed not create UUID")
c.Data( ctx.Data(
http.StatusInternalServerError, http.StatusInternalServerError,
"text/html; charset=utf-8", "text/html; charset=utf-8",
[]byte("Failed to create UUID"), []byte("Failed to create UUID"),
@ -131,7 +131,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
Str("handler", "ApplePlatformConfig"). Str("handler", "ApplePlatformConfig").
Err(err). Err(err).
Msg("Could not render Apple macOS template") Msg("Could not render Apple macOS template")
c.Data( ctx.Data(
http.StatusInternalServerError, http.StatusInternalServerError,
"text/html; charset=utf-8", "text/html; charset=utf-8",
[]byte("Could not render Apple macOS template"), []byte("Could not render Apple macOS template"),
@ -145,7 +145,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
Str("handler", "ApplePlatformConfig"). Str("handler", "ApplePlatformConfig").
Err(err). Err(err).
Msg("Could not render Apple iOS template") Msg("Could not render Apple iOS template")
c.Data( ctx.Data(
http.StatusInternalServerError, http.StatusInternalServerError,
"text/html; charset=utf-8", "text/html; charset=utf-8",
[]byte("Could not render Apple iOS template"), []byte("Could not render Apple iOS template"),
@ -154,7 +154,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
return return
} }
default: default:
c.Data( ctx.Data(
http.StatusOK, http.StatusOK,
"text/html; charset=utf-8", "text/html; charset=utf-8",
[]byte("Invalid platform, only ios and macos is supported"), []byte("Invalid platform, only ios and macos is supported"),
@ -175,7 +175,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
Str("handler", "ApplePlatformConfig"). Str("handler", "ApplePlatformConfig").
Err(err). Err(err).
Msg("Could not render Apple platform template") Msg("Could not render Apple platform template")
c.Data( ctx.Data(
http.StatusInternalServerError, http.StatusInternalServerError,
"text/html; charset=utf-8", "text/html; charset=utf-8",
[]byte("Could not render Apple platform template"), []byte("Could not render Apple platform template"),
@ -184,7 +184,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
return return
} }
c.Data( ctx.Data(
http.StatusOK, http.StatusOK,
"application/x-apple-aspen-config; charset=utf-8", "application/x-apple-aspen-config; charset=utf-8",
content.Bytes(), content.Bytes(),

View file

@ -167,10 +167,10 @@ var listNamespacesCmd = &cobra.Command{
return return
} }
d := pterm.TableData{{"ID", "Name", "Created"}} tableData := pterm.TableData{{"ID", "Name", "Created"}}
for _, namespace := range response.GetNamespaces() { for _, namespace := range response.GetNamespaces() {
d = append( tableData = append(
d, tableData,
[]string{ []string{
namespace.GetId(), namespace.GetId(),
namespace.GetName(), namespace.GetName(),
@ -178,7 +178,7 @@ var listNamespacesCmd = &cobra.Command{
}, },
) )
} }
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render() err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil { if err != nil {
ErrorOutput( ErrorOutput(
err, err,

View file

@ -157,14 +157,14 @@ var listNodesCmd = &cobra.Command{
return return
} }
d, err := nodesToPtables(namespace, response.Machines) tableData, err := nodesToPtables(namespace, response.Machines)
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
return return
} }
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render() err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil { if err != nil {
ErrorOutput( ErrorOutput(
err, err,
@ -183,7 +183,7 @@ var deleteNodeCmd = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output") output, _ := cmd.Flags().GetString("output")
id, err := cmd.Flags().GetInt("identifier") identifier, err := cmd.Flags().GetInt("identifier")
if err != nil { if err != nil {
ErrorOutput( ErrorOutput(
err, err,
@ -199,7 +199,7 @@ var deleteNodeCmd = &cobra.Command{
defer conn.Close() defer conn.Close()
getRequest := &v1.GetMachineRequest{ getRequest := &v1.GetMachineRequest{
MachineId: uint64(id), MachineId: uint64(identifier),
} }
getResponse, err := client.GetMachine(ctx, getRequest) getResponse, err := client.GetMachine(ctx, getRequest)
@ -217,7 +217,7 @@ var deleteNodeCmd = &cobra.Command{
} }
deleteRequest := &v1.DeleteMachineRequest{ deleteRequest := &v1.DeleteMachineRequest{
MachineId: uint64(id), MachineId: uint64(identifier),
} }
confirm := false confirm := false
@ -280,7 +280,7 @@ func sharingWorker(
defer cancel() defer cancel()
defer conn.Close() defer conn.Close()
id, err := cmd.Flags().GetInt("identifier") identifier, err := cmd.Flags().GetInt("identifier")
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting ID to integer: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error converting ID to integer: %s", err), output)
@ -288,7 +288,7 @@ func sharingWorker(
} }
machineRequest := &v1.GetMachineRequest{ machineRequest := &v1.GetMachineRequest{
MachineId: uint64(id), MachineId: uint64(identifier),
} }
machineResponse, err := client.GetMachine(ctx, machineRequest) machineResponse, err := client.GetMachine(ctx, machineRequest)
@ -402,7 +402,7 @@ func nodesToPtables(
currentNamespace string, currentNamespace string,
machines []*v1.Machine, machines []*v1.Machine,
) (pterm.TableData, error) { ) (pterm.TableData, error) {
d := pterm.TableData{ tableData := pterm.TableData{
{ {
"ID", "ID",
"Name", "Name",
@ -448,8 +448,8 @@ func nodesToPtables(
// Shared into this namespace // Shared into this namespace
namespace = pterm.LightYellow(machine.Namespace.Name) namespace = pterm.LightYellow(machine.Namespace.Name)
} }
d = append( tableData = append(
d, tableData,
[]string{ []string{
strconv.FormatUint(machine.Id, headscale.BASE_10), strconv.FormatUint(machine.Id, headscale.BASE_10),
machine.Name, machine.Name,
@ -463,5 +463,5 @@ func nodesToPtables(
) )
} }
return d, nil return tableData, nil
} }

View file

@ -45,7 +45,7 @@ var listPreAuthKeys = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output") output, _ := cmd.Flags().GetString("output")
n, err := cmd.Flags().GetString("namespace") namespace, err := cmd.Flags().GetString("namespace")
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
@ -57,7 +57,7 @@ var listPreAuthKeys = &cobra.Command{
defer conn.Close() defer conn.Close()
request := &v1.ListPreAuthKeysRequest{ request := &v1.ListPreAuthKeysRequest{
Namespace: n, Namespace: namespace,
} }
response, err := client.ListPreAuthKeys(ctx, request) response, err := client.ListPreAuthKeys(ctx, request)
@ -77,34 +77,34 @@ var listPreAuthKeys = &cobra.Command{
return return
} }
d := pterm.TableData{ tableData := pterm.TableData{
{"ID", "Key", "Reusable", "Ephemeral", "Used", "Expiration", "Created"}, {"ID", "Key", "Reusable", "Ephemeral", "Used", "Expiration", "Created"},
} }
for _, k := range response.PreAuthKeys { for _, key := range response.PreAuthKeys {
expiration := "-" expiration := "-"
if k.GetExpiration() != nil { if key.GetExpiration() != nil {
expiration = k.Expiration.AsTime().Format("2006-01-02 15:04:05") expiration = key.Expiration.AsTime().Format("2006-01-02 15:04:05")
} }
var reusable string var reusable string
if k.GetEphemeral() { if key.GetEphemeral() {
reusable = "N/A" reusable = "N/A"
} else { } else {
reusable = fmt.Sprintf("%v", k.GetReusable()) reusable = fmt.Sprintf("%v", key.GetReusable())
} }
d = append(d, []string{ tableData = append(tableData, []string{
k.GetId(), key.GetId(),
k.GetKey(), key.GetKey(),
reusable, reusable,
strconv.FormatBool(k.GetEphemeral()), strconv.FormatBool(key.GetEphemeral()),
strconv.FormatBool(k.GetUsed()), strconv.FormatBool(key.GetUsed()),
expiration, expiration,
k.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"), key.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"),
}) })
} }
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render() err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil { if err != nil {
ErrorOutput( ErrorOutput(
err, err,

View file

@ -81,14 +81,14 @@ var listRoutesCmd = &cobra.Command{
return return
} }
d := routesToPtables(response.Routes) tableData := routesToPtables(response.Routes)
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
return return
} }
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render() err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil { if err != nil {
ErrorOutput( ErrorOutput(
err, err,
@ -162,14 +162,14 @@ omit the route you do not want to enable.
return return
} }
d := routesToPtables(response.Routes) tableData := routesToPtables(response.Routes)
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
return return
} }
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render() err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil { if err != nil {
ErrorOutput( ErrorOutput(
err, err,
@ -184,15 +184,15 @@ omit the route you do not want to enable.
// routesToPtables converts the list of routes to a nice table. // routesToPtables converts the list of routes to a nice table.
func routesToPtables(routes *v1.Routes) pterm.TableData { func routesToPtables(routes *v1.Routes) pterm.TableData {
d := pterm.TableData{{"Route", "Enabled"}} tableData := pterm.TableData{{"Route", "Enabled"}}
for _, route := range routes.GetAdvertisedRoutes() { for _, route := range routes.GetAdvertisedRoutes() {
enabled := isStringInSlice(routes.EnabledRoutes, route) enabled := isStringInSlice(routes.EnabledRoutes, route)
d = append(d, []string{route, strconv.FormatBool(enabled)}) tableData = append(tableData, []string{route, strconv.FormatBool(enabled)})
} }
return d return tableData
} }
func isStringInSlice(strs []string, s string) bool { func isStringInSlice(strs []string, s string) bool {

View file

@ -318,7 +318,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
cfg.OIDC.MatchMap = loadOIDCMatchMap() cfg.OIDC.MatchMap = loadOIDCMatchMap()
h, err := headscale.NewHeadscale(cfg) app, err := headscale.NewHeadscale(cfg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -327,7 +327,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
if viper.GetString("acl_policy_path") != "" { if viper.GetString("acl_policy_path") != "" {
aclPath := absPath(viper.GetString("acl_policy_path")) aclPath := absPath(viper.GetString("acl_policy_path"))
err = h.LoadACLPolicy(aclPath) err = app.LoadACLPolicy(aclPath)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("path", aclPath). Str("path", aclPath).
@ -336,7 +336,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
} }
} }
return h, nil return app, nil
} }
func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) { func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) {

6
dns.go
View file

@ -79,7 +79,7 @@ func generateMagicDNSRootDomains(
func getMapResponseDNSConfig( func getMapResponseDNSConfig(
dnsConfigOrig *tailcfg.DNSConfig, dnsConfigOrig *tailcfg.DNSConfig,
baseDomain string, baseDomain string,
m Machine, machine Machine,
peers Machines, peers Machines,
) *tailcfg.DNSConfig { ) *tailcfg.DNSConfig {
var dnsConfig *tailcfg.DNSConfig var dnsConfig *tailcfg.DNSConfig
@ -88,11 +88,11 @@ func getMapResponseDNSConfig(
dnsConfig = dnsConfigOrig.Clone() dnsConfig = dnsConfigOrig.Clone()
dnsConfig.Domains = append( dnsConfig.Domains = append(
dnsConfig.Domains, dnsConfig.Domains,
fmt.Sprintf("%s.%s", m.Namespace.Name, baseDomain), fmt.Sprintf("%s.%s", machine.Namespace.Name, baseDomain),
) )
namespaceSet := set.New(set.ThreadSafe) namespaceSet := set.New(set.ThreadSafe)
namespaceSet.Add(m.Namespace) namespaceSet.Add(machine.Namespace)
for _, p := range peers { for _, p := range peers {
namespaceSet.Add(p.Namespace) namespaceSet.Add(p.Namespace)
} }

View file

@ -56,21 +56,21 @@ type (
) )
// For the time being this method is rather naive. // For the time being this method is rather naive.
func (m Machine) isAlreadyRegistered() bool { func (machine Machine) isAlreadyRegistered() bool {
return m.Registered return machine.Registered
} }
// isExpired returns whether the machine registration has expired. // isExpired returns whether the machine registration has expired.
func (m Machine) isExpired() bool { func (machine Machine) isExpired() bool {
return time.Now().UTC().After(*m.Expiry) return time.Now().UTC().After(*machine.Expiry)
} }
// If the Machine is expired, updateMachineExpiry updates the Machine Expiry time to the maximum allowed duration, // If the Machine is expired, updateMachineExpiry updates the Machine Expiry time to the maximum allowed duration,
// or the default duration if no Expiry time was requested by the client. The expiry time here does not (yet) cause // or the default duration if no Expiry time was requested by the client. The expiry time here does not (yet) cause
// a client to be disconnected, however they will have to re-auth the machine if they attempt to reconnect after the // a client to be disconnected, however they will have to re-auth the machine if they attempt to reconnect after the
// expiry time. // expiry time.
func (h *Headscale) updateMachineExpiry(m *Machine) { func (h *Headscale) updateMachineExpiry(machine *Machine) {
if m.isExpired() { if machine.isExpired() {
now := time.Now().UTC() now := time.Now().UTC()
maxExpiry := now.Add( maxExpiry := now.Add(
h.cfg.MaxMachineRegistrationDuration, h.cfg.MaxMachineRegistrationDuration,
@ -80,31 +80,31 @@ func (h *Headscale) updateMachineExpiry(m *Machine) {
) // calculate the default expiry ) // calculate the default expiry
// clamp the expiry time of the machine registration to the maximum allowed, or use the default if none supplied // clamp the expiry time of the machine registration to the maximum allowed, or use the default if none supplied
if maxExpiry.Before(*m.RequestedExpiry) { if maxExpiry.Before(*machine.RequestedExpiry) {
log.Debug(). log.Debug().
Msgf("Clamping registration expiry time to maximum: %v (%v)", maxExpiry, h.cfg.MaxMachineRegistrationDuration) Msgf("Clamping registration expiry time to maximum: %v (%v)", maxExpiry, h.cfg.MaxMachineRegistrationDuration)
m.Expiry = &maxExpiry machine.Expiry = &maxExpiry
} else if m.RequestedExpiry.IsZero() { } else if machine.RequestedExpiry.IsZero() {
log.Debug().Msgf("Using default machine registration expiry time: %v (%v)", defaultExpiry, h.cfg.DefaultMachineRegistrationDuration) log.Debug().Msgf("Using default machine registration expiry time: %v (%v)", defaultExpiry, h.cfg.DefaultMachineRegistrationDuration)
m.Expiry = &defaultExpiry machine.Expiry = &defaultExpiry
} else { } else {
log.Debug().Msgf("Using requested machine registration expiry time: %v", m.RequestedExpiry) log.Debug().Msgf("Using requested machine registration expiry time: %v", machine.RequestedExpiry)
m.Expiry = m.RequestedExpiry machine.Expiry = machine.RequestedExpiry
} }
h.db.Save(&m) h.db.Save(&machine)
} }
} }
func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) { func (h *Headscale) getDirectPeers(machine *Machine) (Machines, error) {
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Finding direct peers") Msg("Finding direct peers")
machines := Machines{} machines := Machines{}
if err := h.db.Preload("Namespace").Where("namespace_id = ? AND machine_key <> ? AND registered", if err := h.db.Preload("Namespace").Where("namespace_id = ? AND machine_key <> ? AND registered",
m.NamespaceID, m.MachineKey).Find(&machines).Error; err != nil { machine.NamespaceID, machine.MachineKey).Find(&machines).Error; err != nil {
log.Error().Err(err).Msg("Error accessing db") log.Error().Err(err).Msg("Error accessing db")
return Machines{}, err return Machines{}, err
@ -114,22 +114,22 @@ func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) {
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine", m.Name). Str("machine", machine.Name).
Msgf("Found direct machines: %s", machines.String()) Msgf("Found direct machines: %s", machines.String())
return machines, nil return machines, nil
} }
// getShared fetches machines that are shared to the `Namespace` of the machine we are getting peers for. // getShared fetches machines that are shared to the `Namespace` of the machine we are getting peers for.
func (h *Headscale) getShared(m *Machine) (Machines, error) { func (h *Headscale) getShared(machine *Machine) (Machines, error) {
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Finding shared peers") Msg("Finding shared peers")
sharedMachines := []SharedMachine{} sharedMachines := []SharedMachine{}
if err := h.db.Preload("Namespace").Preload("Machine").Preload("Machine.Namespace").Where("namespace_id = ?", if err := h.db.Preload("Namespace").Preload("Machine").Preload("Machine.Namespace").Where("namespace_id = ?",
m.NamespaceID).Find(&sharedMachines).Error; err != nil { machine.NamespaceID).Find(&sharedMachines).Error; err != nil {
return Machines{}, err return Machines{}, err
} }
@ -142,22 +142,22 @@ func (h *Headscale) getShared(m *Machine) (Machines, error) {
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine", m.Name). Str("machine", machine.Name).
Msgf("Found shared peers: %s", peers.String()) Msgf("Found shared peers: %s", peers.String())
return peers, nil return peers, nil
} }
// getSharedTo fetches the machines of the namespaces this machine is shared in. // getSharedTo fetches the machines of the namespaces this machine is shared in.
func (h *Headscale) getSharedTo(m *Machine) (Machines, error) { func (h *Headscale) getSharedTo(machine *Machine) (Machines, error) {
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Finding peers in namespaces this machine is shared with") Msg("Finding peers in namespaces this machine is shared with")
sharedMachines := []SharedMachine{} sharedMachines := []SharedMachine{}
if err := h.db.Preload("Namespace").Preload("Machine").Preload("Machine.Namespace").Where("machine_id = ?", if err := h.db.Preload("Namespace").Preload("Machine").Preload("Machine.Namespace").Where("machine_id = ?",
m.ID).Find(&sharedMachines).Error; err != nil { machine.ID).Find(&sharedMachines).Error; err != nil {
return Machines{}, err return Machines{}, err
} }
@ -176,14 +176,14 @@ func (h *Headscale) getSharedTo(m *Machine) (Machines, error) {
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine", m.Name). Str("machine", machine.Name).
Msgf("Found peers we are shared with: %s", peers.String()) Msgf("Found peers we are shared with: %s", peers.String())
return peers, nil return peers, nil
} }
func (h *Headscale) getPeers(m *Machine) (Machines, error) { func (h *Headscale) getPeers(machine *Machine) (Machines, error) {
direct, err := h.getDirectPeers(m) direct, err := h.getDirectPeers(machine)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
@ -193,7 +193,7 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) {
return Machines{}, err return Machines{}, err
} }
shared, err := h.getShared(m) shared, err := h.getShared(machine)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
@ -203,7 +203,7 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) {
return Machines{}, err return Machines{}, err
} }
sharedTo, err := h.getSharedTo(m) sharedTo, err := h.getSharedTo(machine)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
@ -220,7 +220,7 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) {
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine", m.Name). Str("machine", machine.Name).
Msgf("Found total peers: %s", peers.String()) Msgf("Found total peers: %s", peers.String())
return peers, nil return peers, nil
@ -262,9 +262,9 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
} }
// GetMachineByMachineKey finds a Machine by ID and returns the Machine struct. // GetMachineByMachineKey finds a Machine by ID and returns the Machine struct.
func (h *Headscale) GetMachineByMachineKey(mKey string) (*Machine, error) { func (h *Headscale) GetMachineByMachineKey(machineKey string) (*Machine, error) {
m := Machine{} m := Machine{}
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey); result.Error != nil { if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", machineKey); result.Error != nil {
return nil, result.Error return nil, result.Error
} }
@ -273,8 +273,8 @@ func (h *Headscale) GetMachineByMachineKey(mKey string) (*Machine, error) {
// UpdateMachine takes a Machine struct pointer (typically already loaded from database // UpdateMachine takes a Machine struct pointer (typically already loaded from database
// and updates it with the latest data from the database. // and updates it with the latest data from the database.
func (h *Headscale) UpdateMachine(m *Machine) error { func (h *Headscale) UpdateMachine(machine *Machine) error {
if result := h.db.Find(m).First(&m); result.Error != nil { if result := h.db.Find(machine).First(&machine); result.Error != nil {
return result.Error return result.Error
} }
@ -282,16 +282,16 @@ func (h *Headscale) UpdateMachine(m *Machine) error {
} }
// DeleteMachine softs deletes a Machine from the database. // DeleteMachine softs deletes a Machine from the database.
func (h *Headscale) DeleteMachine(m *Machine) error { func (h *Headscale) DeleteMachine(machine *Machine) error {
err := h.RemoveSharedMachineFromAllNamespaces(m) err := h.RemoveSharedMachineFromAllNamespaces(machine)
if err != nil && err != errorMachineNotShared { if err != nil && err != errorMachineNotShared {
return err return err
} }
m.Registered = false machine.Registered = false
namespaceID := m.NamespaceID namespaceID := machine.NamespaceID
h.db.Save(&m) // we mark it as unregistered, just in case h.db.Save(&machine) // we mark it as unregistered, just in case
if err := h.db.Delete(&m).Error; err != nil { if err := h.db.Delete(&machine).Error; err != nil {
return err return err
} }
@ -299,14 +299,14 @@ func (h *Headscale) DeleteMachine(m *Machine) error {
} }
// HardDeleteMachine hard deletes a Machine from the database. // HardDeleteMachine hard deletes a Machine from the database.
func (h *Headscale) HardDeleteMachine(m *Machine) error { func (h *Headscale) HardDeleteMachine(machine *Machine) error {
err := h.RemoveSharedMachineFromAllNamespaces(m) err := h.RemoveSharedMachineFromAllNamespaces(machine)
if err != nil && err != errorMachineNotShared { if err != nil && err != errorMachineNotShared {
return err return err
} }
namespaceID := m.NamespaceID namespaceID := machine.NamespaceID
if err := h.db.Unscoped().Delete(&m).Error; err != nil { if err := h.db.Unscoped().Delete(&machine).Error; err != nil {
return err return err
} }
@ -314,10 +314,10 @@ func (h *Headscale) HardDeleteMachine(m *Machine) error {
} }
// GetHostInfo returns a Hostinfo struct for the machine. // GetHostInfo returns a Hostinfo struct for the machine.
func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) { func (machine *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) {
hostinfo := tailcfg.Hostinfo{} hostinfo := tailcfg.Hostinfo{}
if len(m.HostInfo) != 0 { if len(machine.HostInfo) != 0 {
hi, err := m.HostInfo.MarshalJSON() hi, err := machine.HostInfo.MarshalJSON()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -330,17 +330,17 @@ func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) {
return &hostinfo, nil return &hostinfo, nil
} }
func (h *Headscale) isOutdated(m *Machine) bool { func (h *Headscale) isOutdated(machine *Machine) bool {
if err := h.UpdateMachine(m); err != nil { if err := h.UpdateMachine(machine); err != nil {
// It does not seem meaningful to propagate this error as the end result // It does not seem meaningful to propagate this error as the end result
// will have to be that the machine has to be considered outdated. // will have to be that the machine has to be considered outdated.
return true return true
} }
sharedMachines, _ := h.getShared(m) sharedMachines, _ := h.getShared(machine)
namespaceSet := set.New(set.ThreadSafe) namespaceSet := set.New(set.ThreadSafe)
namespaceSet.Add(m.Namespace.Name) namespaceSet.Add(machine.Namespace.Name)
// Check if any of our shared namespaces has updates that we have // Check if any of our shared namespaces has updates that we have
// not propagated. // not propagated.
@ -356,22 +356,22 @@ func (h *Headscale) isOutdated(m *Machine) bool {
lastChange := h.getLastStateChange(namespaces...) lastChange := h.getLastStateChange(namespaces...)
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine", m.Name). Str("machine", machine.Name).
Time("last_successful_update", *m.LastSuccessfulUpdate). Time("last_successful_update", *machine.LastSuccessfulUpdate).
Time("last_state_change", lastChange). Time("last_state_change", lastChange).
Msgf("Checking if %s is missing updates", m.Name) Msgf("Checking if %s is missing updates", machine.Name)
return m.LastSuccessfulUpdate.Before(lastChange) return machine.LastSuccessfulUpdate.Before(lastChange)
} }
func (m Machine) String() string { func (machine Machine) String() string {
return m.Name return machine.Name
} }
func (ms Machines) String() string { func (machines Machines) String() string {
temp := make([]string, len(ms)) temp := make([]string, len(machines))
for index, machine := range ms { for index, machine := range machines {
temp[index] = machine.Name temp[index] = machine.Name
} }
@ -379,24 +379,24 @@ func (ms Machines) String() string {
} }
// TODO(kradalby): Remove when we have generics... // TODO(kradalby): Remove when we have generics...
func (ms MachinesP) String() string { func (machines MachinesP) String() string {
temp := make([]string, len(ms)) temp := make([]string, len(machines))
for index, machine := range ms { for index, machine := range machines {
temp[index] = machine.Name temp[index] = machine.Name
} }
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp)) return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
} }
func (ms Machines) toNodes( func (machines Machines) toNodes(
baseDomain string, baseDomain string,
dnsConfig *tailcfg.DNSConfig, dnsConfig *tailcfg.DNSConfig,
includeRoutes bool, includeRoutes bool,
) ([]*tailcfg.Node, error) { ) ([]*tailcfg.Node, error) {
nodes := make([]*tailcfg.Node, len(ms)) nodes := make([]*tailcfg.Node, len(machines))
for index, machine := range ms { for index, machine := range machines {
node, err := machine.toNode(baseDomain, dnsConfig, includeRoutes) node, err := machine.toNode(baseDomain, dnsConfig, includeRoutes)
if err != nil { if err != nil {
return nil, err return nil, err
@ -410,23 +410,24 @@ func (ms Machines) toNodes(
// toNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes // toNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes
// as per the expected behaviour in the official SaaS. // as per the expected behaviour in the official SaaS.
func (m Machine) toNode( func (machine Machine) toNode(
baseDomain string, baseDomain string,
dnsConfig *tailcfg.DNSConfig, dnsConfig *tailcfg.DNSConfig,
includeRoutes bool, includeRoutes bool,
) (*tailcfg.Node, error) { ) (*tailcfg.Node, error) {
nKey, err := wgkey.ParseHex(m.NodeKey) nodeKey, err := wgkey.ParseHex(machine.NodeKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
mKey, err := wgkey.ParseHex(m.MachineKey)
machineKey, err := wgkey.ParseHex(machine.MachineKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var discoKey tailcfg.DiscoKey var discoKey tailcfg.DiscoKey
if m.DiscoKey != "" { if machine.DiscoKey != "" {
dKey, err := wgkey.ParseHex(m.DiscoKey) dKey, err := wgkey.ParseHex(machine.DiscoKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -436,12 +437,12 @@ func (m Machine) toNode(
} }
addrs := []netaddr.IPPrefix{} addrs := []netaddr.IPPrefix{}
ip, err := netaddr.ParseIPPrefix(fmt.Sprintf("%s/32", m.IPAddress)) ip, err := netaddr.ParseIPPrefix(fmt.Sprintf("%s/32", machine.IPAddress))
if err != nil { if err != nil {
log.Trace(). log.Trace().
Caller(). Caller().
Str("ip", m.IPAddress). Str("ip", machine.IPAddress).
Msgf("Failed to parse IP Prefix from IP: %s", m.IPAddress) Msgf("Failed to parse IP Prefix from IP: %s", machine.IPAddress)
return nil, err return nil, err
} }
@ -455,8 +456,8 @@ func (m Machine) toNode(
if includeRoutes { if includeRoutes {
routesStr := []string{} routesStr := []string{}
if len(m.EnabledRoutes) != 0 { if len(machine.EnabledRoutes) != 0 {
allwIps, err := m.EnabledRoutes.MarshalJSON() allwIps, err := machine.EnabledRoutes.MarshalJSON()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -476,8 +477,8 @@ func (m Machine) toNode(
} }
endpoints := []string{} endpoints := []string{}
if len(m.Endpoints) != 0 { if len(machine.Endpoints) != 0 {
be, err := m.Endpoints.MarshalJSON() be, err := machine.Endpoints.MarshalJSON()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -488,8 +489,8 @@ func (m Machine) toNode(
} }
hostinfo := tailcfg.Hostinfo{} hostinfo := tailcfg.Hostinfo{}
if len(m.HostInfo) != 0 { if len(machine.HostInfo) != 0 {
hi, err := m.HostInfo.MarshalJSON() hi, err := machine.HostInfo.MarshalJSON()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -507,29 +508,34 @@ func (m Machine) toNode(
} }
var keyExpiry time.Time var keyExpiry time.Time
if m.Expiry != nil { if machine.Expiry != nil {
keyExpiry = *m.Expiry keyExpiry = *machine.Expiry
} else { } else {
keyExpiry = time.Time{} keyExpiry = time.Time{}
} }
var hostname string var hostname string
if dnsConfig != nil && dnsConfig.Proxied { // MagicDNS if dnsConfig != nil && dnsConfig.Proxied { // MagicDNS
hostname = fmt.Sprintf("%s.%s.%s", m.Name, m.Namespace.Name, baseDomain) hostname = fmt.Sprintf(
"%s.%s.%s",
machine.Name,
machine.Namespace.Name,
baseDomain,
)
} else { } else {
hostname = m.Name hostname = machine.Name
} }
n := tailcfg.Node{ n := tailcfg.Node{
ID: tailcfg.NodeID(m.ID), // this is the actual ID ID: tailcfg.NodeID(machine.ID), // this is the actual ID
StableID: tailcfg.StableNodeID( StableID: tailcfg.StableNodeID(
strconv.FormatUint(m.ID, BASE_10), strconv.FormatUint(machine.ID, BASE_10),
), // in headscale, unlike tailcontrol server, IDs are permanent ), // in headscale, unlike tailcontrol server, IDs are permanent
Name: hostname, Name: hostname,
User: tailcfg.UserID(m.NamespaceID), User: tailcfg.UserID(machine.NamespaceID),
Key: tailcfg.NodeKey(nKey), Key: tailcfg.NodeKey(nodeKey),
KeyExpiry: keyExpiry, KeyExpiry: keyExpiry,
Machine: tailcfg.MachineKey(mKey), Machine: tailcfg.MachineKey(machineKey),
DiscoKey: discoKey, DiscoKey: discoKey,
Addresses: addrs, Addresses: addrs,
AllowedIPs: allowedIPs, AllowedIPs: allowedIPs,
@ -537,68 +543,73 @@ func (m Machine) toNode(
DERP: derp, DERP: derp,
Hostinfo: hostinfo, Hostinfo: hostinfo,
Created: m.CreatedAt, Created: machine.CreatedAt,
LastSeen: m.LastSeen, LastSeen: machine.LastSeen,
KeepAlive: true, KeepAlive: true,
MachineAuthorized: m.Registered, MachineAuthorized: machine.Registered,
Capabilities: []string{tailcfg.CapabilityFileSharing}, Capabilities: []string{tailcfg.CapabilityFileSharing},
} }
return &n, nil return &n, nil
} }
func (m *Machine) toProto() *v1.Machine { func (machine *Machine) toProto() *v1.Machine {
machine := &v1.Machine{ machineProto := &v1.Machine{
Id: m.ID, Id: machine.ID,
MachineKey: m.MachineKey, MachineKey: machine.MachineKey,
NodeKey: m.NodeKey, NodeKey: machine.NodeKey,
DiscoKey: m.DiscoKey, DiscoKey: machine.DiscoKey,
IpAddress: m.IPAddress, IpAddress: machine.IPAddress,
Name: m.Name, Name: machine.Name,
Namespace: m.Namespace.toProto(), Namespace: machine.Namespace.toProto(),
Registered: m.Registered, Registered: machine.Registered,
// TODO(kradalby): Implement register method enum converter // TODO(kradalby): Implement register method enum converter
// RegisterMethod: , // RegisterMethod: ,
CreatedAt: timestamppb.New(m.CreatedAt), CreatedAt: timestamppb.New(machine.CreatedAt),
} }
if m.AuthKey != nil { if machine.AuthKey != nil {
machine.PreAuthKey = m.AuthKey.toProto() machineProto.PreAuthKey = machine.AuthKey.toProto()
} }
if m.LastSeen != nil { if machine.LastSeen != nil {
machine.LastSeen = timestamppb.New(*m.LastSeen) machineProto.LastSeen = timestamppb.New(*machine.LastSeen)
} }
if m.LastSuccessfulUpdate != nil { if machine.LastSuccessfulUpdate != nil {
machine.LastSuccessfulUpdate = timestamppb.New(*m.LastSuccessfulUpdate) machineProto.LastSuccessfulUpdate = timestamppb.New(
*machine.LastSuccessfulUpdate,
)
} }
if m.Expiry != nil { if machine.Expiry != nil {
machine.Expiry = timestamppb.New(*m.Expiry) machineProto.Expiry = timestamppb.New(*machine.Expiry)
} }
return machine return machineProto
} }
// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey. // RegisterMachine is executed from the CLI to register a new Machine using its MachineKey.
func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, error) { func (h *Headscale) RegisterMachine(
ns, err := h.GetNamespace(namespace) key string,
namespaceName string,
) (*Machine, error) {
namespace, err := h.GetNamespace(namespaceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
mKey, err := wgkey.ParseHex(key) machineKey, err := wgkey.ParseHex(key)
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := Machine{} machine := Machine{}
if result := h.db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is( if result := h.db.First(&machine, "machine_key = ?", machineKey.HexString()); errors.Is(
result.Error, result.Error,
gorm.ErrRecordNotFound, gorm.ErrRecordNotFound,
) { ) {
@ -607,15 +618,15 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Attempting to register machine") Msg("Attempting to register machine")
if m.isAlreadyRegistered() { if machine.isAlreadyRegistered() {
err := errors.New("Machine already registered") err := errors.New("Machine already registered")
log.Error(). log.Error().
Caller(). Caller().
Err(err). Err(err).
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Attempting to register machine") Msg("Attempting to register machine")
return nil, err return nil, err
@ -626,7 +637,7 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err
log.Error(). log.Error().
Caller(). Caller().
Err(err). Err(err).
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Could not find IP for the new machine") Msg("Could not find IP for the new machine")
return nil, err return nil, err
@ -634,27 +645,27 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine", m.Name). Str("machine", machine.Name).
Str("ip", ip.String()). Str("ip", ip.String()).
Msg("Found IP for host") Msg("Found IP for host")
m.IPAddress = ip.String() machine.IPAddress = ip.String()
m.NamespaceID = ns.ID machine.NamespaceID = namespace.ID
m.Registered = true machine.Registered = true
m.RegisterMethod = "cli" machine.RegisterMethod = "cli"
h.db.Save(&m) h.db.Save(&machine)
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine", m.Name). Str("machine", machine.Name).
Str("ip", ip.String()). Str("ip", ip.String()).
Msg("Machine registered with the database") Msg("Machine registered with the database")
return &m, nil return &machine, nil
} }
func (m *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) { func (machine *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) {
hostInfo, err := m.GetHostInfo() hostInfo, err := machine.GetHostInfo()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -662,8 +673,8 @@ func (m *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) {
return hostInfo.RoutableIPs, nil return hostInfo.RoutableIPs, nil
} }
func (m *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) { func (machine *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) {
data, err := m.EnabledRoutes.MarshalJSON() data, err := machine.EnabledRoutes.MarshalJSON()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -686,13 +697,13 @@ func (m *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) {
return routes, nil return routes, nil
} }
func (m *Machine) IsRoutesEnabled(routeStr string) bool { func (machine *Machine) IsRoutesEnabled(routeStr string) bool {
route, err := netaddr.ParseIPPrefix(routeStr) route, err := netaddr.ParseIPPrefix(routeStr)
if err != nil { if err != nil {
return false return false
} }
enabledRoutes, err := m.GetEnabledRoutes() enabledRoutes, err := machine.GetEnabledRoutes()
if err != nil { if err != nil {
return false return false
} }
@ -708,7 +719,7 @@ func (m *Machine) IsRoutesEnabled(routeStr string) bool {
// EnableNodeRoute enables new routes based on a list of new routes. It will _replace_ the // EnableNodeRoute enables new routes based on a list of new routes. It will _replace_ the
// previous list of routes. // previous list of routes.
func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error { func (h *Headscale) EnableRoutes(machine *Machine, routeStrs ...string) error {
newRoutes := make([]netaddr.IPPrefix, len(routeStrs)) newRoutes := make([]netaddr.IPPrefix, len(routeStrs))
for index, routeStr := range routeStrs { for index, routeStr := range routeStrs {
route, err := netaddr.ParseIPPrefix(routeStr) route, err := netaddr.ParseIPPrefix(routeStr)
@ -719,7 +730,7 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
newRoutes[index] = route newRoutes[index] = route
} }
availableRoutes, err := m.GetAdvertisedRoutes() availableRoutes, err := machine.GetAdvertisedRoutes()
if err != nil { if err != nil {
return err return err
} }
@ -728,7 +739,7 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
if !containsIpPrefix(availableRoutes, newRoute) { if !containsIpPrefix(availableRoutes, newRoute) {
return fmt.Errorf( return fmt.Errorf(
"route (%s) is not available on node %s", "route (%s) is not available on node %s",
m.Name, machine.Name,
newRoute, newRoute,
) )
} }
@ -739,10 +750,10 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
return err return err
} }
m.EnabledRoutes = datatypes.JSON(routes) machine.EnabledRoutes = datatypes.JSON(routes)
h.db.Save(&m) h.db.Save(&machine)
err = h.RequestMapUpdates(m.NamespaceID) err = h.RequestMapUpdates(machine.NamespaceID)
if err != nil { if err != nil {
return err return err
} }
@ -750,13 +761,13 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
return nil return nil
} }
func (m *Machine) RoutesToProto() (*v1.Routes, error) { func (machine *Machine) RoutesToProto() (*v1.Routes, error) {
availableRoutes, err := m.GetAdvertisedRoutes() availableRoutes, err := machine.GetAdvertisedRoutes()
if err != nil { if err != nil {
return nil, err return nil, err
} }
enabledRoutes, err := m.GetEnabledRoutes() enabledRoutes, err := machine.GetEnabledRoutes()
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -32,12 +32,12 @@ type Namespace struct {
// CreateNamespace creates a new Namespace. Returns error if could not be created // CreateNamespace creates a new Namespace. Returns error if could not be created
// or another namespace already exists. // or another namespace already exists.
func (h *Headscale) CreateNamespace(name string) (*Namespace, error) { func (h *Headscale) CreateNamespace(name string) (*Namespace, error) {
n := Namespace{} namespace := Namespace{}
if err := h.db.Where("name = ?", name).First(&n).Error; err == nil { if err := h.db.Where("name = ?", name).First(&namespace).Error; err == nil {
return nil, errorNamespaceExists return nil, errorNamespaceExists
} }
n.Name = name namespace.Name = name
if err := h.db.Create(&n).Error; err != nil { if err := h.db.Create(&namespace).Error; err != nil {
log.Error(). log.Error().
Str("func", "CreateNamespace"). Str("func", "CreateNamespace").
Err(err). Err(err).
@ -46,22 +46,22 @@ func (h *Headscale) CreateNamespace(name string) (*Namespace, error) {
return nil, err return nil, err
} }
return &n, nil return &namespace, nil
} }
// DestroyNamespace destroys a Namespace. Returns error if the Namespace does // DestroyNamespace destroys a Namespace. Returns error if the Namespace does
// not exist or if there are machines associated with it. // not exist or if there are machines associated with it.
func (h *Headscale) DestroyNamespace(name string) error { func (h *Headscale) DestroyNamespace(name string) error {
n, err := h.GetNamespace(name) namespace, err := h.GetNamespace(name)
if err != nil { if err != nil {
return errorNamespaceNotFound return errorNamespaceNotFound
} }
m, err := h.ListMachinesInNamespace(name) machines, err := h.ListMachinesInNamespace(name)
if err != nil { if err != nil {
return err return err
} }
if len(m) > 0 { if len(machines) > 0 {
return errorNamespaceNotEmptyOfNodes return errorNamespaceNotEmptyOfNodes
} }
@ -69,14 +69,14 @@ func (h *Headscale) DestroyNamespace(name string) error {
if err != nil { if err != nil {
return err return err
} }
for _, p := range keys { for _, key := range keys {
err = h.DestroyPreAuthKey(&p) err = h.DestroyPreAuthKey(&key)
if err != nil { if err != nil {
return err return err
} }
} }
if result := h.db.Unscoped().Delete(&n); result.Error != nil { if result := h.db.Unscoped().Delete(&namespace); result.Error != nil {
return result.Error return result.Error
} }
@ -86,7 +86,7 @@ func (h *Headscale) DestroyNamespace(name string) error {
// RenameNamespace renames a Namespace. Returns error if the Namespace does // RenameNamespace renames a Namespace. Returns error if the Namespace does
// not exist or if another Namespace exists with the new name. // not exist or if another Namespace exists with the new name.
func (h *Headscale) RenameNamespace(oldName, newName string) error { func (h *Headscale) RenameNamespace(oldName, newName string) error {
n, err := h.GetNamespace(oldName) oldNamespace, err := h.GetNamespace(oldName)
if err != nil { if err != nil {
return err return err
} }
@ -98,13 +98,13 @@ func (h *Headscale) RenameNamespace(oldName, newName string) error {
return err return err
} }
n.Name = newName oldNamespace.Name = newName
if result := h.db.Save(&n); result.Error != nil { if result := h.db.Save(&oldNamespace); result.Error != nil {
return result.Error return result.Error
} }
err = h.RequestMapUpdates(n.ID) err = h.RequestMapUpdates(oldNamespace.ID)
if err != nil { if err != nil {
return err return err
} }
@ -114,15 +114,15 @@ func (h *Headscale) RenameNamespace(oldName, newName string) error {
// GetNamespace fetches a namespace by name. // GetNamespace fetches a namespace by name.
func (h *Headscale) GetNamespace(name string) (*Namespace, error) { func (h *Headscale) GetNamespace(name string) (*Namespace, error) {
n := Namespace{} namespace := Namespace{}
if result := h.db.First(&n, "name = ?", name); errors.Is( if result := h.db.First(&namespace, "name = ?", name); errors.Is(
result.Error, result.Error,
gorm.ErrRecordNotFound, gorm.ErrRecordNotFound,
) { ) {
return nil, errorNamespaceNotFound return nil, errorNamespaceNotFound
} }
return &n, nil return &namespace, nil
} }
// ListNamespaces gets all the existing namespaces. // ListNamespaces gets all the existing namespaces.
@ -137,13 +137,13 @@ func (h *Headscale) ListNamespaces() ([]Namespace, error) {
// ListMachinesInNamespace gets all the nodes in a given namespace. // ListMachinesInNamespace gets all the nodes in a given namespace.
func (h *Headscale) ListMachinesInNamespace(name string) ([]Machine, error) { func (h *Headscale) ListMachinesInNamespace(name string) ([]Machine, error) {
n, err := h.GetNamespace(name) namespace, err := h.GetNamespace(name)
if err != nil { if err != nil {
return nil, err return nil, err
} }
machines := []Machine{} machines := []Machine{}
if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Where(&Machine{NamespaceID: n.ID}).Find(&machines).Error; err != nil { if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Where(&Machine{NamespaceID: namespace.ID}).Find(&machines).Error; err != nil {
return nil, err return nil, err
} }
@ -176,17 +176,18 @@ func (h *Headscale) ListSharedMachinesInNamespace(name string) ([]Machine, error
} }
// SetMachineNamespace assigns a Machine to a namespace. // SetMachineNamespace assigns a Machine to a namespace.
func (h *Headscale) SetMachineNamespace(m *Machine, namespaceName string) error { func (h *Headscale) SetMachineNamespace(machine *Machine, namespaceName string) error {
n, err := h.GetNamespace(namespaceName) namespace, err := h.GetNamespace(namespaceName)
if err != nil { if err != nil {
return err return err
} }
m.NamespaceID = n.ID machine.NamespaceID = namespace.ID
h.db.Save(&m) h.db.Save(&machine)
return nil return nil
} }
// TODO(kradalby): Remove the need for this.
// RequestMapUpdates signals the KV worker to update the maps for this namespace. // RequestMapUpdates signals the KV worker to update the maps for this namespace.
func (h *Headscale) RequestMapUpdates(namespaceID uint) error { func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
namespace := Namespace{} namespace := Namespace{}
@ -194,8 +195,8 @@ func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
return err return err
} }
v, err := h.getValue("namespaces_pending_updates") namespacesPendingUpdates, err := h.getValue("namespaces_pending_updates")
if err != nil || v == "" { if err != nil || namespacesPendingUpdates == "" {
err = h.setValue( err = h.setValue(
"namespaces_pending_updates", "namespaces_pending_updates",
fmt.Sprintf(`["%s"]`, namespace.Name), fmt.Sprintf(`["%s"]`, namespace.Name),
@ -207,7 +208,7 @@ func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
return nil return nil
} }
names := []string{} names := []string{}
err = json.Unmarshal([]byte(v), &names) err = json.Unmarshal([]byte(namespacesPendingUpdates), &names)
if err != nil { if err != nil {
err = h.setValue( err = h.setValue(
"namespaces_pending_updates", "namespaces_pending_updates",
@ -235,16 +236,16 @@ func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
} }
func (h *Headscale) checkForNamespacesPendingUpdates() { func (h *Headscale) checkForNamespacesPendingUpdates() {
v, err := h.getValue("namespaces_pending_updates") namespacesPendingUpdates, err := h.getValue("namespaces_pending_updates")
if err != nil { if err != nil {
return return
} }
if v == "" { if namespacesPendingUpdates == "" {
return return
} }
namespaces := []string{} namespaces := []string{}
err = json.Unmarshal([]byte(v), &namespaces) err = json.Unmarshal([]byte(namespacesPendingUpdates), &namespaces)
if err != nil { if err != nil {
return return
} }
@ -255,11 +256,11 @@ func (h *Headscale) checkForNamespacesPendingUpdates() {
Msg("Sending updates to nodes in namespacespace") Msg("Sending updates to nodes in namespacespace")
h.setLastStateChangeToNow(namespace) h.setLastStateChangeToNow(namespace)
} }
newV, err := h.getValue("namespaces_pending_updates") newPendingUpdateValue, err := h.getValue("namespaces_pending_updates")
if err != nil { if err != nil {
return return
} }
if v == newV { // only clear when no changes, so we notified everybody if namespacesPendingUpdates == newPendingUpdateValue { // only clear when no changes, so we notified everybody
err = h.setValue("namespaces_pending_updates", "") err = h.setValue("namespaces_pending_updates", "")
if err != nil { if err != nil {
log.Error(). log.Error().
@ -273,7 +274,7 @@ func (h *Headscale) checkForNamespacesPendingUpdates() {
} }
func (n *Namespace) toUser() *tailcfg.User { func (n *Namespace) toUser() *tailcfg.User {
u := tailcfg.User{ user := tailcfg.User{
ID: tailcfg.UserID(n.ID), ID: tailcfg.UserID(n.ID),
LoginName: n.Name, LoginName: n.Name,
DisplayName: n.Name, DisplayName: n.Name,
@ -283,11 +284,11 @@ func (n *Namespace) toUser() *tailcfg.User {
Created: time.Time{}, Created: time.Time{},
} }
return &u return &user
} }
func (n *Namespace) toLogin() *tailcfg.Login { func (n *Namespace) toLogin() *tailcfg.Login {
l := tailcfg.Login{ login := tailcfg.Login{
ID: tailcfg.LoginID(n.ID), ID: tailcfg.LoginID(n.ID),
LoginName: n.Name, LoginName: n.Name,
DisplayName: n.Name, DisplayName: n.Name,
@ -295,14 +296,14 @@ func (n *Namespace) toLogin() *tailcfg.Login {
Domain: "headscale.net", Domain: "headscale.net",
} }
return &l return &login
} }
func getMapResponseUserProfiles(m Machine, peers Machines) []tailcfg.UserProfile { func getMapResponseUserProfiles(machine Machine, peers Machines) []tailcfg.UserProfile {
namespaceMap := make(map[string]Namespace) namespaceMap := make(map[string]Namespace)
namespaceMap[m.Namespace.Name] = m.Namespace namespaceMap[machine.Namespace.Name] = machine.Namespace
for _, p := range peers { for _, peer := range peers {
namespaceMap[p.Namespace.Name] = p.Namespace // not worth checking if already is there namespaceMap[peer.Namespace.Name] = peer.Namespace // not worth checking if already is there
} }
profiles := []tailcfg.UserProfile{} profiles := []tailcfg.UserProfile{}

69
oidc.go
View file

@ -68,10 +68,10 @@ func (h *Headscale) initOIDC() error {
// RegisterOIDC redirects to the OIDC provider for authentication // RegisterOIDC redirects to the OIDC provider for authentication
// Puts machine key in cache so the callback can retrieve it using the oidc state param // Puts machine key in cache so the callback can retrieve it using the oidc state param
// Listens in /oidc/register/:mKey. // Listens in /oidc/register/:mKey.
func (h *Headscale) RegisterOIDC(c *gin.Context) { func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
mKeyStr := c.Param("mkey") mKeyStr := ctx.Param("mkey")
if mKeyStr == "" { if mKeyStr == "" {
c.String(http.StatusBadRequest, "Wrong params") ctx.String(http.StatusBadRequest, "Wrong params")
return return
} }
@ -79,7 +79,7 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) {
b := make([]byte, RANDOM_BYTE_SIZE) b := make([]byte, RANDOM_BYTE_SIZE)
if _, err := rand.Read(b); err != nil { if _, err := rand.Read(b); err != nil {
log.Error().Msg("could not read 16 bytes from rand") log.Error().Msg("could not read 16 bytes from rand")
c.String(http.StatusInternalServerError, "could not read 16 bytes from rand") ctx.String(http.StatusInternalServerError, "could not read 16 bytes from rand")
return return
} }
@ -92,7 +92,7 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) {
authUrl := h.oauth2Config.AuthCodeURL(stateStr) authUrl := h.oauth2Config.AuthCodeURL(stateStr)
log.Debug().Msgf("Redirecting to %s for authentication", authUrl) log.Debug().Msgf("Redirecting to %s for authentication", authUrl)
c.Redirect(http.StatusFound, authUrl) ctx.Redirect(http.StatusFound, authUrl)
} }
// OIDCCallback handles the callback from the OIDC endpoint // OIDCCallback handles the callback from the OIDC endpoint
@ -100,19 +100,19 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) {
// TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities // TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities
// TODO: Add groups information from OIDC tokens into machine HostInfo // TODO: Add groups information from OIDC tokens into machine HostInfo
// Listens in /oidc/callback. // Listens in /oidc/callback.
func (h *Headscale) OIDCCallback(c *gin.Context) { func (h *Headscale) OIDCCallback(ctx *gin.Context) {
code := c.Query("code") code := ctx.Query("code")
state := c.Query("state") state := ctx.Query("state")
if code == "" || state == "" { if code == "" || state == "" {
c.String(http.StatusBadRequest, "Wrong params") ctx.String(http.StatusBadRequest, "Wrong params")
return return
} }
oauth2Token, err := h.oauth2Config.Exchange(context.Background(), code) oauth2Token, err := h.oauth2Config.Exchange(context.Background(), code)
if err != nil { if err != nil {
c.String(http.StatusBadRequest, "Could not exchange code for token") ctx.String(http.StatusBadRequest, "Could not exchange code for token")
return return
} }
@ -121,7 +121,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string) rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string)
if !rawIDTokenOK { if !rawIDTokenOK {
c.String(http.StatusBadRequest, "Could not extract ID Token") ctx.String(http.StatusBadRequest, "Could not extract ID Token")
return return
} }
@ -130,7 +130,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
idToken, err := verifier.Verify(context.Background(), rawIDToken) idToken, err := verifier.Verify(context.Background(), rawIDToken)
if err != nil { if err != nil {
c.String(http.StatusBadRequest, "Failed to verify id token: %s", err.Error()) ctx.String(http.StatusBadRequest, "Failed to verify id token: %s", err.Error())
return return
} }
@ -145,7 +145,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
// Extract custom claims // Extract custom claims
var claims IDTokenClaims var claims IDTokenClaims
if err = idToken.Claims(&claims); err != nil { if err = idToken.Claims(&claims); err != nil {
c.String( ctx.String(
http.StatusBadRequest, http.StatusBadRequest,
fmt.Sprintf("Failed to decode id token claims: %s", err), fmt.Sprintf("Failed to decode id token claims: %s", err),
) )
@ -159,7 +159,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
if !mKeyFound { if !mKeyFound {
log.Error(). log.Error().
Msg("requested machine state key expired before authorisation completed") Msg("requested machine state key expired before authorisation completed")
c.String(http.StatusBadRequest, "state has expired") ctx.String(http.StatusBadRequest, "state has expired")
return return
} }
@ -167,16 +167,19 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
if !mKeyOK { if !mKeyOK {
log.Error().Msg("could not get machine key from cache") log.Error().Msg("could not get machine key from cache")
c.String(http.StatusInternalServerError, "could not get machine key from cache") ctx.String(
http.StatusInternalServerError,
"could not get machine key from cache",
)
return return
} }
// retrieve machine information // retrieve machine information
m, err := h.GetMachineByMachineKey(mKeyStr) machine, err := h.GetMachineByMachineKey(mKeyStr)
if err != nil { if err != nil {
log.Error().Msg("machine key not found in database") log.Error().Msg("machine key not found in database")
c.String( ctx.String(
http.StatusInternalServerError, http.StatusInternalServerError,
"could not get machine info from database", "could not get machine info from database",
) )
@ -186,19 +189,19 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
now := time.Now().UTC() now := time.Now().UTC()
if nsName, ok := h.getNamespaceFromEmail(claims.Email); ok { if namespaceName, ok := h.getNamespaceFromEmail(claims.Email); ok {
// register the machine if it's new // register the machine if it's new
if !m.Registered { if !machine.Registered {
log.Debug().Msg("Registering new machine after successful callback") log.Debug().Msg("Registering new machine after successful callback")
ns, err := h.GetNamespace(nsName) namespace, err := h.GetNamespace(namespaceName)
if err != nil { if err != nil {
ns, err = h.CreateNamespace(nsName) namespace, err = h.CreateNamespace(namespaceName)
if err != nil { if err != nil {
log.Error(). log.Error().
Msgf("could not create new namespace '%s'", claims.Email) Msgf("could not create new namespace '%s'", claims.Email)
c.String( ctx.String(
http.StatusInternalServerError, http.StatusInternalServerError,
"could not create new namespace", "could not create new namespace",
) )
@ -209,7 +212,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
ip, err := h.getAvailableIP() ip, err := h.getAvailableIP()
if err != nil { if err != nil {
c.String( ctx.String(
http.StatusInternalServerError, http.StatusInternalServerError,
"could not get an IP from the pool", "could not get an IP from the pool",
) )
@ -217,17 +220,17 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
return return
} }
m.IPAddress = ip.String() machine.IPAddress = ip.String()
m.NamespaceID = ns.ID machine.NamespaceID = namespace.ID
m.Registered = true machine.Registered = true
m.RegisterMethod = "oidc" machine.RegisterMethod = "oidc"
m.LastSuccessfulUpdate = &now machine.LastSuccessfulUpdate = &now
h.db.Save(&m) h.db.Save(&machine)
} }
h.updateMachineExpiry(m) h.updateMachineExpiry(machine)
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(` ctx.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
<html> <html>
<body> <body>
<h1>headscale</h1> <h1>headscale</h1>
@ -243,9 +246,9 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
log.Error(). log.Error().
Str("email", claims.Email). Str("email", claims.Email).
Str("username", claims.Username). Str("username", claims.Username).
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Email could not be mapped to a namespace") Msg("Email could not be mapped to a namespace")
c.String( ctx.String(
http.StatusBadRequest, http.StatusBadRequest,
"email from claim could not be mapped to a namespace", "email from claim could not be mapped to a namespace",
) )

View file

@ -233,7 +233,7 @@ func (h *Headscale) PollNetMapStream(
) { ) {
go h.scheduledPollWorker(cancelKeepAlive, updateChan, keepAliveChan, mKey, req, m) go h.scheduledPollWorker(cancelKeepAlive, updateChan, keepAliveChan, mKey, req, m)
c.Stream(func(w io.Writer) bool { c.Stream(func(writer io.Writer) bool {
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", m.Name).
@ -252,7 +252,7 @@ func (h *Headscale) PollNetMapStream(
Str("channel", "pollData"). Str("channel", "pollData").
Int("bytes", len(data)). Int("bytes", len(data)).
Msg("Sending data received via pollData channel") Msg("Sending data received via pollData channel")
_, err := w.Write(data) _, err := writer.Write(data)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
@ -305,7 +305,7 @@ func (h *Headscale) PollNetMapStream(
Str("channel", "keepAlive"). Str("channel", "keepAlive").
Int("bytes", len(data)). Int("bytes", len(data)).
Msg("Sending keep alive message") Msg("Sending keep alive message")
_, err := w.Write(data) _, err := writer.Write(data)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
@ -370,7 +370,7 @@ func (h *Headscale) PollNetMapStream(
Err(err). Err(err).
Msg("Could not get the map update") Msg("Could not get the map update")
} }
_, err = w.Write(data) _, err = writer.Write(data)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").

View file

@ -18,13 +18,16 @@ type SharedMachine struct {
} }
// AddSharedMachineToNamespace adds a machine as a shared node to a namespace. // AddSharedMachineToNamespace adds a machine as a shared node to a namespace.
func (h *Headscale) AddSharedMachineToNamespace(m *Machine, ns *Namespace) error { func (h *Headscale) AddSharedMachineToNamespace(
if m.NamespaceID == ns.ID { machine *Machine,
namespace *Namespace,
) error {
if machine.NamespaceID == namespace.ID {
return errorSameNamespace return errorSameNamespace
} }
sharedMachines := []SharedMachine{} sharedMachines := []SharedMachine{}
if err := h.db.Where("machine_id = ? AND namespace_id = ?", m.ID, ns.ID).Find(&sharedMachines).Error; err != nil { if err := h.db.Where("machine_id = ? AND namespace_id = ?", machine.ID, namespace.ID).Find(&sharedMachines).Error; err != nil {
return err return err
} }
if len(sharedMachines) > 0 { if len(sharedMachines) > 0 {
@ -32,10 +35,10 @@ func (h *Headscale) AddSharedMachineToNamespace(m *Machine, ns *Namespace) error
} }
sharedMachine := SharedMachine{ sharedMachine := SharedMachine{
MachineID: m.ID, MachineID: machine.ID,
Machine: *m, Machine: *machine,
NamespaceID: ns.ID, NamespaceID: namespace.ID,
Namespace: *ns, Namespace: *namespace,
} }
h.db.Save(&sharedMachine) h.db.Save(&sharedMachine)
@ -43,14 +46,17 @@ func (h *Headscale) AddSharedMachineToNamespace(m *Machine, ns *Namespace) error
} }
// RemoveSharedMachineFromNamespace removes a shared machine from a namespace. // RemoveSharedMachineFromNamespace removes a shared machine from a namespace.
func (h *Headscale) RemoveSharedMachineFromNamespace(m *Machine, ns *Namespace) error { func (h *Headscale) RemoveSharedMachineFromNamespace(
if m.NamespaceID == ns.ID { machine *Machine,
namespace *Namespace,
) error {
if machine.NamespaceID == namespace.ID {
// Can't unshare from primary namespace // Can't unshare from primary namespace
return errorMachineNotShared return errorMachineNotShared
} }
sharedMachine := SharedMachine{} sharedMachine := SharedMachine{}
result := h.db.Where("machine_id = ? AND namespace_id = ?", m.ID, ns.ID). result := h.db.Where("machine_id = ? AND namespace_id = ?", machine.ID, namespace.ID).
Unscoped(). Unscoped().
Delete(&sharedMachine) Delete(&sharedMachine)
if result.Error != nil { if result.Error != nil {
@ -61,7 +67,7 @@ func (h *Headscale) RemoveSharedMachineFromNamespace(m *Machine, ns *Namespace)
return errorMachineNotShared return errorMachineNotShared
} }
err := h.RequestMapUpdates(ns.ID) err := h.RequestMapUpdates(namespace.ID)
if err != nil { if err != nil {
return err return err
} }
@ -70,9 +76,9 @@ func (h *Headscale) RemoveSharedMachineFromNamespace(m *Machine, ns *Namespace)
} }
// RemoveSharedMachineFromAllNamespaces removes a machine as a shared node from all namespaces. // RemoveSharedMachineFromAllNamespaces removes a machine as a shared node from all namespaces.
func (h *Headscale) RemoveSharedMachineFromAllNamespaces(m *Machine) error { func (h *Headscale) RemoveSharedMachineFromAllNamespaces(machine *Machine) error {
sharedMachine := SharedMachine{} sharedMachine := SharedMachine{}
if result := h.db.Where("machine_id = ?", m.ID).Unscoped().Delete(&sharedMachine); result.Error != nil { if result := h.db.Where("machine_id = ?", machine.ID).Unscoped().Delete(&sharedMachine); result.Error != nil {
return result.Error return result.Error
} }

View file

@ -13,8 +13,8 @@ import (
//go:embed gen/openapiv2/headscale/v1/headscale.swagger.json //go:embed gen/openapiv2/headscale/v1/headscale.swagger.json
var apiV1JSON []byte var apiV1JSON []byte
func SwaggerUI(c *gin.Context) { func SwaggerUI(ctx *gin.Context) {
t := template.Must(template.New("swagger").Parse(` swaggerTemplate := template.Must(template.New("swagger").Parse(`
<html> <html>
<head> <head>
<link rel="stylesheet" type="text/css" href="https://unpkg.com/swagger-ui-dist@3/swagger-ui.css"> <link rel="stylesheet" type="text/css" href="https://unpkg.com/swagger-ui-dist@3/swagger-ui.css">
@ -47,12 +47,12 @@ func SwaggerUI(c *gin.Context) {
</html>`)) </html>`))
var payload bytes.Buffer var payload bytes.Buffer
if err := t.Execute(&payload, struct{}{}); err != nil { if err := swaggerTemplate.Execute(&payload, struct{}{}); err != nil {
log.Error(). log.Error().
Caller(). Caller().
Err(err). Err(err).
Msg("Could not render Swagger") Msg("Could not render Swagger")
c.Data( ctx.Data(
http.StatusInternalServerError, http.StatusInternalServerError,
"text/html; charset=utf-8", "text/html; charset=utf-8",
[]byte("Could not render Swagger"), []byte("Could not render Swagger"),
@ -61,9 +61,9 @@ func SwaggerUI(c *gin.Context) {
return return
} }
c.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes()) ctx.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes())
} }
func SwaggerAPIv1(c *gin.Context) { func SwaggerAPIv1(ctx *gin.Context) {
c.Data(http.StatusOK, "application/json; charset=utf-8", apiV1JSON) ctx.Data(http.StatusOK, "application/json; charset=utf-8", apiV1JSON)
} }

View file

@ -36,7 +36,7 @@ func decode(
func decodeMsg( func decodeMsg(
msg []byte, msg []byte,
v interface{}, output interface{},
pubKey *wgkey.Key, pubKey *wgkey.Key,
privKey *wgkey.Private, privKey *wgkey.Private,
) error { ) error {
@ -45,7 +45,7 @@ func decodeMsg(
return err return err
} }
// fmt.Println(string(decrypted)) // fmt.Println(string(decrypted))
if err := json.Unmarshal(decrypted, v); err != nil { if err := json.Unmarshal(decrypted, output); err != nil {
return fmt.Errorf("response: %v", err) return fmt.Errorf("response: %v", err)
} }
@ -78,13 +78,17 @@ func encode(v interface{}, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, e
return encodeMsg(b, pubKey, privKey) return encodeMsg(b, pubKey, privKey)
} }
func encodeMsg(b []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, error) { func encodeMsg(
payload []byte,
pubKey *wgkey.Key,
privKey *wgkey.Private,
) ([]byte, error) {
var nonce [24]byte var nonce [24]byte
if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil { if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil {
panic(err) panic(err)
} }
pub, pri := (*[32]byte)(pubKey), (*[32]byte)(privKey) pub, pri := (*[32]byte)(pubKey), (*[32]byte)(privKey)
msg := box.Seal(nonce[:], b, &nonce, pub, pri) msg := box.Seal(nonce[:], payload, &nonce, pub, pri)
return msg, nil return msg, nil
} }