mirror of
https://github.com/juanfont/headscale.git
synced 2024-11-26 08:53:05 +00:00
can the policy be typed at parsetime?
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
2c1ad6d11a
commit
9f6c8ab62e
2 changed files with 460 additions and 0 deletions
365
hscontrol/policyv2/types.go
Normal file
365
hscontrol/policyv2/types.go
Normal file
|
@ -0,0 +1,365 @@
|
||||||
|
package policyv2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Username is a string that represents a username, it must contain an @.
|
||||||
|
type Username string
|
||||||
|
|
||||||
|
func (u Username) Valid() bool {
|
||||||
|
return strings.Contains(string(u), "@")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u Username) UnmarshalJSON(b []byte) error {
|
||||||
|
u = Username(strings.Trim(string(b), `"`))
|
||||||
|
if !u.Valid() {
|
||||||
|
return fmt.Errorf("invalid username %q", u)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Group is a special string which is always prefixed with `group:`
|
||||||
|
type Group string
|
||||||
|
|
||||||
|
func (g Group) Valid() bool {
|
||||||
|
return strings.HasPrefix(string(g), "group:")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g Group) UnmarshalJSON(b []byte) error {
|
||||||
|
g = Group(strings.Trim(string(b), `"`))
|
||||||
|
if !g.Valid() {
|
||||||
|
return fmt.Errorf("invalid group %q", g)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tag is a special string which is always prefixed with `tag:`
|
||||||
|
type Tag string
|
||||||
|
|
||||||
|
func (t Tag) Valid() bool {
|
||||||
|
return strings.HasPrefix(string(t), "tag:")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Tag) UnmarshalJSON(b []byte) error {
|
||||||
|
t = Tag(strings.Trim(string(b), `"`))
|
||||||
|
if !t.Valid() {
|
||||||
|
return fmt.Errorf("invalid tag %q", t)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Host is a string that represents a hostname.
|
||||||
|
type Host string
|
||||||
|
|
||||||
|
func (h Host) Valid() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h Host) UnmarshalJSON(b []byte) error {
|
||||||
|
h = Host(strings.Trim(string(b), `"`))
|
||||||
|
if !h.Valid() {
|
||||||
|
return fmt.Errorf("invalid host %q", h)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type Addr netip.Addr
|
||||||
|
|
||||||
|
func (a Addr) Valid() bool {
|
||||||
|
return netip.Addr(a).IsValid()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a Addr) UnmarshalJSON(b []byte) error {
|
||||||
|
a = Addr(netip.Addr{})
|
||||||
|
if err := json.Unmarshal(b, (netip.Addr)(a)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !a.Valid() {
|
||||||
|
return fmt.Errorf("invalid address %v", a)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type Prefix netip.Prefix
|
||||||
|
|
||||||
|
func (p Prefix) Valid() bool {
|
||||||
|
return netip.Prefix(p).IsValid()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p Prefix) UnmarshalJSON(b []byte) error {
|
||||||
|
p = Prefix(netip.Prefix{})
|
||||||
|
if err := json.Unmarshal(b, (netip.Prefix)(p)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !p.Valid() {
|
||||||
|
return fmt.Errorf("invalid prefix %v", p)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AutoGroup is a special string which is always prefixed with `autogroup:`
|
||||||
|
type AutoGroup string
|
||||||
|
|
||||||
|
func (ag AutoGroup) Valid() bool {
|
||||||
|
return strings.HasPrefix(string(ag), "autogroup:")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ag AutoGroup) UnmarshalJSON(b []byte) error {
|
||||||
|
ag = AutoGroup(strings.Trim(string(b), `"`))
|
||||||
|
if !ag.Valid() {
|
||||||
|
return fmt.Errorf("invalid autogroup %q", ag)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type Alias interface {
|
||||||
|
Valid() bool
|
||||||
|
UnmarshalJSON([]byte) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type AliasWithPorts struct {
|
||||||
|
Alias
|
||||||
|
Ports []tailcfg.PortRange
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error {
|
||||||
|
// TODO(kradalby): use encoding/json/v2 (go-json-experiment)
|
||||||
|
dec := json.NewDecoder(bytes.NewReader(b))
|
||||||
|
var v any
|
||||||
|
if err := dec.Decode(&v); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch vs := v.(type) {
|
||||||
|
case string:
|
||||||
|
var portsPart string
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if strings.Contains(vs, ":") {
|
||||||
|
vs, portsPart, err = splitDestination(vs)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ports, err := parsePorts(portsPart)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ve.Ports = ports
|
||||||
|
}
|
||||||
|
|
||||||
|
ve.Alias = parseAlias(vs)
|
||||||
|
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("type %T not supported", vs)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseAlias(vs string) Alias {
|
||||||
|
// case netip.Addr:
|
||||||
|
// ve.Alias = Addr(val)
|
||||||
|
// case netip.Prefix:
|
||||||
|
// ve.Alias = Prefix(val)
|
||||||
|
if addr, err := netip.ParseAddr(vs); err == nil {
|
||||||
|
return Addr(addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if prefix, err := netip.ParsePrefix(vs); err == nil {
|
||||||
|
return Prefix(prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case strings.Contains(vs, "@"):
|
||||||
|
return Username(vs)
|
||||||
|
case strings.HasPrefix(vs, "group:"):
|
||||||
|
return Group(vs)
|
||||||
|
case strings.HasPrefix(vs, "tag:"):
|
||||||
|
return Tag(vs)
|
||||||
|
case strings.HasPrefix(vs, "autogroup:"):
|
||||||
|
return AutoGroup(vs)
|
||||||
|
}
|
||||||
|
return Host(vs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AliasEnc is used to deserialize a Alias.
|
||||||
|
type AliasEnc struct{ Alias }
|
||||||
|
|
||||||
|
func (ve *AliasEnc) UnmarshalJSON(b []byte) error {
|
||||||
|
// TODO(kradalby): use encoding/json/v2 (go-json-experiment)
|
||||||
|
dec := json.NewDecoder(bytes.NewReader(b))
|
||||||
|
var v any
|
||||||
|
if err := dec.Decode(&v); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
switch val := v.(type) {
|
||||||
|
case string:
|
||||||
|
ve.Alias = parseAlias(val)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("type %T not supported", val)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type Aliases []Alias
|
||||||
|
|
||||||
|
func (a *Aliases) UnmarshalJSON(b []byte) error {
|
||||||
|
var aliases []AliasEnc
|
||||||
|
err := json.Unmarshal(b, &aliases)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*a = make([]Alias, len(aliases))
|
||||||
|
for i, alias := range aliases {
|
||||||
|
(*a)[i] = alias.Alias
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserEntity is an interface that represents something that can
|
||||||
|
// return a list of users:
|
||||||
|
// - Username
|
||||||
|
// - Group
|
||||||
|
// - AutoGroup
|
||||||
|
type UserEntity interface {
|
||||||
|
Users() []Username
|
||||||
|
UnmarshalJSON([]byte) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Groups are a map of Group to a list of Username.
|
||||||
|
type Groups map[Group][]Username
|
||||||
|
|
||||||
|
// Hosts are alias for IP addresses or subnets.
|
||||||
|
type Hosts map[Host]netip.Prefix
|
||||||
|
|
||||||
|
// TagOwners are a map of Tag to a list of the UserEntities that own the tag.
|
||||||
|
type TagOwners map[Tag][]UserEntity
|
||||||
|
|
||||||
|
type AutoApprovers struct {
|
||||||
|
Routes map[string][]string `json:"routes"`
|
||||||
|
ExitNode []string `json:"exitNode"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ACL struct {
|
||||||
|
Action string `json:"action"`
|
||||||
|
Protocol string `json:"proto"`
|
||||||
|
Sources Aliases `json:"src"`
|
||||||
|
Destinations []AliasWithPorts `json:"dst"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ACLPolicy represents a Tailscale ACL Policy.
|
||||||
|
type ACLPolicy struct {
|
||||||
|
Groups Groups `json:"groups"`
|
||||||
|
// Hosts Hosts `json:"hosts"`
|
||||||
|
TagOwners TagOwners `json:"tagOwners"`
|
||||||
|
ACLs []ACL `json:"acls"`
|
||||||
|
AutoApprovers AutoApprovers `json:"autoApprovers"`
|
||||||
|
// SSHs []SSH `json:"ssh"`
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
expectedTokenItems = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
// TODO(kradalby): copy tests from parseDestination in policy
|
||||||
|
func splitDestination(dest string) (string, string, error) {
|
||||||
|
var tokens []string
|
||||||
|
|
||||||
|
// Check if there is a IPv4/6:Port combination, IPv6 has more than
|
||||||
|
// three ":".
|
||||||
|
tokens = strings.Split(dest, ":")
|
||||||
|
if len(tokens) < expectedTokenItems || len(tokens) > 3 {
|
||||||
|
port := tokens[len(tokens)-1]
|
||||||
|
|
||||||
|
maybeIPv6Str := strings.TrimSuffix(dest, ":"+port)
|
||||||
|
|
||||||
|
filteredMaybeIPv6Str := maybeIPv6Str
|
||||||
|
if strings.Contains(maybeIPv6Str, "/") {
|
||||||
|
networkParts := strings.Split(maybeIPv6Str, "/")
|
||||||
|
filteredMaybeIPv6Str = networkParts[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
if maybeIPv6, err := netip.ParseAddr(filteredMaybeIPv6Str); err != nil && !maybeIPv6.Is6() {
|
||||||
|
return "", "", fmt.Errorf(
|
||||||
|
"failed to split destination: %v",
|
||||||
|
tokens,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
tokens = []string{maybeIPv6Str, port}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var alias string
|
||||||
|
// We can have here stuff like:
|
||||||
|
// git-server:*
|
||||||
|
// 192.168.1.0/24:22
|
||||||
|
// fd7a:115c:a1e0::2:22
|
||||||
|
// fd7a:115c:a1e0::2/128:22
|
||||||
|
// tag:montreal-webserver:80,443
|
||||||
|
// tag:api-server:443
|
||||||
|
// example-host-1:*
|
||||||
|
if len(tokens) == expectedTokenItems {
|
||||||
|
alias = tokens[0]
|
||||||
|
} else {
|
||||||
|
alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
return alias, tokens[len(tokens)-1], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(kradalby): write/copy tests from expandPorts in policy
|
||||||
|
func parsePorts(portsStr string) ([]tailcfg.PortRange, error) {
|
||||||
|
if portsStr == "*" {
|
||||||
|
return []tailcfg.PortRange{
|
||||||
|
tailcfg.PortRangeAny,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var ports []tailcfg.PortRange
|
||||||
|
for _, portStr := range strings.Split(portsStr, ",") {
|
||||||
|
log.Trace().Msgf("parsing portstring: %s", portStr)
|
||||||
|
rang := strings.Split(portStr, "-")
|
||||||
|
switch len(rang) {
|
||||||
|
case 1:
|
||||||
|
port, err := strconv.ParseUint(rang[0], util.Base10, util.BitSize16)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ports = append(ports, tailcfg.PortRange{
|
||||||
|
First: uint16(port),
|
||||||
|
Last: uint16(port),
|
||||||
|
})
|
||||||
|
|
||||||
|
case expectedTokenItems:
|
||||||
|
start, err := strconv.ParseUint(rang[0], util.Base10, util.BitSize16)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
last, err := strconv.ParseUint(rang[1], util.Base10, util.BitSize16)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ports = append(ports, tailcfg.PortRange{
|
||||||
|
First: uint16(start),
|
||||||
|
Last: uint16(last),
|
||||||
|
})
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, errors.New("invalid ports")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ports, nil
|
||||||
|
}
|
95
hscontrol/policyv2/types_test.go
Normal file
95
hscontrol/policyv2/types_test.go
Normal file
|
@ -0,0 +1,95 @@
|
||||||
|
package policyv2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/tailscale/hujson"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUnmarshalPolicy(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
want *ACLPolicy
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
input: "{}",
|
||||||
|
want: &ACLPolicy{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "basic-types",
|
||||||
|
input: `
|
||||||
|
{
|
||||||
|
"groups": {
|
||||||
|
"group:example": [
|
||||||
|
"testuser@headscale.net",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
|
||||||
|
"hosts": {
|
||||||
|
"host-1": "100.100.100.100",
|
||||||
|
"subnet-1": "100.100.101.100/24",
|
||||||
|
},
|
||||||
|
|
||||||
|
"acls": [
|
||||||
|
{
|
||||||
|
"action": "accept",
|
||||||
|
"src": [
|
||||||
|
"group:example",
|
||||||
|
],
|
||||||
|
"dst": [
|
||||||
|
"host-1:*",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
`,
|
||||||
|
want: &ACLPolicy{
|
||||||
|
Groups: Groups{
|
||||||
|
Group("group:example"): []Username{"testuser@headscale.net"},
|
||||||
|
},
|
||||||
|
ACLs: []ACL{
|
||||||
|
{
|
||||||
|
Action: "accept",
|
||||||
|
Sources: Aliases{
|
||||||
|
Group("group:example"),
|
||||||
|
},
|
||||||
|
Destinations: []AliasWithPorts{
|
||||||
|
{
|
||||||
|
Alias: Host("host-1"),
|
||||||
|
Ports: []tailcfg.PortRange{tailcfg.PortRangeAny},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var policy ACLPolicy
|
||||||
|
ast, err := hujson.Parse([]byte(tt.input))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parsing hujson: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ast.Standardize()
|
||||||
|
acl := ast.Pack()
|
||||||
|
|
||||||
|
if err := json.Unmarshal(acl, &policy); err != nil {
|
||||||
|
// TODO: check error type
|
||||||
|
t.Fatalf("unmarshaling json: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.want, &policy); diff != "" {
|
||||||
|
t.Fatalf("unexpected policy (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue