mirror of
https://github.com/juanfont/headscale.git
synced 2024-11-29 18:33:05 +00:00
can the policy be typed at parsetime?
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
cc42fc394a
commit
eca5f79acc
2 changed files with 460 additions and 0 deletions
365
hscontrol/policyv2/types.go
Normal file
365
hscontrol/policyv2/types.go
Normal file
|
@ -0,0 +1,365 @@
|
|||
package policyv2
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
// Username is a string that represents a username, it must contain an @.
|
||||
type Username string
|
||||
|
||||
func (u Username) Valid() bool {
|
||||
return strings.Contains(string(u), "@")
|
||||
}
|
||||
|
||||
func (u Username) UnmarshalJSON(b []byte) error {
|
||||
u = Username(strings.Trim(string(b), `"`))
|
||||
if !u.Valid() {
|
||||
return fmt.Errorf("invalid username %q", u)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Group is a special string which is always prefixed with `group:`
|
||||
type Group string
|
||||
|
||||
func (g Group) Valid() bool {
|
||||
return strings.HasPrefix(string(g), "group:")
|
||||
}
|
||||
|
||||
func (g Group) UnmarshalJSON(b []byte) error {
|
||||
g = Group(strings.Trim(string(b), `"`))
|
||||
if !g.Valid() {
|
||||
return fmt.Errorf("invalid group %q", g)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Tag is a special string which is always prefixed with `tag:`
|
||||
type Tag string
|
||||
|
||||
func (t Tag) Valid() bool {
|
||||
return strings.HasPrefix(string(t), "tag:")
|
||||
}
|
||||
|
||||
func (t Tag) UnmarshalJSON(b []byte) error {
|
||||
t = Tag(strings.Trim(string(b), `"`))
|
||||
if !t.Valid() {
|
||||
return fmt.Errorf("invalid tag %q", t)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Host is a string that represents a hostname.
|
||||
type Host string
|
||||
|
||||
func (h Host) Valid() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (h Host) UnmarshalJSON(b []byte) error {
|
||||
h = Host(strings.Trim(string(b), `"`))
|
||||
if !h.Valid() {
|
||||
return fmt.Errorf("invalid host %q", h)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Addr netip.Addr
|
||||
|
||||
func (a Addr) Valid() bool {
|
||||
return netip.Addr(a).IsValid()
|
||||
}
|
||||
|
||||
func (a Addr) UnmarshalJSON(b []byte) error {
|
||||
a = Addr(netip.Addr{})
|
||||
if err := json.Unmarshal(b, (netip.Addr)(a)); err != nil {
|
||||
return err
|
||||
}
|
||||
if !a.Valid() {
|
||||
return fmt.Errorf("invalid address %v", a)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Prefix netip.Prefix
|
||||
|
||||
func (p Prefix) Valid() bool {
|
||||
return netip.Prefix(p).IsValid()
|
||||
}
|
||||
|
||||
func (p Prefix) UnmarshalJSON(b []byte) error {
|
||||
p = Prefix(netip.Prefix{})
|
||||
if err := json.Unmarshal(b, (netip.Prefix)(p)); err != nil {
|
||||
return err
|
||||
}
|
||||
if !p.Valid() {
|
||||
return fmt.Errorf("invalid prefix %v", p)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AutoGroup is a special string which is always prefixed with `autogroup:`
|
||||
type AutoGroup string
|
||||
|
||||
func (ag AutoGroup) Valid() bool {
|
||||
return strings.HasPrefix(string(ag), "autogroup:")
|
||||
}
|
||||
|
||||
func (ag AutoGroup) UnmarshalJSON(b []byte) error {
|
||||
ag = AutoGroup(strings.Trim(string(b), `"`))
|
||||
if !ag.Valid() {
|
||||
return fmt.Errorf("invalid autogroup %q", ag)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Alias interface {
|
||||
Valid() bool
|
||||
UnmarshalJSON([]byte) error
|
||||
}
|
||||
|
||||
type AliasWithPorts struct {
|
||||
Alias
|
||||
Ports []tailcfg.PortRange
|
||||
}
|
||||
|
||||
func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error {
|
||||
// TODO(kradalby): use encoding/json/v2 (go-json-experiment)
|
||||
dec := json.NewDecoder(bytes.NewReader(b))
|
||||
var v any
|
||||
if err := dec.Decode(&v); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch vs := v.(type) {
|
||||
case string:
|
||||
var portsPart string
|
||||
var err error
|
||||
|
||||
if strings.Contains(vs, ":") {
|
||||
vs, portsPart, err = splitDestination(vs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ports, err := parsePorts(portsPart)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ve.Ports = ports
|
||||
}
|
||||
|
||||
ve.Alias = parseAlias(vs)
|
||||
|
||||
default:
|
||||
return fmt.Errorf("type %T not supported", vs)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseAlias(vs string) Alias {
|
||||
// case netip.Addr:
|
||||
// ve.Alias = Addr(val)
|
||||
// case netip.Prefix:
|
||||
// ve.Alias = Prefix(val)
|
||||
if addr, err := netip.ParseAddr(vs); err == nil {
|
||||
return Addr(addr)
|
||||
}
|
||||
|
||||
if prefix, err := netip.ParsePrefix(vs); err == nil {
|
||||
return Prefix(prefix)
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.Contains(vs, "@"):
|
||||
return Username(vs)
|
||||
case strings.HasPrefix(vs, "group:"):
|
||||
return Group(vs)
|
||||
case strings.HasPrefix(vs, "tag:"):
|
||||
return Tag(vs)
|
||||
case strings.HasPrefix(vs, "autogroup:"):
|
||||
return AutoGroup(vs)
|
||||
}
|
||||
return Host(vs)
|
||||
}
|
||||
|
||||
// AliasEnc is used to deserialize a Alias.
|
||||
type AliasEnc struct{ Alias }
|
||||
|
||||
func (ve *AliasEnc) UnmarshalJSON(b []byte) error {
|
||||
// TODO(kradalby): use encoding/json/v2 (go-json-experiment)
|
||||
dec := json.NewDecoder(bytes.NewReader(b))
|
||||
var v any
|
||||
if err := dec.Decode(&v); err != nil {
|
||||
return err
|
||||
}
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
ve.Alias = parseAlias(val)
|
||||
default:
|
||||
return fmt.Errorf("type %T not supported", val)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Aliases []Alias
|
||||
|
||||
func (a *Aliases) UnmarshalJSON(b []byte) error {
|
||||
var aliases []AliasEnc
|
||||
err := json.Unmarshal(b, &aliases)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*a = make([]Alias, len(aliases))
|
||||
for i, alias := range aliases {
|
||||
(*a)[i] = alias.Alias
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UserEntity is an interface that represents something that can
|
||||
// return a list of users:
|
||||
// - Username
|
||||
// - Group
|
||||
// - AutoGroup
|
||||
type UserEntity interface {
|
||||
Users() []Username
|
||||
UnmarshalJSON([]byte) error
|
||||
}
|
||||
|
||||
// Groups are a map of Group to a list of Username.
|
||||
type Groups map[Group][]Username
|
||||
|
||||
// Hosts are alias for IP addresses or subnets.
|
||||
type Hosts map[Host]netip.Prefix
|
||||
|
||||
// TagOwners are a map of Tag to a list of the UserEntities that own the tag.
|
||||
type TagOwners map[Tag][]UserEntity
|
||||
|
||||
type AutoApprovers struct {
|
||||
Routes map[string][]string `json:"routes"`
|
||||
ExitNode []string `json:"exitNode"`
|
||||
}
|
||||
|
||||
type ACL struct {
|
||||
Action string `json:"action"`
|
||||
Protocol string `json:"proto"`
|
||||
Sources Aliases `json:"src"`
|
||||
Destinations []AliasWithPorts `json:"dst"`
|
||||
}
|
||||
|
||||
// ACLPolicy represents a Tailscale ACL Policy.
|
||||
type ACLPolicy struct {
|
||||
Groups Groups `json:"groups"`
|
||||
// Hosts Hosts `json:"hosts"`
|
||||
TagOwners TagOwners `json:"tagOwners"`
|
||||
ACLs []ACL `json:"acls"`
|
||||
AutoApprovers AutoApprovers `json:"autoApprovers"`
|
||||
// SSHs []SSH `json:"ssh"`
|
||||
}
|
||||
|
||||
const (
|
||||
expectedTokenItems = 2
|
||||
)
|
||||
|
||||
// TODO(kradalby): copy tests from parseDestination in policy
|
||||
func splitDestination(dest string) (string, string, error) {
|
||||
var tokens []string
|
||||
|
||||
// Check if there is a IPv4/6:Port combination, IPv6 has more than
|
||||
// three ":".
|
||||
tokens = strings.Split(dest, ":")
|
||||
if len(tokens) < expectedTokenItems || len(tokens) > 3 {
|
||||
port := tokens[len(tokens)-1]
|
||||
|
||||
maybeIPv6Str := strings.TrimSuffix(dest, ":"+port)
|
||||
|
||||
filteredMaybeIPv6Str := maybeIPv6Str
|
||||
if strings.Contains(maybeIPv6Str, "/") {
|
||||
networkParts := strings.Split(maybeIPv6Str, "/")
|
||||
filteredMaybeIPv6Str = networkParts[0]
|
||||
}
|
||||
|
||||
if maybeIPv6, err := netip.ParseAddr(filteredMaybeIPv6Str); err != nil && !maybeIPv6.Is6() {
|
||||
return "", "", fmt.Errorf(
|
||||
"failed to split destination: %v",
|
||||
tokens,
|
||||
)
|
||||
} else {
|
||||
tokens = []string{maybeIPv6Str, port}
|
||||
}
|
||||
}
|
||||
|
||||
var alias string
|
||||
// We can have here stuff like:
|
||||
// git-server:*
|
||||
// 192.168.1.0/24:22
|
||||
// fd7a:115c:a1e0::2:22
|
||||
// fd7a:115c:a1e0::2/128:22
|
||||
// tag:montreal-webserver:80,443
|
||||
// tag:api-server:443
|
||||
// example-host-1:*
|
||||
if len(tokens) == expectedTokenItems {
|
||||
alias = tokens[0]
|
||||
} else {
|
||||
alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1])
|
||||
}
|
||||
|
||||
return alias, tokens[len(tokens)-1], nil
|
||||
}
|
||||
|
||||
// TODO(kradalby): write/copy tests from expandPorts in policy
|
||||
func parsePorts(portsStr string) ([]tailcfg.PortRange, error) {
|
||||
if portsStr == "*" {
|
||||
return []tailcfg.PortRange{
|
||||
tailcfg.PortRangeAny,
|
||||
}, nil
|
||||
}
|
||||
|
||||
var ports []tailcfg.PortRange
|
||||
for _, portStr := range strings.Split(portsStr, ",") {
|
||||
log.Trace().Msgf("parsing portstring: %s", portStr)
|
||||
rang := strings.Split(portStr, "-")
|
||||
switch len(rang) {
|
||||
case 1:
|
||||
port, err := strconv.ParseUint(rang[0], util.Base10, util.BitSize16)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ports = append(ports, tailcfg.PortRange{
|
||||
First: uint16(port),
|
||||
Last: uint16(port),
|
||||
})
|
||||
|
||||
case expectedTokenItems:
|
||||
start, err := strconv.ParseUint(rang[0], util.Base10, util.BitSize16)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
last, err := strconv.ParseUint(rang[1], util.Base10, util.BitSize16)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ports = append(ports, tailcfg.PortRange{
|
||||
First: uint16(start),
|
||||
Last: uint16(last),
|
||||
})
|
||||
|
||||
default:
|
||||
return nil, errors.New("invalid ports")
|
||||
}
|
||||
}
|
||||
|
||||
return ports, nil
|
||||
}
|
95
hscontrol/policyv2/types_test.go
Normal file
95
hscontrol/policyv2/types_test.go
Normal file
|
@ -0,0 +1,95 @@
|
|||
package policyv2
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/tailscale/hujson"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
func TestUnmarshalPolicy(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want *ACLPolicy
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
input: "{}",
|
||||
want: &ACLPolicy{},
|
||||
},
|
||||
{
|
||||
name: "basic-types",
|
||||
input: `
|
||||
{
|
||||
"groups": {
|
||||
"group:example": [
|
||||
"testuser@headscale.net",
|
||||
],
|
||||
},
|
||||
|
||||
"hosts": {
|
||||
"host-1": "100.100.100.100",
|
||||
"subnet-1": "100.100.101.100/24",
|
||||
},
|
||||
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": [
|
||||
"group:example",
|
||||
],
|
||||
"dst": [
|
||||
"host-1:*",
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
`,
|
||||
want: &ACLPolicy{
|
||||
Groups: Groups{
|
||||
Group("group:example"): []Username{"testuser@headscale.net"},
|
||||
},
|
||||
ACLs: []ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: Aliases{
|
||||
Group("group:example"),
|
||||
},
|
||||
Destinations: []AliasWithPorts{
|
||||
{
|
||||
Alias: Host("host-1"),
|
||||
Ports: []tailcfg.PortRange{tailcfg.PortRangeAny},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var policy ACLPolicy
|
||||
ast, err := hujson.Parse([]byte(tt.input))
|
||||
if err != nil {
|
||||
t.Fatalf("parsing hujson: %s", err)
|
||||
}
|
||||
|
||||
ast.Standardize()
|
||||
acl := ast.Pack()
|
||||
|
||||
if err := json.Unmarshal(acl, &policy); err != nil {
|
||||
// TODO: check error type
|
||||
t.Fatalf("unmarshaling json: %s", err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.want, &policy); diff != "" {
|
||||
t.Fatalf("unexpected policy (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue