From 80772033ee98afa401b2d92ec1b67480ff33f9e4 Mon Sep 17 00:00:00 2001 From: Juan Font Date: Tue, 2 May 2023 08:15:33 +0200 Subject: [PATCH] Improvements on Noise implementation (#1379) --- CHANGELOG.md | 3 +- noise.go | 114 ++++++++++++++++++++++++++++++++++++----- protocol_noise.go | 11 +++- protocol_noise_poll.go | 13 +++-- 4 files changed, 121 insertions(+), 20 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 72f2dee1..f21e931b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,12 +7,13 @@ - Add environment flags to enable pprof (profiling) [#1382](https://github.com/juanfont/headscale/pull/1382) - Profiles are continously generated in our integration tests. - Fix systemd service file location in `.deb` packages [#1391](https://github.com/juanfont/headscale/pull/1391) +- Improvements on Noise implementation [#1379](https://github.com/juanfont/headscale/pull/1379) ## 0.22.1 (2023-04-20) ### Changes -- Fix issue where SystemD could not bind to port 80 [#1365](https://github.com/juanfont/headscale/pull/1365) +- Fix issue where systemd could not bind to port 80 [#1365](https://github.com/juanfont/headscale/pull/1365) ## 0.22.0 (2023-04-20) diff --git a/noise.go b/noise.go index 56967148..9cc489ae 100644 --- a/noise.go +++ b/noise.go @@ -1,6 +1,9 @@ package headscale import ( + "encoding/binary" + "encoding/json" + "io" "net/http" "github.com/gorilla/mux" @@ -9,18 +12,37 @@ import ( "golang.org/x/net/http2/h2c" "tailscale.com/control/controlbase" "tailscale.com/control/controlhttp" - "tailscale.com/net/netutil" + "tailscale.com/tailcfg" + "tailscale.com/types/key" ) const ( // ts2021UpgradePath is the path that the server listens on for the WebSockets upgrade. ts2021UpgradePath = "/ts2021" + + // The first 9 bytes from the server to client over Noise are either an HTTP/2 + // settings frame (a normal HTTP/2 setup) or, as Tailscale added later, an "early payload" + // header that's also 9 bytes long: 5 bytes (earlyPayloadMagic) followed by 4 bytes + // of length. Then that many bytes of JSON-encoded tailcfg.EarlyNoise. + // The early payload is optional. Some servers may not send it... But we do! + earlyPayloadMagic = "\xff\xff\xffTS" + + // EarlyNoise was added in protocol version 49. + earlyNoiseCapabilityVersion = 49 ) -type ts2021App struct { +type noiseServer struct { headscale *Headscale - conn *controlbase.Conn + httpBaseConfig *http.Server + http2Server *http2.Server + conn *controlbase.Conn + machineKey key.MachinePublic + nodeKey key.NodePublic + + // EarlyNoise-related stuff + challenge key.ChallengePrivate + protocolVersion int } // NoiseUpgradeHandler is to upgrade the connection and hijack the net.Conn @@ -44,7 +66,18 @@ func (h *Headscale) NoiseUpgradeHandler( return } - noiseConn, err := controlhttp.AcceptHTTP(req.Context(), writer, req, *h.noisePrivateKey, nil) + noiseServer := noiseServer{ + headscale: h, + challenge: key.NewChallenge(), + } + + noiseConn, err := controlhttp.AcceptHTTP( + req.Context(), + writer, + req, + *h.noisePrivateKey, + noiseServer.earlyNoise, + ) if err != nil { log.Error().Err(err).Msg("noise upgrade failed") http.Error(writer, err.Error(), http.StatusInternalServerError) @@ -52,10 +85,9 @@ func (h *Headscale) NoiseUpgradeHandler( return } - ts2021App := ts2021App{ - headscale: h, - conn: noiseConn, - } + noiseServer.conn = noiseConn + noiseServer.machineKey = noiseServer.conn.Peer() + noiseServer.protocolVersion = noiseServer.conn.ProtocolVersion() // This router is served only over the Noise connection, and exposes only the new API. // @@ -63,16 +95,70 @@ func (h *Headscale) NoiseUpgradeHandler( // a single hijacked connection from /ts2021, using netutil.NewOneConnListener router := mux.NewRouter() - router.HandleFunc("/machine/register", ts2021App.NoiseRegistrationHandler). + router.HandleFunc("/machine/register", noiseServer.NoiseRegistrationHandler). Methods(http.MethodPost) - router.HandleFunc("/machine/map", ts2021App.NoisePollNetMapHandler) + router.HandleFunc("/machine/map", noiseServer.NoisePollNetMapHandler) server := http.Server{ ReadTimeout: HTTPReadTimeout, } - server.Handler = h2c.NewHandler(router, &http2.Server{}) - err = server.Serve(netutil.NewOneConnListener(noiseConn, nil)) - if err != nil { - log.Info().Err(err).Msg("The HTTP2 server was closed") + + noiseServer.httpBaseConfig = &http.Server{ + Handler: router, + ReadHeaderTimeout: HTTPReadTimeout, } + noiseServer.http2Server = &http2.Server{} + + server.Handler = h2c.NewHandler(router, noiseServer.http2Server) + + noiseServer.http2Server.ServeConn( + noiseConn, + &http2.ServeConnOpts{ + BaseConfig: noiseServer.httpBaseConfig, + }, + ) +} + +func (ns *noiseServer) earlyNoise(protocolVersion int, writer io.Writer) error { + log.Trace(). + Caller(). + Int("protocol_version", protocolVersion). + Str("challenge", ns.challenge.Public().String()). + Msg("earlyNoise called") + + if protocolVersion < earlyNoiseCapabilityVersion { + log.Trace(). + Caller(). + Msgf("protocol version %d does not support early noise", protocolVersion) + + return nil + } + + earlyJSON, err := json.Marshal(&tailcfg.EarlyNoise{ + NodeKeyChallenge: ns.challenge.Public(), + }) + if err != nil { + return err + } + + // 5 bytes that won't be mistaken for an HTTP/2 frame: + // https://httpwg.org/specs/rfc7540.html#rfc.section.4.1 (Especially not + // an HTTP/2 settings frame, which isn't of type 'T') + var notH2Frame [5]byte + copy(notH2Frame[:], earlyPayloadMagic) + var lenBuf [4]byte + binary.BigEndian.PutUint32(lenBuf[:], uint32(len(earlyJSON))) + // These writes are all buffered by caller, so fine to do them + // separately: + if _, err := writer.Write(notH2Frame[:]); err != nil { + return err + } + if _, err := writer.Write(lenBuf[:]); err != nil { + return err + } + if _, err := writer.Write(earlyJSON); err != nil { + return err + } + + return nil } diff --git a/protocol_noise.go b/protocol_noise.go index 1d1b9c8a..eb18a474 100644 --- a/protocol_noise.go +++ b/protocol_noise.go @@ -10,7 +10,7 @@ import ( ) // // NoiseRegistrationHandler handles the actual registration process of a machine. -func (t *ts2021App) NoiseRegistrationHandler( +func (ns *noiseServer) NoiseRegistrationHandler( writer http.ResponseWriter, req *http.Request, ) { @@ -20,6 +20,11 @@ func (t *ts2021App) NoiseRegistrationHandler( return } + + log.Trace(). + Any("headers", req.Header). + Msg("Headers") + body, _ := io.ReadAll(req.Body) registerRequest := tailcfg.RegisterRequest{} if err := json.Unmarshal(body, ®isterRequest); err != nil { @@ -33,5 +38,7 @@ func (t *ts2021App) NoiseRegistrationHandler( return } - t.headscale.handleRegisterCommon(writer, req, registerRequest, t.conn.Peer(), true) + ns.nodeKey = registerRequest.NodeKey + + ns.headscale.handleRegisterCommon(writer, req, registerRequest, ns.conn.Peer(), true) } diff --git a/protocol_noise_poll.go b/protocol_noise_poll.go index 2c298693..d5e9a8a9 100644 --- a/protocol_noise_poll.go +++ b/protocol_noise_poll.go @@ -21,13 +21,18 @@ import ( // only after their first request (marked with the ReadOnly field). // // At this moment the updates are sent in a quite horrendous way, but they kinda work. -func (t *ts2021App) NoisePollNetMapHandler( +func (ns *noiseServer) NoisePollNetMapHandler( writer http.ResponseWriter, req *http.Request, ) { log.Trace(). Str("handler", "NoisePollNetMap"). Msg("PollNetMapHandler called") + + log.Trace(). + Any("headers", req.Header). + Msg("Headers") + body, _ := io.ReadAll(req.Body) mapRequest := tailcfg.MapRequest{} @@ -41,7 +46,9 @@ func (t *ts2021App) NoisePollNetMapHandler( return } - machine, err := t.headscale.GetMachineByAnyKey(t.conn.Peer(), mapRequest.NodeKey, key.NodePublic{}) + ns.nodeKey = mapRequest.NodeKey + + machine, err := ns.headscale.GetMachineByAnyKey(ns.conn.Peer(), mapRequest.NodeKey, key.NodePublic{}) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { log.Warn(). @@ -63,5 +70,5 @@ func (t *ts2021App) NoisePollNetMapHandler( Str("machine", machine.Hostname). Msg("A machine is entering polling via the Noise protocol") - t.headscale.handlePollCommon(writer, req.Context(), machine, mapRequest, true) + ns.headscale.handlePollCommon(writer, req.Context(), machine, mapRequest, true) }