can the policy be typed at parsetime?

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2024-10-04 17:39:40 +02:00
parent 2c1ad6d11a
commit 9f6c8ab62e
No known key found for this signature in database
2 changed files with 460 additions and 0 deletions

365
hscontrol/policyv2/types.go Normal file
View file

@ -0,0 +1,365 @@
package policyv2
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"net/netip"
"strconv"
"strings"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"tailscale.com/tailcfg"
)
// Username is a string that represents a username, it must contain an @.
type Username string
func (u Username) Valid() bool {
return strings.Contains(string(u), "@")
}
func (u Username) UnmarshalJSON(b []byte) error {
u = Username(strings.Trim(string(b), `"`))
if !u.Valid() {
return fmt.Errorf("invalid username %q", u)
}
return nil
}
// Group is a special string which is always prefixed with `group:`
type Group string
func (g Group) Valid() bool {
return strings.HasPrefix(string(g), "group:")
}
func (g Group) UnmarshalJSON(b []byte) error {
g = Group(strings.Trim(string(b), `"`))
if !g.Valid() {
return fmt.Errorf("invalid group %q", g)
}
return nil
}
// Tag is a special string which is always prefixed with `tag:`
type Tag string
func (t Tag) Valid() bool {
return strings.HasPrefix(string(t), "tag:")
}
func (t Tag) UnmarshalJSON(b []byte) error {
t = Tag(strings.Trim(string(b), `"`))
if !t.Valid() {
return fmt.Errorf("invalid tag %q", t)
}
return nil
}
// Host is a string that represents a hostname.
type Host string
func (h Host) Valid() bool {
return true
}
func (h Host) UnmarshalJSON(b []byte) error {
h = Host(strings.Trim(string(b), `"`))
if !h.Valid() {
return fmt.Errorf("invalid host %q", h)
}
return nil
}
type Addr netip.Addr
func (a Addr) Valid() bool {
return netip.Addr(a).IsValid()
}
func (a Addr) UnmarshalJSON(b []byte) error {
a = Addr(netip.Addr{})
if err := json.Unmarshal(b, (netip.Addr)(a)); err != nil {
return err
}
if !a.Valid() {
return fmt.Errorf("invalid address %v", a)
}
return nil
}
type Prefix netip.Prefix
func (p Prefix) Valid() bool {
return netip.Prefix(p).IsValid()
}
func (p Prefix) UnmarshalJSON(b []byte) error {
p = Prefix(netip.Prefix{})
if err := json.Unmarshal(b, (netip.Prefix)(p)); err != nil {
return err
}
if !p.Valid() {
return fmt.Errorf("invalid prefix %v", p)
}
return nil
}
// AutoGroup is a special string which is always prefixed with `autogroup:`
type AutoGroup string
func (ag AutoGroup) Valid() bool {
return strings.HasPrefix(string(ag), "autogroup:")
}
func (ag AutoGroup) UnmarshalJSON(b []byte) error {
ag = AutoGroup(strings.Trim(string(b), `"`))
if !ag.Valid() {
return fmt.Errorf("invalid autogroup %q", ag)
}
return nil
}
type Alias interface {
Valid() bool
UnmarshalJSON([]byte) error
}
type AliasWithPorts struct {
Alias
Ports []tailcfg.PortRange
}
func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error {
// TODO(kradalby): use encoding/json/v2 (go-json-experiment)
dec := json.NewDecoder(bytes.NewReader(b))
var v any
if err := dec.Decode(&v); err != nil {
return err
}
switch vs := v.(type) {
case string:
var portsPart string
var err error
if strings.Contains(vs, ":") {
vs, portsPart, err = splitDestination(vs)
if err != nil {
return err
}
ports, err := parsePorts(portsPart)
if err != nil {
return err
}
ve.Ports = ports
}
ve.Alias = parseAlias(vs)
default:
return fmt.Errorf("type %T not supported", vs)
}
return nil
}
func parseAlias(vs string) Alias {
// case netip.Addr:
// ve.Alias = Addr(val)
// case netip.Prefix:
// ve.Alias = Prefix(val)
if addr, err := netip.ParseAddr(vs); err == nil {
return Addr(addr)
}
if prefix, err := netip.ParsePrefix(vs); err == nil {
return Prefix(prefix)
}
switch {
case strings.Contains(vs, "@"):
return Username(vs)
case strings.HasPrefix(vs, "group:"):
return Group(vs)
case strings.HasPrefix(vs, "tag:"):
return Tag(vs)
case strings.HasPrefix(vs, "autogroup:"):
return AutoGroup(vs)
}
return Host(vs)
}
// AliasEnc is used to deserialize a Alias.
type AliasEnc struct{ Alias }
func (ve *AliasEnc) UnmarshalJSON(b []byte) error {
// TODO(kradalby): use encoding/json/v2 (go-json-experiment)
dec := json.NewDecoder(bytes.NewReader(b))
var v any
if err := dec.Decode(&v); err != nil {
return err
}
switch val := v.(type) {
case string:
ve.Alias = parseAlias(val)
default:
return fmt.Errorf("type %T not supported", val)
}
return nil
}
type Aliases []Alias
func (a *Aliases) UnmarshalJSON(b []byte) error {
var aliases []AliasEnc
err := json.Unmarshal(b, &aliases)
if err != nil {
return err
}
*a = make([]Alias, len(aliases))
for i, alias := range aliases {
(*a)[i] = alias.Alias
}
return nil
}
// UserEntity is an interface that represents something that can
// return a list of users:
// - Username
// - Group
// - AutoGroup
type UserEntity interface {
Users() []Username
UnmarshalJSON([]byte) error
}
// Groups are a map of Group to a list of Username.
type Groups map[Group][]Username
// Hosts are alias for IP addresses or subnets.
type Hosts map[Host]netip.Prefix
// TagOwners are a map of Tag to a list of the UserEntities that own the tag.
type TagOwners map[Tag][]UserEntity
type AutoApprovers struct {
Routes map[string][]string `json:"routes"`
ExitNode []string `json:"exitNode"`
}
type ACL struct {
Action string `json:"action"`
Protocol string `json:"proto"`
Sources Aliases `json:"src"`
Destinations []AliasWithPorts `json:"dst"`
}
// ACLPolicy represents a Tailscale ACL Policy.
type ACLPolicy struct {
Groups Groups `json:"groups"`
// Hosts Hosts `json:"hosts"`
TagOwners TagOwners `json:"tagOwners"`
ACLs []ACL `json:"acls"`
AutoApprovers AutoApprovers `json:"autoApprovers"`
// SSHs []SSH `json:"ssh"`
}
const (
expectedTokenItems = 2
)
// TODO(kradalby): copy tests from parseDestination in policy
func splitDestination(dest string) (string, string, error) {
var tokens []string
// Check if there is a IPv4/6:Port combination, IPv6 has more than
// three ":".
tokens = strings.Split(dest, ":")
if len(tokens) < expectedTokenItems || len(tokens) > 3 {
port := tokens[len(tokens)-1]
maybeIPv6Str := strings.TrimSuffix(dest, ":"+port)
filteredMaybeIPv6Str := maybeIPv6Str
if strings.Contains(maybeIPv6Str, "/") {
networkParts := strings.Split(maybeIPv6Str, "/")
filteredMaybeIPv6Str = networkParts[0]
}
if maybeIPv6, err := netip.ParseAddr(filteredMaybeIPv6Str); err != nil && !maybeIPv6.Is6() {
return "", "", fmt.Errorf(
"failed to split destination: %v",
tokens,
)
} else {
tokens = []string{maybeIPv6Str, port}
}
}
var alias string
// We can have here stuff like:
// git-server:*
// 192.168.1.0/24:22
// fd7a:115c:a1e0::2:22
// fd7a:115c:a1e0::2/128:22
// tag:montreal-webserver:80,443
// tag:api-server:443
// example-host-1:*
if len(tokens) == expectedTokenItems {
alias = tokens[0]
} else {
alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1])
}
return alias, tokens[len(tokens)-1], nil
}
// TODO(kradalby): write/copy tests from expandPorts in policy
func parsePorts(portsStr string) ([]tailcfg.PortRange, error) {
if portsStr == "*" {
return []tailcfg.PortRange{
tailcfg.PortRangeAny,
}, nil
}
var ports []tailcfg.PortRange
for _, portStr := range strings.Split(portsStr, ",") {
log.Trace().Msgf("parsing portstring: %s", portStr)
rang := strings.Split(portStr, "-")
switch len(rang) {
case 1:
port, err := strconv.ParseUint(rang[0], util.Base10, util.BitSize16)
if err != nil {
return nil, err
}
ports = append(ports, tailcfg.PortRange{
First: uint16(port),
Last: uint16(port),
})
case expectedTokenItems:
start, err := strconv.ParseUint(rang[0], util.Base10, util.BitSize16)
if err != nil {
return nil, err
}
last, err := strconv.ParseUint(rang[1], util.Base10, util.BitSize16)
if err != nil {
return nil, err
}
ports = append(ports, tailcfg.PortRange{
First: uint16(start),
Last: uint16(last),
})
default:
return nil, errors.New("invalid ports")
}
}
return ports, nil
}

View file

@ -0,0 +1,95 @@
package policyv2
import (
"encoding/json"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/tailscale/hujson"
"tailscale.com/tailcfg"
)
func TestUnmarshalPolicy(t *testing.T) {
tests := []struct {
name string
input string
want *ACLPolicy
wantErr error
}{
{
name: "empty",
input: "{}",
want: &ACLPolicy{},
},
{
name: "basic-types",
input: `
{
"groups": {
"group:example": [
"testuser@headscale.net",
],
},
"hosts": {
"host-1": "100.100.100.100",
"subnet-1": "100.100.101.100/24",
},
"acls": [
{
"action": "accept",
"src": [
"group:example",
],
"dst": [
"host-1:*",
],
},
],
}
`,
want: &ACLPolicy{
Groups: Groups{
Group("group:example"): []Username{"testuser@headscale.net"},
},
ACLs: []ACL{
{
Action: "accept",
Sources: Aliases{
Group("group:example"),
},
Destinations: []AliasWithPorts{
{
Alias: Host("host-1"),
Ports: []tailcfg.PortRange{tailcfg.PortRangeAny},
},
},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var policy ACLPolicy
ast, err := hujson.Parse([]byte(tt.input))
if err != nil {
t.Fatalf("parsing hujson: %s", err)
}
ast.Standardize()
acl := ast.Pack()
if err := json.Unmarshal(acl, &policy); err != nil {
// TODO: check error type
t.Fatalf("unmarshaling json: %s", err)
}
if diff := cmp.Diff(tt.want, &policy); diff != "" {
t.Fatalf("unexpected policy (-want +got):\n%s", diff)
}
})
}
}