feat: add WithGrpcTLS shim

This commit is contained in:
Kedas 2024-05-25 05:35:00 +00:00
parent 50a7315226
commit 9163edcf50
3 changed files with 43 additions and 2 deletions

View file

@ -43,6 +43,7 @@ func TestDERPServerScenario(t *testing.T) {
hsic.WithExtraPorts([]string{"3478/udp"}), hsic.WithExtraPorts([]string{"3478/udp"}),
hsic.WithEmbeddedDERPServerOnly(), hsic.WithEmbeddedDERPServerOnly(),
hsic.WithTLS(), hsic.WithTLS(),
hsic.WithGrpcTLS(),
hsic.WithHostnameAsServerURL(), hsic.WithHostnameAsServerURL(),
) )
assertNoErrHeadscaleEnv(t, err) assertNoErrHeadscaleEnv(t, err)

View file

@ -41,6 +41,7 @@ func TestPingAllByIP(t *testing.T) {
hsic.WithTestName("pingallbyip"), hsic.WithTestName("pingallbyip"),
hsic.WithEmbeddedDERPServerOnly(), hsic.WithEmbeddedDERPServerOnly(),
hsic.WithTLS(), hsic.WithTLS(),
hsic.WithGrpcTLS(),
hsic.WithHostnameAsServerURL(), hsic.WithHostnameAsServerURL(),
hsic.WithIPAllocationStrategy(types.IPAllocationStrategyRandom), hsic.WithIPAllocationStrategy(types.IPAllocationStrategyRandom),
) )
@ -836,6 +837,7 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
hsic.WithTestName("pingallbyipmany"), hsic.WithTestName("pingallbyipmany"),
hsic.WithEmbeddedDERPServerOnly(), hsic.WithEmbeddedDERPServerOnly(),
hsic.WithTLS(), hsic.WithTLS(),
hsic.WithGrpcTLS(),
hsic.WithHostnameAsServerURL(), hsic.WithHostnameAsServerURL(),
) )
assertNoErrHeadscaleEnv(t, err) assertNoErrHeadscaleEnv(t, err)

View file

@ -40,6 +40,8 @@ const (
aclPolicyPath = "/etc/headscale/acl.hujson" aclPolicyPath = "/etc/headscale/acl.hujson"
tlsCertPath = "/etc/headscale/tls.cert" tlsCertPath = "/etc/headscale/tls.cert"
tlsKeyPath = "/etc/headscale/tls.key" tlsKeyPath = "/etc/headscale/tls.key"
grpcTlsCertPath = "/etc/headscale/grpc_tls.cert"
grpcTlsKeyPath = "/etc/headscale/grpc_tls.key"
headscaleDefaultPort = 8080 headscaleDefaultPort = 8080
) )
@ -69,6 +71,8 @@ type HeadscaleInContainer struct {
env map[string]string env map[string]string
tlsCert []byte tlsCert []byte
tlsKey []byte tlsKey []byte
grpcTlsCert []byte
grpcTlsKey []byte
filesInContainer []fileInContainer filesInContainer []fileInContainer
postgres bool postgres bool
} }
@ -97,14 +101,27 @@ func WithTLS() Option {
} }
// TODO(kradalby): Move somewhere appropriate // TODO(kradalby): Move somewhere appropriate
hsic.env["HEADSCALE_GRPC_TLS_CERT_PATH"] = tlsCertPath hsic.env["HEADSCALE_TLS_CERT_PATH"] = tlsCertPath
hsic.env["HEADSCALE_GRPC_TLS_KEY_PATH"] = tlsKeyPath hsic.env["HEADSCALE_TLS_KEY_PATH"] = tlsKeyPath
hsic.tlsCert = cert hsic.tlsCert = cert
hsic.tlsKey = key hsic.tlsKey = key
} }
} }
// WithGrpcTLS creates gRPC certificates and enables them.
func WithGrpcTLS() Option {
return func(hsic *HeadscaleInContainer) {
cert, key, err := createCertificate(hsic.hostname)
if err != nil {
log.Fatalf("failed to create grpc certificates for headscale test: %s", err)
}
hsic.grpcTlsCert = cert
hsic.grpcTlsKey = key
}
}
// WithConfigEnv takes a map of environment variables that // WithConfigEnv takes a map of environment variables that
// can be used to override Headscale configuration. // can be used to override Headscale configuration.
func WithConfigEnv(configEnv map[string]string) Option { func WithConfigEnv(configEnv map[string]string) Option {
@ -253,6 +270,11 @@ func New(
hsic.env["HEADSCALE_SERVER_URL"] = serverURL.String() hsic.env["HEADSCALE_SERVER_URL"] = serverURL.String()
} }
if hsic.hasGrpcTLS() {
hsic.env["HEADSCALE_GRPC_TLS_CERT_PATH"] = grpcTlsCertPath
hsic.env["HEADSCALE_GRPC_TLS_KEY_PATH"] = grpcTlsKeyPath
}
headscaleBuildOptions := &dockertest.BuildOptions{ headscaleBuildOptions := &dockertest.BuildOptions{
Dockerfile: "Dockerfile.debug", Dockerfile: "Dockerfile.debug",
ContextDir: dockerContextPath, ContextDir: dockerContextPath,
@ -374,6 +396,18 @@ func New(
} }
} }
if hsic.hasGrpcTLS() {
err = hsic.WriteFile(grpcTlsCertPath, hsic.grpcTlsCert)
if err != nil {
return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err)
}
err = hsic.WriteFile(grpcTlsKeyPath, hsic.grpcTlsKey)
if err != nil {
return nil, fmt.Errorf("failed to write TLS key to container: %w", err)
}
}
for _, f := range hsic.filesInContainer { for _, f := range hsic.filesInContainer {
if err := hsic.WriteFile(f.path, f.contents); err != nil { if err := hsic.WriteFile(f.path, f.contents); err != nil {
return nil, fmt.Errorf("failed to write %q: %w", f.path, err) return nil, fmt.Errorf("failed to write %q: %w", f.path, err)
@ -391,6 +425,10 @@ func (t *HeadscaleInContainer) hasTLS() bool {
return len(t.tlsCert) != 0 && len(t.tlsKey) != 0 return len(t.tlsCert) != 0 && len(t.tlsKey) != 0
} }
func (t *HeadscaleInContainer) hasGrpcTLS() bool {
return len(t.grpcTlsCert) != 0 && len(t.grpcTlsKey) != 0
}
// Shutdown stops and cleans up the Headscale container. // Shutdown stops and cleans up the Headscale container.
func (t *HeadscaleInContainer) Shutdown() error { func (t *HeadscaleInContainer) Shutdown() error {
err := t.SaveLog("/tmp/control") err := t.SaveLog("/tmp/control")