Make TLS setup work automatically

This commit injects the per-test-generated tls certs into the tailscale
container and makes sure all can ping all. It does not test any of the
DERP isolation yet.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2022-11-06 20:22:21 +01:00 committed by Juan Font
parent 89ff5c83d2
commit 9bc6ac0f35
4 changed files with 363 additions and 96 deletions

View file

@ -1,52 +1,83 @@
package hsic package hsic
import ( import (
"archive/tar"
"bytes" "bytes"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/json" "encoding/json"
"encoding/pem"
"errors" "errors"
"fmt" "fmt"
"io"
"log" "log"
"math/big"
"net"
"net/http" "net/http"
"path/filepath" "time"
"github.com/juanfont/headscale" "github.com/juanfont/headscale"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/dockertestutil"
"github.com/juanfont/headscale/integration/integrationutil"
"github.com/ory/dockertest/v3" "github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker"
) )
const ( const (
hsicHashLength = 6 hsicHashLength = 6
dockerContextPath = "../." dockerContextPath = "../."
aclPolicyPath = "/etc/headscale/acl.hujson" aclPolicyPath = "/etc/headscale/acl.hujson"
tlsCertPath = "/etc/headscale/tls.cert"
tlsKeyPath = "/etc/headscale/tls.key"
headscaleDefaultPort = 8080
) )
var errHeadscaleStatusCodeNotOk = errors.New("headscale status code not ok") var errHeadscaleStatusCodeNotOk = errors.New("headscale status code not ok")
type HeadscaleInContainer struct { type HeadscaleInContainer struct {
hostname string hostname string
port int
pool *dockertest.Pool pool *dockertest.Pool
container *dockertest.Resource container *dockertest.Resource
network *dockertest.Network network *dockertest.Network
// optional config // optional config
port int
aclPolicy *headscale.ACLPolicy aclPolicy *headscale.ACLPolicy
env []string env []string
tlsCert []byte
tlsKey []byte
} }
type Option = func(c *HeadscaleInContainer) type Option = func(c *HeadscaleInContainer)
func WithACLPolicy(acl *headscale.ACLPolicy) Option { func WithACLPolicy(acl *headscale.ACLPolicy) Option {
return func(hsic *HeadscaleInContainer) { return func(hsic *HeadscaleInContainer) {
// TODO(kradalby): Move somewhere appropriate
hsic.env = append(hsic.env, fmt.Sprintf("HEADSCALE_ACL_POLICY_PATH=%s", aclPolicyPath))
hsic.aclPolicy = acl hsic.aclPolicy = acl
} }
} }
func WithTLS() Option {
return func(hsic *HeadscaleInContainer) {
cert, key, err := createCertificate()
if err != nil {
log.Fatalf("failed to create certificates for headscale test: %s", err)
}
// TODO(kradalby): Move somewhere appropriate
hsic.env = append(hsic.env, fmt.Sprintf("HEADSCALE_TLS_CERT_PATH=%s", tlsCertPath))
hsic.env = append(hsic.env, fmt.Sprintf("HEADSCALE_TLS_KEY_PATH=%s", tlsKeyPath))
hsic.env = append(hsic.env, "HEADSCALE_TLS_CLIENT_AUTH_MODE=disabled")
hsic.tlsCert = cert
hsic.tlsKey = key
}
}
func WithConfigEnv(configEnv map[string]string) Option { func WithConfigEnv(configEnv map[string]string) Option {
return func(hsic *HeadscaleInContainer) { return func(hsic *HeadscaleInContainer) {
env := []string{} env := []string{}
@ -59,9 +90,14 @@ func WithConfigEnv(configEnv map[string]string) Option {
} }
} }
func WithPort(port int) Option {
return func(hsic *HeadscaleInContainer) {
hsic.port = port
}
}
func New( func New(
pool *dockertest.Pool, pool *dockertest.Pool,
port int,
network *dockertest.Network, network *dockertest.Network,
opts ...Option, opts ...Option,
) (*HeadscaleInContainer, error) { ) (*HeadscaleInContainer, error) {
@ -71,11 +107,10 @@ func New(
} }
hostname := fmt.Sprintf("hs-%s", hash) hostname := fmt.Sprintf("hs-%s", hash)
portProto := fmt.Sprintf("%d/tcp", port)
hsic := &HeadscaleInContainer{ hsic := &HeadscaleInContainer{
hostname: hostname, hostname: hostname,
port: port, port: headscaleDefaultPort,
pool: pool, pool: pool,
network: network, network: network,
@ -85,9 +120,7 @@ func New(
opt(hsic) opt(hsic)
} }
if hsic.aclPolicy != nil { portProto := fmt.Sprintf("%d/tcp", hsic.port)
hsic.env = append(hsic.env, fmt.Sprintf("HEADSCALE_ACL_POLICY_PATH=%s", aclPolicyPath))
}
headscaleBuildOptions := &dockertest.BuildOptions{ headscaleBuildOptions := &dockertest.BuildOptions{
Dockerfile: "Dockerfile.debug", Dockerfile: "Dockerfile.debug",
@ -144,9 +177,25 @@ func New(
} }
} }
if hsic.hasTLS() {
err = hsic.WriteFile(tlsCertPath, hsic.tlsCert)
if err != nil {
return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err)
}
err = hsic.WriteFile(tlsKeyPath, hsic.tlsKey)
if err != nil {
return nil, fmt.Errorf("failed to write TLS key to container: %w", err)
}
}
return hsic, nil return hsic, nil
} }
func (t *HeadscaleInContainer) hasTLS() bool {
return len(t.tlsCert) != 0 && len(t.tlsKey) != 0
}
func (t *HeadscaleInContainer) Shutdown() error { func (t *HeadscaleInContainer) Shutdown() error {
return t.pool.Purge(t.container) return t.pool.Purge(t.container)
} }
@ -183,11 +232,7 @@ func (t *HeadscaleInContainer) GetPort() string {
} }
func (t *HeadscaleInContainer) GetHealthEndpoint() string { func (t *HeadscaleInContainer) GetHealthEndpoint() string {
hostEndpoint := fmt.Sprintf("%s:%d", return fmt.Sprintf("%s/health", t.GetEndpoint())
t.GetIP(),
t.port)
return fmt.Sprintf("http://%s/health", hostEndpoint)
} }
func (t *HeadscaleInContainer) GetEndpoint() string { func (t *HeadscaleInContainer) GetEndpoint() string {
@ -195,17 +240,39 @@ func (t *HeadscaleInContainer) GetEndpoint() string {
t.GetIP(), t.GetIP(),
t.port) t.port)
if t.hasTLS() {
return fmt.Sprintf("https://%s", hostEndpoint)
}
return fmt.Sprintf("http://%s", hostEndpoint) return fmt.Sprintf("http://%s", hostEndpoint)
} }
func (t *HeadscaleInContainer) GetCert() []byte {
return t.tlsCert
}
func (t *HeadscaleInContainer) GetHostname() string {
return t.hostname
}
func (t *HeadscaleInContainer) WaitForReady() error { func (t *HeadscaleInContainer) WaitForReady() error {
url := t.GetHealthEndpoint() url := t.GetHealthEndpoint()
log.Printf("waiting for headscale to be ready at %s", url) log.Printf("waiting for headscale to be ready at %s", url)
client := &http.Client{}
if t.hasTLS() {
insecureTransport := http.DefaultTransport.(*http.Transport).Clone()
insecureTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
client = &http.Client{Transport: insecureTransport}
}
return t.pool.Retry(func() error { return t.pool.Retry(func() error {
resp, err := http.Get(url) //nolint resp, err := client.Get(url) //nolint
if err != nil { if err != nil {
log.Printf("ready err: %s", err)
return fmt.Errorf("headscale is not ready: %w", err) return fmt.Errorf("headscale is not ready: %w", err)
} }
@ -292,55 +359,96 @@ func (t *HeadscaleInContainer) ListMachinesInNamespace(
} }
func (t *HeadscaleInContainer) WriteFile(path string, data []byte) error { func (t *HeadscaleInContainer) WriteFile(path string, data []byte) error {
dirPath, fileName := filepath.Split(path) return integrationutil.WriteFileToContainer(t.pool, t.container, path, data)
file := bytes.NewReader(data)
buf := bytes.NewBuffer([]byte{})
tarWriter := tar.NewWriter(buf)
header := &tar.Header{
Name: fileName,
Size: file.Size(),
// Mode: int64(stat.Mode()),
// ModTime: stat.ModTime(),
} }
err := tarWriter.WriteHeader(header) func createCertificate() ([]byte, []byte, error) {
if err != nil { // From:
return fmt.Errorf("failed write file header to tar: %w", err) // https://shaneutt.com/blog/golang-ca-and-signed-cert-go/
}
_, err = io.Copy(tarWriter, file) ca := &x509.Certificate{
if err != nil { SerialNumber: big.NewInt(2019),
return fmt.Errorf("failed to copy file to tar: %w", err) Subject: pkix.Name{
} Organization: []string{"Headscale testing INC"},
Country: []string{"NL"},
err = tarWriter.Close() Locality: []string{"Leiden"},
if err != nil {
return fmt.Errorf("failed to close tar: %w", err)
}
log.Printf("tar: %s", buf.String())
// Ensure the directory is present inside the container
_, err = t.Execute([]string{"mkdir", "-p", dirPath})
if err != nil {
return fmt.Errorf("failed to ensure directory: %w", err)
}
err = t.pool.Client.UploadToContainer(
t.container.Container.ID,
docker.UploadToContainerOptions{
NoOverwriteDirNonDir: false,
Path: dirPath,
InputStream: bytes.NewReader(buf.Bytes()),
}, },
NotBefore: time.Now(),
NotAfter: time.Now().Add(30 * time.Minute),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageClientAuth,
x509.ExtKeyUsageServerAuth,
},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
caPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, nil, err
}
// caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivKey.PublicKey, caPrivKey)
// if err != nil {
// return nil, err
// }
cert := &x509.Certificate{
SerialNumber: big.NewInt(1658),
Subject: pkix.Name{
Organization: []string{"Headscale testing INC"},
Country: []string{"NL"},
Locality: []string{"Leiden"},
},
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback},
NotBefore: time.Now(),
NotAfter: time.Now().Add(30 * time.Minute),
SubjectKeyId: []byte{1, 2, 3, 4, 6},
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature,
}
certPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, nil, err
}
certBytes, err := x509.CreateCertificate(
rand.Reader,
cert,
ca,
&certPrivKey.PublicKey,
caPrivKey,
) )
if err != nil { if err != nil {
return err return nil, nil, err
} }
return nil certPEM := new(bytes.Buffer)
err = pem.Encode(certPEM, &pem.Block{
Type: "CERTIFICATE",
Bytes: certBytes,
})
if err != nil {
return nil, nil, err
}
certPrivKeyPEM := new(bytes.Buffer)
err = pem.Encode(certPrivKeyPEM, &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey),
})
if err != nil {
return nil, nil, err
}
// serverCert, err := tls.X509KeyPair(certPEM.Bytes(), certPrivKeyPEM.Bytes())
// if err != nil {
// return nil, err
// }
return certPEM.Bytes(), certPrivKeyPEM.Bytes(), nil
} }

View file

@ -0,0 +1,77 @@
package integrationutil
import (
"archive/tar"
"bytes"
"fmt"
"io"
"log"
"path/filepath"
"github.com/juanfont/headscale/integration/dockertestutil"
"github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker"
)
func WriteFileToContainer(
pool *dockertest.Pool,
container *dockertest.Resource,
path string,
data []byte,
) error {
dirPath, fileName := filepath.Split(path)
file := bytes.NewReader(data)
buf := bytes.NewBuffer([]byte{})
tarWriter := tar.NewWriter(buf)
header := &tar.Header{
Name: fileName,
Size: file.Size(),
// Mode: int64(stat.Mode()),
// ModTime: stat.ModTime(),
}
err := tarWriter.WriteHeader(header)
if err != nil {
return fmt.Errorf("failed write file header to tar: %w", err)
}
_, err = io.Copy(tarWriter, file)
if err != nil {
return fmt.Errorf("failed to copy file to tar: %w", err)
}
err = tarWriter.Close()
if err != nil {
return fmt.Errorf("failed to close tar: %w", err)
}
log.Printf("tar: %s", buf.String())
// Ensure the directory is present inside the container
_, _, err = dockertestutil.ExecuteCommand(
container,
[]string{"mkdir", "-p", dirPath},
[]string{},
)
if err != nil {
return fmt.Errorf("failed to ensure directory: %w", err)
}
err = pool.Client.UploadToContainer(
container.Container.ID,
docker.UploadToContainerOptions{
NoOverwriteDirNonDir: false,
Path: dirPath,
InputStream: bytes.NewReader(buf.Bytes()),
},
)
if err != nil {
return err
}
return nil
}

View file

@ -150,20 +150,8 @@ func (s *Scenario) Namespaces() []string {
// Note: These functions assume that there is a _single_ headscale instance for now // Note: These functions assume that there is a _single_ headscale instance for now
// TODO(kradalby): make port and headscale configurable, multiple instances support? // TODO(kradalby): make port and headscale configurable, multiple instances support?
func (s *Scenario) StartHeadscale() error { func (s *Scenario) StartHeadscale(opts ...hsic.Option) error {
headscale, err := hsic.New(s.pool, headscalePort, s.network, headscale, err := hsic.New(s.pool, s.network, opts...)
hsic.WithACLPolicy(
&headscale.ACLPolicy{
ACLs: []headscale.ACL{
{
Action: "accept",
Sources: []string{"*"},
Destinations: []string{"*:*"},
},
},
},
),
)
if err != nil { if err != nil {
return fmt.Errorf("failed to create headscale container: %w", err) return fmt.Errorf("failed to create headscale container: %w", err)
} }
@ -228,10 +216,22 @@ func (s *Scenario) CreateTailscaleNodesInNamespace(
defer namespace.createWaitGroup.Done() defer namespace.createWaitGroup.Done()
// TODO(kradalby): error handle this // TODO(kradalby): error handle this
tsClient, err := tsic.New(s.pool, version, s.network) tsClient, err := tsic.New(
s.pool,
version,
s.network,
tsic.WithHeadscaleTLS(s.Headscale().GetCert()),
tsic.WithHeadscaleName(s.Headscale().GetHostname()),
)
if err != nil { if err != nil {
// return fmt.Errorf("failed to add tailscale node: %w", err) // return fmt.Errorf("failed to add tailscale node: %w", err)
log.Printf("failed to add tailscale node: %s", err) log.Printf("failed to create tailscale node: %s", err)
}
err = tsClient.WaitForReady()
if err != nil {
// return fmt.Errorf("failed to add tailscale node: %w", err)
log.Printf("failed to wait for tailscaled: %s", err)
} }
namespace.Clients[tsClient.Hostname()] = tsClient namespace.Clients[tsClient.Hostname()] = tsClient
@ -306,8 +306,8 @@ func (s *Scenario) WaitForTailscaleSync() error {
// CreateHeadscaleEnv is a conventient method returning a set up Headcale // CreateHeadscaleEnv is a conventient method returning a set up Headcale
// test environment with nodes of all versions, joined to the server with X // test environment with nodes of all versions, joined to the server with X
// namespaces. // namespaces.
func (s *Scenario) CreateHeadscaleEnv(namespaces map[string]int) error { func (s *Scenario) CreateHeadscaleEnv(namespaces map[string]int, opts ...hsic.Option) error {
err := s.StartHeadscale() err := s.StartHeadscale(opts...)
if err != nil { if err != nil {
return err return err
} }

View file

@ -12,6 +12,7 @@ import (
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
"github.com/juanfont/headscale" "github.com/juanfont/headscale"
"github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/dockertestutil"
"github.com/juanfont/headscale/integration/integrationutil"
"github.com/ory/dockertest/v3" "github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker" "github.com/ory/dockertest/v3/docker"
"tailscale.com/ipn/ipnstate" "tailscale.com/ipn/ipnstate"
@ -20,6 +21,7 @@ import (
const ( const (
tsicHashLength = 6 tsicHashLength = 6
dockerContextPath = "../." dockerContextPath = "../."
headscaleCertPath = "/usr/local/share/ca-certificates/headscale.crt"
) )
var ( var (
@ -41,12 +43,51 @@ type TailscaleInContainer struct {
// "cache" // "cache"
ips []netip.Addr ips []netip.Addr
fqdn string fqdn string
// optional config
headscaleCert []byte
headscaleHostname string
}
type Option = func(c *TailscaleInContainer)
func WithHeadscaleTLS(cert []byte) Option {
return func(tsic *TailscaleInContainer) {
tsic.headscaleCert = cert
}
}
func WithOrCreateNetwork(network *dockertest.Network) Option {
return func(tsic *TailscaleInContainer) {
if network != nil {
tsic.network = network
return
}
network, err := dockertestutil.GetFirstOrCreateNetwork(
tsic.pool,
fmt.Sprintf("%s-network", tsic.hostname),
)
if err != nil {
log.Fatalf("failed to create network: %s", err)
}
tsic.network = network
}
}
func WithHeadscaleName(hsName string) Option {
return func(tsic *TailscaleInContainer) {
tsic.headscaleHostname = hsName
}
} }
func New( func New(
pool *dockertest.Pool, pool *dockertest.Pool,
version string, version string,
network *dockertest.Network, network *dockertest.Network,
opts ...Option,
) (*TailscaleInContainer, error) { ) (*TailscaleInContainer, error) {
hash, err := headscale.GenerateRandomStringDNSSafe(tsicHashLength) hash, err := headscale.GenerateRandomStringDNSSafe(tsicHashLength)
if err != nil { if err != nil {
@ -55,20 +96,38 @@ func New(
hostname := fmt.Sprintf("ts-%s-%s", strings.ReplaceAll(version, ".", "-"), hash) hostname := fmt.Sprintf("ts-%s-%s", strings.ReplaceAll(version, ".", "-"), hash)
// TODO(kradalby): figure out why we need to "refresh" the network here. tsic := &TailscaleInContainer{
// network, err = dockertestutil.GetFirstOrCreateNetwork(pool, network.Network.Name) version: version,
// if err != nil { hostname: hostname,
// return nil, err
// } pool: pool,
network: network,
}
for _, opt := range opts {
opt(tsic)
}
tailscaleOptions := &dockertest.RunOptions{ tailscaleOptions := &dockertest.RunOptions{
Name: hostname, Name: hostname,
Networks: []*dockertest.Network{network}, Networks: []*dockertest.Network{network},
Cmd: []string{ // Cmd: []string{
"tailscaled", "--tun=tsdev", // "tailscaled", "--tun=tsdev",
// },
Entrypoint: []string{
"/bin/bash",
"-c",
"/bin/sleep 3 ; update-ca-certificates ; tailscaled --tun=tsdev",
}, },
} }
if tsic.headscaleHostname != "" {
tailscaleOptions.ExtraHosts = []string{
"host.docker.internal:host-gateway",
fmt.Sprintf("%s:host-gateway", tsic.headscaleHostname),
}
}
// dockertest isnt very good at handling containers that has already // dockertest isnt very good at handling containers that has already
// been created, this is an attempt to make sure this container isnt // been created, this is an attempt to make sure this container isnt
// present. // present.
@ -89,14 +148,20 @@ func New(
} }
log.Printf("Created %s container\n", hostname) log.Printf("Created %s container\n", hostname)
return &TailscaleInContainer{ tsic.container = container
version: version,
hostname: hostname,
pool: pool, if tsic.hasTLS() {
container: container, err = tsic.WriteFile(headscaleCertPath, tsic.headscaleCert)
network: network, if err != nil {
}, nil return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err)
}
}
return tsic, nil
}
func (t *TailscaleInContainer) hasTLS() bool {
return len(t.headscaleCert) != 0
} }
func (t *TailscaleInContainer) Shutdown() error { func (t *TailscaleInContainer) Shutdown() error {
@ -111,6 +176,19 @@ func (t *TailscaleInContainer) Version() string {
return t.version return t.version
} }
func (t *TailscaleInContainer) WaitForReady() error {
return t.pool.Retry(func() error {
// If tailscaled has not started yet, this will return a non-zero
// status code
_, err := t.Execute([]string{"tailscale", "status"})
if err != nil {
return err
}
return nil
})
}
func (t *TailscaleInContainer) Execute( func (t *TailscaleInContainer) Execute(
command []string, command []string,
) (string, string, error) { ) (string, string, error) {
@ -318,6 +396,10 @@ func (t *TailscaleInContainer) Ping(hostnameOrIP string) error {
}) })
} }
func (t *TailscaleInContainer) WriteFile(path string, data []byte) error {
return integrationutil.WriteFileToContainer(t.pool, t.container, path, data)
}
func createTailscaleBuildOptions(version string) *dockertest.BuildOptions { func createTailscaleBuildOptions(version string) *dockertest.BuildOptions {
var tailscaleBuildOptions *dockertest.BuildOptions var tailscaleBuildOptions *dockertest.BuildOptions
switch version { switch version {