mirror of
https://github.com/juanfont/headscale.git
synced 2024-12-03 04:03:05 +00:00
WIP: Client updates. Long polling rewritten
This commit is contained in:
parent
ca6904fc95
commit
06fb7d4587
2 changed files with 105 additions and 40 deletions
8
app.go
8
app.go
|
@ -4,6 +4,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
|
@ -30,6 +31,9 @@ type Headscale struct {
|
||||||
dbString string
|
dbString string
|
||||||
publicKey *wgcfg.Key
|
publicKey *wgcfg.Key
|
||||||
privateKey *wgcfg.PrivateKey
|
privateKey *wgcfg.PrivateKey
|
||||||
|
|
||||||
|
pollMu sync.Mutex
|
||||||
|
clientsPolling map[uint64]chan []byte // this is by all means a hackity hack
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHeadscale returns the Headscale app
|
// NewHeadscale returns the Headscale app
|
||||||
|
@ -54,6 +58,7 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
h.clientsPolling = make(map[uint64]chan []byte)
|
||||||
return &h, nil
|
return &h, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -64,9 +69,6 @@ func (h *Headscale) Serve() error {
|
||||||
r.GET("/register", h.RegisterWebAPI)
|
r.GET("/register", h.RegisterWebAPI)
|
||||||
r.POST("/machine/:id/map", h.PollNetMapHandler)
|
r.POST("/machine/:id/map", h.PollNetMapHandler)
|
||||||
r.POST("/machine/:id", h.RegistrationHandler)
|
r.POST("/machine/:id", h.RegistrationHandler)
|
||||||
|
|
||||||
// r.LoadHTMLFiles("./frontend/build/index.html")
|
|
||||||
// r.Use(static.Serve("/", static.LocalFile("./frontend/build", true)))
|
|
||||||
err := r.Run(h.cfg.Addr)
|
err := r.Run(h.cfg.Addr)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
133
handlers.go
133
handlers.go
|
@ -57,7 +57,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
||||||
// We do have the updated key!
|
// We do have the updated key!
|
||||||
if m.NodeKey == wgcfg.Key(req.NodeKey).HexString() {
|
if m.NodeKey == wgcfg.Key(req.NodeKey).HexString() {
|
||||||
if m.Registered {
|
if m.Registered {
|
||||||
log.Println("Registered and we have the updated key! Lets move to map")
|
log.Println("Client is registered and we have the current key. All clear to /map")
|
||||||
resp.AuthURL = ""
|
resp.AuthURL = ""
|
||||||
respBody, err := encode(resp, &mKey, h.privateKey)
|
respBody, err := encode(resp, &mKey, h.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -102,50 +102,73 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
||||||
log.Println("We dont know anything about the new key. WTF")
|
log.Println("We dont know anything about the new key. WTF")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PollNetMapHandler takes care of /machine/:id/map
|
||||||
|
//
|
||||||
|
// This is the busiest endpoint, as it keeps the HTTP long poll that updates
|
||||||
|
// the clients when something in the network changes.
|
||||||
|
//
|
||||||
|
// The clients POST stuff like HostInfo and their Endpoints here, but
|
||||||
|
// only after their first request (marked with the ReadOnly field).
|
||||||
|
//
|
||||||
|
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
|
||||||
func (h *Headscale) PollNetMapHandler(c *gin.Context) {
|
func (h *Headscale) PollNetMapHandler(c *gin.Context) {
|
||||||
body, _ := io.ReadAll(c.Request.Body)
|
body, _ := io.ReadAll(c.Request.Body)
|
||||||
mKeyStr := c.Param("id")
|
mKeyStr := c.Param("id")
|
||||||
mKey, err := wgcfg.ParseHexKey(mKeyStr)
|
mKey, err := wgcfg.ParseHexKey(mKeyStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Cannot parse client key: %s", err)
|
log.Printf("Cannot parse client key: %s", err)
|
||||||
c.String(http.StatusOK, "Sad!")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
req := tailcfg.MapRequest{}
|
req := tailcfg.MapRequest{}
|
||||||
err = decode(body, &req, &mKey, h.privateKey)
|
err = decode(body, &req, &mKey, h.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Cannot decode message: %s", err)
|
log.Printf("Cannot decode message: %s", err)
|
||||||
c.String(http.StatusOK, "Very sad!")
|
return
|
||||||
// return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
db, err := h.db()
|
db, err := h.db()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Cannot open DB: %s", err)
|
log.Printf("Cannot open DB: %s", err)
|
||||||
c.String(http.StatusInternalServerError, ":(")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
var m Machine
|
var m Machine
|
||||||
if db.First(&m, "machine_key = ?", mKey.HexString()).RecordNotFound() {
|
if db.First(&m, "machine_key = ?", mKey.HexString()).RecordNotFound() {
|
||||||
log.Printf("Cannot encode message: %s", err)
|
log.Printf("Cannot find machine: %s", err)
|
||||||
c.String(http.StatusOK, "Extremely sad!")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
endpoints, _ := json.Marshal(req.Endpoints)
|
|
||||||
hostinfo, _ := json.Marshal(req.Hostinfo)
|
hostinfo, _ := json.Marshal(req.Hostinfo)
|
||||||
m.Endpoints = postgres.Jsonb{RawMessage: json.RawMessage(endpoints)}
|
m.Name = req.Hostinfo.Hostname
|
||||||
m.HostInfo = postgres.Jsonb{RawMessage: json.RawMessage(hostinfo)}
|
m.HostInfo = postgres.Jsonb{RawMessage: json.RawMessage(hostinfo)}
|
||||||
m.DiscoKey = wgcfg.Key(req.DiscoKey).HexString()
|
m.DiscoKey = wgcfg.Key(req.DiscoKey).HexString()
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
|
|
||||||
|
// From Tailscale client:
|
||||||
|
//
|
||||||
|
// ReadOnly is whether the client just wants to fetch the MapResponse,
|
||||||
|
// without updating their Endpoints. The Endpoints field will be ignored and
|
||||||
|
// LastSeen will not be updated and peers will not be notified of changes.
|
||||||
|
//
|
||||||
|
// The intended use is for clients to discover the DERP map at start-up
|
||||||
|
// before their first real endpoint update.
|
||||||
|
if !req.ReadOnly {
|
||||||
|
endpoints, _ := json.Marshal(req.Endpoints)
|
||||||
|
m.Endpoints = postgres.Jsonb{RawMessage: json.RawMessage(endpoints)}
|
||||||
m.LastSeen = &now
|
m.LastSeen = &now
|
||||||
|
}
|
||||||
db.Save(&m)
|
db.Save(&m)
|
||||||
db.Close()
|
db.Close()
|
||||||
|
|
||||||
chanStream := make(chan []byte, 1)
|
pollData := make(chan []byte, 1)
|
||||||
go func() {
|
update := make(chan []byte, 1)
|
||||||
defer close(chanStream)
|
cancelKeepAlive := make(chan []byte, 1)
|
||||||
|
defer close(pollData)
|
||||||
|
defer close(update)
|
||||||
|
defer close(cancelKeepAlive)
|
||||||
|
h.pollMu.Lock()
|
||||||
|
h.clientsPolling[m.ID] = update
|
||||||
|
h.pollMu.Unlock()
|
||||||
|
|
||||||
data, err := h.getMapResponse(mKey, req, m)
|
data, err := h.getMapResponse(mKey, req, m)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -153,34 +176,73 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
//send initial dump
|
log.Printf("[%s] sending initial map", m.Name)
|
||||||
chanStream <- *data
|
pollData <- *data
|
||||||
for {
|
|
||||||
|
|
||||||
data, err := h.getMapKeepAliveResponse(mKey, req, m)
|
// We update our peers if the client is not sending ReadOnly in the MapRequest
|
||||||
if err != nil {
|
// so we don't distribute its initial request (it comes with
|
||||||
c.String(http.StatusInternalServerError, ":(")
|
// empty endpoints to peers)
|
||||||
return
|
if !req.ReadOnly {
|
||||||
}
|
peers, _ := h.getPeers(m)
|
||||||
chanStream <- *data
|
h.pollMu.Lock()
|
||||||
// keep the node entertained
|
for _, p := range *peers {
|
||||||
time.Sleep(time.Second * 180)
|
log.Printf("[%s] notifying peer %s (%s)", m.Name, p.Name, p.Addresses[0])
|
||||||
break
|
if pUp, ok := h.clientsPolling[uint64(p.ID)]; ok {
|
||||||
}
|
pUp <- []byte{}
|
||||||
|
|
||||||
}()
|
|
||||||
c.Stream(func(w io.Writer) bool {
|
|
||||||
if msg, ok := <-chanStream; ok {
|
|
||||||
log.Printf("🦀 Sending data to %s: %d bytes", c.Request.RemoteAddr, len(msg))
|
|
||||||
w.Write(msg)
|
|
||||||
return true
|
|
||||||
} else {
|
} else {
|
||||||
log.Printf("🦄 Closing connection to %s", c.Request.RemoteAddr)
|
log.Printf("[%s] Peer %s does not appear to be polling", m.Name, p.Name)
|
||||||
c.AbortWithStatus(200)
|
}
|
||||||
|
}
|
||||||
|
h.pollMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
go h.keepAlive(cancelKeepAlive, pollData, mKey, req, m)
|
||||||
|
|
||||||
|
c.Stream(func(w io.Writer) bool {
|
||||||
|
select {
|
||||||
|
case data := <-pollData:
|
||||||
|
log.Printf("[%s] Sending data (%d bytes)", m.Name, len(data))
|
||||||
|
w.Write(data)
|
||||||
|
return true
|
||||||
|
|
||||||
|
case <-update:
|
||||||
|
log.Printf("[%s] Received a request for update", m.Name)
|
||||||
|
data, err := h.getMapResponse(mKey, req, m)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("[%s] 🤮 Cannot get the poll response: %s", m.Name, err)
|
||||||
|
}
|
||||||
|
w.Write(*data)
|
||||||
|
return true
|
||||||
|
|
||||||
|
case <-c.Request.Context().Done():
|
||||||
|
log.Printf("[%s] 😥 The client has closed the connection", m.Name)
|
||||||
|
h.pollMu.Lock()
|
||||||
|
cancelKeepAlive <- []byte{}
|
||||||
|
delete(h.clientsPolling, m.ID)
|
||||||
|
h.pollMu.Unlock()
|
||||||
|
|
||||||
return false
|
return false
|
||||||
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Headscale) keepAlive(cancel chan []byte, pollData chan []byte, mKey wgcfg.Key, req tailcfg.MapRequest, m Machine) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-cancel:
|
||||||
|
return
|
||||||
|
|
||||||
|
default:
|
||||||
|
data, err := h.getMapKeepAliveResponse(mKey, req, m)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error generating the keep alive msg: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pollData <- *data
|
||||||
|
time.Sleep(60 * time.Second)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) getMapResponse(mKey wgcfg.Key, req tailcfg.MapRequest, m Machine) (*[]byte, error) {
|
func (h *Headscale) getMapResponse(mKey wgcfg.Key, req tailcfg.MapRequest, m Machine) (*[]byte, error) {
|
||||||
|
@ -221,7 +283,7 @@ func (h *Headscale) getMapResponse(mKey wgcfg.Key, req tailcfg.MapRequest, m Mac
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// spew.Dump(resp)
|
||||||
// declare the incoming size on the first 4 bytes
|
// declare the incoming size on the first 4 bytes
|
||||||
data := make([]byte, 4)
|
data := make([]byte, 4)
|
||||||
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
|
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
|
||||||
|
@ -289,6 +351,7 @@ func (h *Headscale) handleNewServer(c *gin.Context, db *gorm.DB, idKey wgcfg.Key
|
||||||
MachineKey: idKey.HexString(),
|
MachineKey: idKey.HexString(),
|
||||||
NodeKey: wgcfg.Key(req.NodeKey).HexString(),
|
NodeKey: wgcfg.Key(req.NodeKey).HexString(),
|
||||||
Expiry: &req.Expiry,
|
Expiry: &req.Expiry,
|
||||||
|
Name: req.Hostinfo.Hostname,
|
||||||
}
|
}
|
||||||
if err := db.Create(&mNew).Error; err != nil {
|
if err := db.Create(&mNew).Error; err != nil {
|
||||||
log.Printf("Could not create row: %s", err)
|
log.Printf("Could not create row: %s", err)
|
||||||
|
|
Loading…
Reference in a new issue