diff --git a/.dockerignore b/.dockerignore index 33f9aea2..057a20e7 100644 --- a/.dockerignore +++ b/.dockerignore @@ -14,3 +14,5 @@ docker-compose* README.md LICENSE .vscode + +*.sock diff --git a/app.go b/app.go index c77a85a5..94517d54 100644 --- a/app.go +++ b/app.go @@ -39,9 +39,7 @@ import ( ) const ( - LOCALHOST_V4 = "127.0.0.1" - LOCALHOST_V6 = "[::1]" - AUTH_PREFIX = "Bearer " + AUTH_PREFIX = "Bearer " ) // Config contains the initial Headscale configuration. @@ -75,6 +73,8 @@ type Config struct { ACMEEmail string DNSConfig *tailcfg.DNSConfig + + UnixSocket string } type DERPConfig struct { @@ -233,8 +233,9 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context, // the server p, _ := peer.FromContext(ctx) - if IsLocalhost(p.Addr.String()) { - log.Trace().Caller().Str("client_address", p.Addr.String()).Msg("Client connected from localhost") + // TODO(kradalby): Figure out what @ means (socket wise) and if it can be exploited + if p.Addr.String() == "@" { + log.Trace().Caller().Str("client_address", p.Addr.String()).Msg("Client connecting over socket") return handler(ctx, req) } @@ -326,14 +327,19 @@ func (h *Headscale) Serve() error { defer cancel() - l, err := net.Listen("tcp", h.cfg.Addr) + socketListener, err := net.Listen("unix", h.cfg.UnixSocket) + if err != nil { + panic(err) + } + + networkListener, err := net.Listen("tcp", h.cfg.Addr) if err != nil { panic(err) } // Create the cmux object that will multiplex 2 protocols on the same port. // The two following listeners will be served on the same port below gracefully. - m := cmux.New(l) + m := cmux.New(networkListener) // Match gRPC requests here grpcListener := m.MatchWithWriters( cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc"), @@ -344,16 +350,23 @@ func (h *Headscale) Serve() error { grpcGatewayMux := runtime.NewServeMux() - grpcDialOptions := []grpc.DialOption{grpc.WithInsecure()} - - _, port, err := net.SplitHostPort(h.cfg.Addr) + // Make the grpc-gateway connect to grpc over socket + grpcGatewayConn, err := grpc.Dial( + h.cfg.UnixSocket, + []grpc.DialOption{ + grpc.WithInsecure(), + grpc.WithDialer(func(addr string, timeout time.Duration) (net.Conn, error) { + return net.DialTimeout("unix", addr, timeout) + }), + }..., + ) if err != nil { return err } // Connect to the gRPC server over localhost to skip // the authentication. - err = apiV1.RegisterHeadscaleServiceHandlerFromEndpoint(ctx, grpcGatewayMux, LOCALHOST_V4+":"+port, grpcDialOptions) + err = apiV1.RegisterHeadscaleServiceHandler(ctx, grpcGatewayMux, grpcGatewayConn) if err != nil { return err } @@ -432,6 +445,7 @@ func (h *Headscale) Serve() error { g := new(errgroup.Group) + g.Go(func() error { return grpcServer.Serve(socketListener) }) g.Go(func() error { return grpcServer.Serve(grpcListener) }) g.Go(func() error { return httpServer.Serve(httpListener) }) g.Go(func() error { return m.Serve() }) diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index 14ff2bf2..1d6fdd61 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -48,6 +48,8 @@ func LoadConfig(path string) error { viper.SetDefault("dns_config", nil) + viper.SetDefault("unix_socket", "/var/run/headscale.sock") + err := viper.ReadInConfig() if err != nil { return fmt.Errorf("Fatal error reading config file: %s \n", err) @@ -242,6 +244,8 @@ func getHeadscaleConfig() headscale.Config { ACMEEmail: viper.GetString("acme_email"), ACMEURL: viper.GetString("acme_url"), + + UnixSocket: viper.GetString("unix_socket"), } } @@ -282,11 +286,11 @@ func getHeadscaleApp() (*headscale.Headscale, error) { } func getHeadscaleGRPCClient() (apiV1.HeadscaleServiceClient, *grpc.ClientConn) { + // TODO(kradalby): Make configurable ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() grpcOptions := []grpc.DialOption{ - // TODO(kradalby): Make configurable grpc.WithBlock(), } @@ -294,19 +298,24 @@ func getHeadscaleGRPCClient() (apiV1.HeadscaleServiceClient, *grpc.ClientConn) { // If the address is not set, we assume that we are on the server hosting headscale. if address == "" { - log.Debug().Msgf("HEADSCALE_ADDRESS environment is not set, connecting to localhost.") cfg := getHeadscaleConfig() - _, port, _ := net.SplitHostPort(cfg.Addr) + log.Debug(). + Str("socket", cfg.UnixSocket). + Msgf("HEADSCALE_ADDRESS environment is not set, connecting to unix socket.") - address = "127.0.0.1" + ":" + port + address = cfg.UnixSocket - grpcOptions = append(grpcOptions, grpc.WithInsecure()) - } - - // If we are not connecting to a local server, require an API key for authentication - if !headscale.IsLocalhost(address) { + grpcOptions = append( + grpcOptions, + grpc.WithInsecure(), + grpc.WithDialer(func(addr string, timeout time.Duration) (net.Conn, error) { + return net.DialTimeout("unix", addr, timeout) + }), + ) + } else { + // If we are not connecting to a local server, require an API key for authentication apiKey := os.Getenv("HEADSCALE_API_KEY") if apiKey == "" { log.Fatal().Msgf("HEADSCALE_API_KEY environment variable needs to be set.") diff --git a/config-example.yaml b/config-example.yaml index 59370eb5..0eaf4c2d 100644 --- a/config-example.yaml +++ b/config-example.yaml @@ -64,3 +64,8 @@ dns_config: magic_dns: true base_domain: example.com + +# Unix socket used for the CLI to connect without authentication +# Note: for local development, you probably want to change this to: +# unix_socket: ./headscale.sock +unix_socket: /var/run/headscale.sock diff --git a/proto/headscale/v1/rpc.proto b/proto/headscale/v1/rpc.proto deleted file mode 100644 index 2e6b0a67..00000000 --- a/proto/headscale/v1/rpc.proto +++ /dev/null @@ -1,106 +0,0 @@ -syntax = "proto3"; -package headscale.v1; -option go_package = "github.com/juanfont/headscale/gen/go/v1"; - -import "google/protobuf/timestamp.proto"; -import "google/api/annotations.proto"; - -enum RegisterMethod { - REGISTER_METHOD_UNSPECIFIED = 0; - REGISTER_METHOD_AUTH_KEY = 1; - REGISTER_METHOD_CLI = 2; - REGISTER_METHOD_OIDC = 3; -} - -// message PreAuthKey { -// uint64 id = 1; -// string key = 2; -// uint32 namespace_id = 3; -// Namespace namespace = 4; -// bool reusable = 5; -// bool ephemeral = 6; -// bool used = 7; -// -// google.protobuf.Timestamp created_at = 8; -// google.protobuf.Timestamp expiration = 9; -// } - -message GetMachineRequest { - uint64 machine_id = 1; -} - -message GetMachineResponse { - uint64 id = 1; - string machine_key = 2; - string node_key = 3; - string disco_key = 4; - string ip_address = 5; - string name = 6; - uint32 namespace_id = 7; - - bool registered = 8; - RegisterMethod register_method = 9; - uint32 auth_key_id = 10; - // PreAuthKey auth_key = 11; - - google.protobuf.Timestamp last_seen = 12; - google.protobuf.Timestamp last_successful_update = 13; - google.protobuf.Timestamp expiry = 14; - - // bytes host_info = 15; - // bytes endpoints = 16; - // bytes enabled_routes = 17; - - // google.protobuf.Timestamp created_at = 18; - // google.protobuf.Timestamp updated_at = 19; - // google.protobuf.Timestamp deleted_at = 20; -} - -message CreateNamespaceRequest { - string name = 1; -} - -message CreateNamespaceResponse { - string name = 1; -} - -message DeleteNamespaceRequest { - string name = 1; -} - -message DeleteNamespaceResponse { -} - -message ListNamespacesRequest { -} - -message ListNamespacesResponse { - repeated string namespaces = 1; -} - -service HeadscaleService { - rpc GetMachine(GetMachineRequest) returns(GetMachineResponse) { - option(google.api.http) = { - get : "/api/v1/machine/{machine_id}" - }; - } - - rpc CreateNamespace(CreateNamespaceRequest) returns(CreateNamespaceResponse) { - option(google.api.http) = { - post : "/api/v1/namespace" - body : "*" - }; - } - - rpc DeleteNamespace(DeleteNamespaceRequest) returns(DeleteNamespaceResponse) { - option(google.api.http) = { - delete : "/api/v1/namespace" - }; - } - - rpc ListNamespaces(ListNamespacesRequest) returns(ListNamespacesResponse) { - option(google.api.http) = { - get : "/api/v1/namespace" - }; - } -} diff --git a/utils.go b/utils.go index 1aa31712..cbe1d870 100644 --- a/utils.go +++ b/utils.go @@ -156,11 +156,3 @@ func tailNodesToString(nodes []*tailcfg.Node) string { func tailMapResponseToString(resp tailcfg.MapResponse) string { return fmt.Sprintf("{ Node: %s, Peers: %s }", resp.Node.Name, tailNodesToString(resp.Peers)) } - -func IsLocalhost(host string) bool { - if strings.Contains(host, LOCALHOST_V4) || strings.Contains(host, LOCALHOST_V6) { - return true - } - - return false -}