From 30a2ccd9758c96efa42e2f9691f151be480ec5a6 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Sat, 12 Feb 2022 17:05:30 +0000 Subject: [PATCH] Add tls certs as creds for grpc --- app.go | 116 +++++++++++++++++++++++++++++---------------------------- 1 file changed, 59 insertions(+), 57 deletions(-) diff --git a/app.go b/app.go index 987e64e4..8d228b48 100644 --- a/app.go +++ b/app.go @@ -34,6 +34,8 @@ import ( "golang.org/x/sync/errgroup" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal/credentials" // "google.golang.org/grpc/credentials" "google.golang.org/grpc/metadata" @@ -474,6 +476,13 @@ func (h *Headscale) Serve() error { zerolog.RespLog = false } + // Prepare group for running listeners + errorGroup := new(errgroup.Group) + + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + // // // Set up LOCAL listeners @@ -507,39 +516,6 @@ func (h *Headscale) Serve() error { os.Exit(0) }(sigc) - // - // - // Set up REMOTE listeners - // - ctx := context.Background() - ctx, cancel := context.WithCancel(ctx) - - defer cancel() - - tlsConfig, err := h.getTLSSettings() - if err != nil { - log.Error().Err(err).Msg("Failed to set up TLS configuration") - - return err - } - - // var httpListener net.Listener - // - // if tlsConfig != nil { - // httpListener, err = tls.Listen("tcp", h.cfg.Addr, tlsConfig) - // } else { - // httpListener, err = net.Listen("tcp", h.cfg.Addr) - // } - // if err != nil { - // return fmt.Errorf("failed to bind to TCP address: %w", err) - // } - // - - // - // - // gRPC setup - // - grpcGatewayMux := runtime.NewServeMux() // Make the grpc-gateway connect to grpc over socket @@ -561,33 +537,63 @@ func (h *Headscale) Serve() error { return err } - grpcOptions := []grpc.ServerOption{ - grpc.UnaryInterceptor( - grpc_middleware.ChainUnaryServer( - h.grpcAuthenticationInterceptor, - zerolog.NewUnaryServerInterceptor(), - ), - ), - } - - grpcServer := grpc.NewServer(grpcOptions...) - // Start the local gRPC server without TLS and without authentication grpcSocket := grpc.NewServer(zerolog.UnaryInterceptor()) - v1.RegisterHeadscaleServiceServer(grpcServer, newHeadscaleV1APIServer(h)) v1.RegisterHeadscaleServiceServer(grpcSocket, newHeadscaleV1APIServer(h)) - reflection.Register(grpcServer) reflection.Register(grpcSocket) - var grpcListener net.Listener - if tlsConfig != nil { - grpcListener, err = tls.Listen("tcp", h.cfg.GRPCAddr, tlsConfig) - } else { - grpcListener, err = net.Listen("tcp", h.cfg.GRPCAddr) - } + errorGroup.Go(func() error { return grpcSocket.Serve(socketListener) }) + + // + // + // Set up REMOTE listeners + // + + tlsConfig, err := h.getTLSSettings() if err != nil { - return fmt.Errorf("failed to bind to TCP address: %w", err) + log.Error().Err(err).Msg("Failed to set up TLS configuration") + + return err + } + + // + // + // gRPC setup + // + + // If TLS has been enabled, set up the remote gRPC server + if tlsConfig != nil { + log.Info().Msgf("Enabling remote gRPC at %s", h.cfg.GRPCAddr) + + grpcOptions := []grpc.ServerOption{ + grpc.UnaryInterceptor( + grpc_middleware.ChainUnaryServer( + h.grpcAuthenticationInterceptor, + zerolog.NewUnaryServerInterceptor(), + ), + ), + grpc.Creds(credentials.NewTLS(tlsConfig)), + } + + grpcServer := grpc.NewServer(grpcOptions...) + + v1.RegisterHeadscaleServiceServer(grpcServer, newHeadscaleV1APIServer(h)) + reflection.Register(grpcServer) + + var grpcListener net.Listener + // if tlsConfig != nil { + // grpcListener, err = tls.Listen("tcp", h.cfg.GRPCAddr, tlsConfig) + // } else { + grpcListener, err = net.Listen("tcp", h.cfg.GRPCAddr) + // } + if err != nil { + return fmt.Errorf("failed to bind to TCP address: %w", err) + } + + errorGroup.Go(func() error { return grpcServer.Serve(grpcListener) }) + } else { + log.Info().Msg("TLS is not configured, not enabling remote gRPC") } // @@ -619,10 +625,6 @@ func (h *Headscale) Serve() error { return fmt.Errorf("failed to bind to TCP address: %w", err) } - errorGroup := new(errgroup.Group) - - errorGroup.Go(func() error { return grpcSocket.Serve(socketListener) }) - errorGroup.Go(func() error { return grpcServer.Serve(grpcListener) }) errorGroup.Go(func() error { return httpServer.Serve(httpListener) }) log.Info().