Add tls certs as creds for grpc

This commit is contained in:
Kristoffer Dalby 2022-02-12 17:05:30 +00:00
parent 59e48993f2
commit 30a2ccd975

94
app.go
View file

@ -34,6 +34,8 @@ import (
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/credentials"
// "google.golang.org/grpc/credentials" // "google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
@ -474,6 +476,13 @@ func (h *Headscale) Serve() error {
zerolog.RespLog = false zerolog.RespLog = false
} }
// Prepare group for running listeners
errorGroup := new(errgroup.Group)
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
// //
// //
// Set up LOCAL listeners // Set up LOCAL listeners
@ -507,39 +516,6 @@ func (h *Headscale) Serve() error {
os.Exit(0) os.Exit(0)
}(sigc) }(sigc)
//
//
// Set up REMOTE listeners
//
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
tlsConfig, err := h.getTLSSettings()
if err != nil {
log.Error().Err(err).Msg("Failed to set up TLS configuration")
return err
}
// var httpListener net.Listener
//
// if tlsConfig != nil {
// httpListener, err = tls.Listen("tcp", h.cfg.Addr, tlsConfig)
// } else {
// httpListener, err = net.Listen("tcp", h.cfg.Addr)
// }
// if err != nil {
// return fmt.Errorf("failed to bind to TCP address: %w", err)
// }
//
//
//
// gRPC setup
//
grpcGatewayMux := runtime.NewServeMux() grpcGatewayMux := runtime.NewServeMux()
// Make the grpc-gateway connect to grpc over socket // Make the grpc-gateway connect to grpc over socket
@ -561,6 +537,35 @@ func (h *Headscale) Serve() error {
return err return err
} }
// Start the local gRPC server without TLS and without authentication
grpcSocket := grpc.NewServer(zerolog.UnaryInterceptor())
v1.RegisterHeadscaleServiceServer(grpcSocket, newHeadscaleV1APIServer(h))
reflection.Register(grpcSocket)
errorGroup.Go(func() error { return grpcSocket.Serve(socketListener) })
//
//
// Set up REMOTE listeners
//
tlsConfig, err := h.getTLSSettings()
if err != nil {
log.Error().Err(err).Msg("Failed to set up TLS configuration")
return err
}
//
//
// gRPC setup
//
// If TLS has been enabled, set up the remote gRPC server
if tlsConfig != nil {
log.Info().Msgf("Enabling remote gRPC at %s", h.cfg.GRPCAddr)
grpcOptions := []grpc.ServerOption{ grpcOptions := []grpc.ServerOption{
grpc.UnaryInterceptor( grpc.UnaryInterceptor(
grpc_middleware.ChainUnaryServer( grpc_middleware.ChainUnaryServer(
@ -568,28 +573,29 @@ func (h *Headscale) Serve() error {
zerolog.NewUnaryServerInterceptor(), zerolog.NewUnaryServerInterceptor(),
), ),
), ),
grpc.Creds(credentials.NewTLS(tlsConfig)),
} }
grpcServer := grpc.NewServer(grpcOptions...) grpcServer := grpc.NewServer(grpcOptions...)
// Start the local gRPC server without TLS and without authentication
grpcSocket := grpc.NewServer(zerolog.UnaryInterceptor())
v1.RegisterHeadscaleServiceServer(grpcServer, newHeadscaleV1APIServer(h)) v1.RegisterHeadscaleServiceServer(grpcServer, newHeadscaleV1APIServer(h))
v1.RegisterHeadscaleServiceServer(grpcSocket, newHeadscaleV1APIServer(h))
reflection.Register(grpcServer) reflection.Register(grpcServer)
reflection.Register(grpcSocket)
var grpcListener net.Listener var grpcListener net.Listener
if tlsConfig != nil { // if tlsConfig != nil {
grpcListener, err = tls.Listen("tcp", h.cfg.GRPCAddr, tlsConfig) // grpcListener, err = tls.Listen("tcp", h.cfg.GRPCAddr, tlsConfig)
} else { // } else {
grpcListener, err = net.Listen("tcp", h.cfg.GRPCAddr) grpcListener, err = net.Listen("tcp", h.cfg.GRPCAddr)
} // }
if err != nil { if err != nil {
return fmt.Errorf("failed to bind to TCP address: %w", err) return fmt.Errorf("failed to bind to TCP address: %w", err)
} }
errorGroup.Go(func() error { return grpcServer.Serve(grpcListener) })
} else {
log.Info().Msg("TLS is not configured, not enabling remote gRPC")
}
// //
// //
// HTTP setup // HTTP setup
@ -619,10 +625,6 @@ func (h *Headscale) Serve() error {
return fmt.Errorf("failed to bind to TCP address: %w", err) return fmt.Errorf("failed to bind to TCP address: %w", err)
} }
errorGroup := new(errgroup.Group)
errorGroup.Go(func() error { return grpcSocket.Serve(socketListener) })
errorGroup.Go(func() error { return grpcServer.Serve(grpcListener) })
errorGroup.Go(func() error { return httpServer.Serve(httpListener) }) errorGroup.Go(func() error { return httpServer.Serve(httpListener) })
log.Info(). log.Info().