Switch from gRPC localhost to socket

This commit changes the way CLI and grpc-gateway communicates with the
gRPC backend to socket, instead of localhost. Unauthenticated access now
goes on the socket, while the network interface will require API key (in
the future).
This commit is contained in:
Kristoffer Dalby 2021-10-30 14:08:16 +00:00
parent 72fd2a2780
commit 6aacada852
6 changed files with 50 additions and 134 deletions

View file

@ -14,3 +14,5 @@ docker-compose*
README.md
LICENSE
.vscode
*.sock

36
app.go
View file

@ -39,9 +39,7 @@ import (
)
const (
LOCALHOST_V4 = "127.0.0.1"
LOCALHOST_V6 = "[::1]"
AUTH_PREFIX = "Bearer "
AUTH_PREFIX = "Bearer "
)
// Config contains the initial Headscale configuration.
@ -75,6 +73,8 @@ type Config struct {
ACMEEmail string
DNSConfig *tailcfg.DNSConfig
UnixSocket string
}
type DERPConfig struct {
@ -233,8 +233,9 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
// the server
p, _ := peer.FromContext(ctx)
if IsLocalhost(p.Addr.String()) {
log.Trace().Caller().Str("client_address", p.Addr.String()).Msg("Client connected from localhost")
// TODO(kradalby): Figure out what @ means (socket wise) and if it can be exploited
if p.Addr.String() == "@" {
log.Trace().Caller().Str("client_address", p.Addr.String()).Msg("Client connecting over socket")
return handler(ctx, req)
}
@ -326,14 +327,19 @@ func (h *Headscale) Serve() error {
defer cancel()
l, err := net.Listen("tcp", h.cfg.Addr)
socketListener, err := net.Listen("unix", h.cfg.UnixSocket)
if err != nil {
panic(err)
}
networkListener, err := net.Listen("tcp", h.cfg.Addr)
if err != nil {
panic(err)
}
// Create the cmux object that will multiplex 2 protocols on the same port.
// The two following listeners will be served on the same port below gracefully.
m := cmux.New(l)
m := cmux.New(networkListener)
// Match gRPC requests here
grpcListener := m.MatchWithWriters(
cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc"),
@ -344,16 +350,23 @@ func (h *Headscale) Serve() error {
grpcGatewayMux := runtime.NewServeMux()
grpcDialOptions := []grpc.DialOption{grpc.WithInsecure()}
_, port, err := net.SplitHostPort(h.cfg.Addr)
// Make the grpc-gateway connect to grpc over socket
grpcGatewayConn, err := grpc.Dial(
h.cfg.UnixSocket,
[]grpc.DialOption{
grpc.WithInsecure(),
grpc.WithDialer(func(addr string, timeout time.Duration) (net.Conn, error) {
return net.DialTimeout("unix", addr, timeout)
}),
}...,
)
if err != nil {
return err
}
// Connect to the gRPC server over localhost to skip
// the authentication.
err = apiV1.RegisterHeadscaleServiceHandlerFromEndpoint(ctx, grpcGatewayMux, LOCALHOST_V4+":"+port, grpcDialOptions)
err = apiV1.RegisterHeadscaleServiceHandler(ctx, grpcGatewayMux, grpcGatewayConn)
if err != nil {
return err
}
@ -432,6 +445,7 @@ func (h *Headscale) Serve() error {
g := new(errgroup.Group)
g.Go(func() error { return grpcServer.Serve(socketListener) })
g.Go(func() error { return grpcServer.Serve(grpcListener) })
g.Go(func() error { return httpServer.Serve(httpListener) })
g.Go(func() error { return m.Serve() })

View file

@ -48,6 +48,8 @@ func LoadConfig(path string) error {
viper.SetDefault("dns_config", nil)
viper.SetDefault("unix_socket", "/var/run/headscale.sock")
err := viper.ReadInConfig()
if err != nil {
return fmt.Errorf("Fatal error reading config file: %s \n", err)
@ -242,6 +244,8 @@ func getHeadscaleConfig() headscale.Config {
ACMEEmail: viper.GetString("acme_email"),
ACMEURL: viper.GetString("acme_url"),
UnixSocket: viper.GetString("unix_socket"),
}
}
@ -282,11 +286,11 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
}
func getHeadscaleGRPCClient() (apiV1.HeadscaleServiceClient, *grpc.ClientConn) {
// TODO(kradalby): Make configurable
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
grpcOptions := []grpc.DialOption{
// TODO(kradalby): Make configurable
grpc.WithBlock(),
}
@ -294,19 +298,24 @@ func getHeadscaleGRPCClient() (apiV1.HeadscaleServiceClient, *grpc.ClientConn) {
// If the address is not set, we assume that we are on the server hosting headscale.
if address == "" {
log.Debug().Msgf("HEADSCALE_ADDRESS environment is not set, connecting to localhost.")
cfg := getHeadscaleConfig()
_, port, _ := net.SplitHostPort(cfg.Addr)
log.Debug().
Str("socket", cfg.UnixSocket).
Msgf("HEADSCALE_ADDRESS environment is not set, connecting to unix socket.")
address = "127.0.0.1" + ":" + port
address = cfg.UnixSocket
grpcOptions = append(grpcOptions, grpc.WithInsecure())
}
// If we are not connecting to a local server, require an API key for authentication
if !headscale.IsLocalhost(address) {
grpcOptions = append(
grpcOptions,
grpc.WithInsecure(),
grpc.WithDialer(func(addr string, timeout time.Duration) (net.Conn, error) {
return net.DialTimeout("unix", addr, timeout)
}),
)
} else {
// If we are not connecting to a local server, require an API key for authentication
apiKey := os.Getenv("HEADSCALE_API_KEY")
if apiKey == "" {
log.Fatal().Msgf("HEADSCALE_API_KEY environment variable needs to be set.")

View file

@ -64,3 +64,8 @@ dns_config:
magic_dns: true
base_domain: example.com
# Unix socket used for the CLI to connect without authentication
# Note: for local development, you probably want to change this to:
# unix_socket: ./headscale.sock
unix_socket: /var/run/headscale.sock

View file

@ -1,106 +0,0 @@
syntax = "proto3";
package headscale.v1;
option go_package = "github.com/juanfont/headscale/gen/go/v1";
import "google/protobuf/timestamp.proto";
import "google/api/annotations.proto";
enum RegisterMethod {
REGISTER_METHOD_UNSPECIFIED = 0;
REGISTER_METHOD_AUTH_KEY = 1;
REGISTER_METHOD_CLI = 2;
REGISTER_METHOD_OIDC = 3;
}
// message PreAuthKey {
// uint64 id = 1;
// string key = 2;
// uint32 namespace_id = 3;
// Namespace namespace = 4;
// bool reusable = 5;
// bool ephemeral = 6;
// bool used = 7;
//
// google.protobuf.Timestamp created_at = 8;
// google.protobuf.Timestamp expiration = 9;
// }
message GetMachineRequest {
uint64 machine_id = 1;
}
message GetMachineResponse {
uint64 id = 1;
string machine_key = 2;
string node_key = 3;
string disco_key = 4;
string ip_address = 5;
string name = 6;
uint32 namespace_id = 7;
bool registered = 8;
RegisterMethod register_method = 9;
uint32 auth_key_id = 10;
// PreAuthKey auth_key = 11;
google.protobuf.Timestamp last_seen = 12;
google.protobuf.Timestamp last_successful_update = 13;
google.protobuf.Timestamp expiry = 14;
// bytes host_info = 15;
// bytes endpoints = 16;
// bytes enabled_routes = 17;
// google.protobuf.Timestamp created_at = 18;
// google.protobuf.Timestamp updated_at = 19;
// google.protobuf.Timestamp deleted_at = 20;
}
message CreateNamespaceRequest {
string name = 1;
}
message CreateNamespaceResponse {
string name = 1;
}
message DeleteNamespaceRequest {
string name = 1;
}
message DeleteNamespaceResponse {
}
message ListNamespacesRequest {
}
message ListNamespacesResponse {
repeated string namespaces = 1;
}
service HeadscaleService {
rpc GetMachine(GetMachineRequest) returns(GetMachineResponse) {
option(google.api.http) = {
get : "/api/v1/machine/{machine_id}"
};
}
rpc CreateNamespace(CreateNamespaceRequest) returns(CreateNamespaceResponse) {
option(google.api.http) = {
post : "/api/v1/namespace"
body : "*"
};
}
rpc DeleteNamespace(DeleteNamespaceRequest) returns(DeleteNamespaceResponse) {
option(google.api.http) = {
delete : "/api/v1/namespace"
};
}
rpc ListNamespaces(ListNamespacesRequest) returns(ListNamespacesResponse) {
option(google.api.http) = {
get : "/api/v1/namespace"
};
}
}

View file

@ -156,11 +156,3 @@ func tailNodesToString(nodes []*tailcfg.Node) string {
func tailMapResponseToString(resp tailcfg.MapResponse) string {
return fmt.Sprintf("{ Node: %s, Peers: %s }", resp.Node.Name, tailNodesToString(resp.Peers))
}
func IsLocalhost(host string) bool {
if strings.Contains(host, LOCALHOST_V4) || strings.Contains(host, LOCALHOST_V6) {
return true
}
return false
}