implement asterix, pass old parsing test

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2024-10-17 01:02:17 +02:00
parent 05a9a03358
commit 1c029c365d
No known key found for this signature in database
3 changed files with 81 additions and 52 deletions

View file

@ -358,7 +358,7 @@ func TestParsing(t *testing.T) {
&types.Node{
IPv4: ap("200.200.200.200"),
User: types.User{
Name: "testuser",
Name: "testuser@",
},
Hostinfo: &tailcfg.Hostinfo{},
},

View file

@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"log"
"net/netip"
"strconv"
"strings"
@ -50,24 +51,29 @@ func theInternet() *netipx.IPSet {
return theInternetSet
}
type Asterix string
type Asterix int
func (a Asterix) Validate() error {
if a == "*" {
return nil
}
return fmt.Errorf(`Asterix can only be "*", got: %s`, a)
}
func (a *Asterix) String() string {
return string(*a)
}
func (a *Asterix) UnmarshalJSON(b []byte) error {
*a = "*"
return nil
}
func (a Asterix) String() string {
return "*"
}
func (a Asterix) UnmarshalJSON(b []byte) error {
return nil
}
func (a Asterix) Resolve(_ *Policy, nodes types.Nodes) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder
ips.AddPrefix(tsaddr.AllIPv4())
ips.AddPrefix(tsaddr.AllIPv6())
return ips.IPSet()
}
// Username is a string that represents a username, it must contain an @.
type Username string
@ -120,8 +126,8 @@ func (g Group) Validate() error {
return fmt.Errorf(`Group has to start with "group:", got: %q`, g)
}
func (g Group) UnmarshalJSON(b []byte) error {
g = Group(strings.Trim(string(b), `"`))
func (g *Group) UnmarshalJSON(b []byte) error {
*g = Group(strings.Trim(string(b), `"`))
if err := g.Validate(); err != nil {
return err
}
@ -157,8 +163,8 @@ func (t Tag) Validate() error {
return fmt.Errorf(`tag has to start with "tag:", got: %q`, t)
}
func (t Tag) UnmarshalJSON(b []byte) error {
t = Tag(strings.Trim(string(b), `"`))
func (t *Tag) UnmarshalJSON(b []byte) error {
*t = Tag(strings.Trim(string(b), `"`))
if err := t.Validate(); err != nil {
return err
}
@ -184,8 +190,8 @@ func (h Host) Validate() error {
return nil
}
func (h Host) UnmarshalJSON(b []byte) error {
h = Host(strings.Trim(string(b), `"`))
func (h *Host) UnmarshalJSON(b []byte) error {
*h = Host(strings.Trim(string(b), `"`))
if err := h.Validate(); err != nil {
return err
}
@ -195,7 +201,15 @@ func (h Host) UnmarshalJSON(b []byte) error {
func (h Host) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder
ips.AddPrefix(netip.Prefix(p.Hosts[h]))
pref, ok := p.Hosts[h]
if !ok {
return nil, fmt.Errorf("unable to resolve host: %q", h)
}
err := pref.Validate()
if err != nil {
return nil, err
}
ips.AddPrefix(netip.Prefix(pref))
return ips.IPSet()
}
@ -275,8 +289,8 @@ func (ag AutoGroup) Validate() error {
return fmt.Errorf("AutoGroup is invalid, got: %q, must be one of %v", ag, autogroups)
}
func (ag AutoGroup) UnmarshalJSON(b []byte) error {
ag = AutoGroup(strings.Trim(string(b), `"`))
func (ag *AutoGroup) UnmarshalJSON(b []byte) error {
*ag = AutoGroup(strings.Trim(string(b), `"`))
if err := ag.Validate(); err != nil {
return err
}
@ -330,6 +344,9 @@ func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error {
}
ve.Alias = parseAlias(vs)
if ve.Alias == nil {
return fmt.Errorf("could not determine the type of %q", vs)
}
if err := ve.Alias.Validate(); err != nil {
return err
}
@ -353,17 +370,22 @@ func parseAlias(vs string) Alias {
switch {
case vs == "*":
return ptr.To(Asterix("*"))
return Asterix(0)
case strings.Contains(vs, "@"):
return ptr.To(Username(vs))
case strings.HasPrefix(vs, "group:"):
ptr.To(Group(vs))
return ptr.To(Group(vs))
case strings.HasPrefix(vs, "tag:"):
return Tag(vs)
return ptr.To(Tag(vs))
case strings.HasPrefix(vs, "autogroup:"):
return AutoGroup(vs)
return ptr.To(AutoGroup(vs))
}
return Host(vs)
if !strings.Contains(vs, "@") && !strings.Contains(vs, ":") {
return ptr.To(Host(vs))
}
return nil
}
// AliasEnc is used to deserialize a Alias.
@ -379,10 +401,13 @@ func (ve *AliasEnc) UnmarshalJSON(b []byte) error {
switch val := v.(type) {
case string:
ve.Alias = parseAlias(val)
ve.Alias = parseAlias(val)
if ve.Alias == nil {
return fmt.Errorf("could not determine the type of %q", val)
}
if err := ve.Alias.Validate(); err != nil {
return err
}
log.Printf("val: %q as type: %T", val, ve.Alias)
default:
return fmt.Errorf("type %T not supported", val)
}
@ -440,10 +465,9 @@ func (ve *OwnerEnc) UnmarshalJSON(b []byte) error {
switch {
case strings.Contains(val, "@"):
u := Username(val)
ve.Owner = &u
ve.Owner = ptr.To(Username(val))
case strings.HasPrefix(val, "group:"):
ve.Owner = Group(val)
ve.Owner = ptr.To(Group(val))
}
default:
return fmt.Errorf("type %T not supported", val)

View file

@ -57,7 +57,7 @@ func TestUnmarshalPolicy(t *testing.T) {
"tagOwners": {
"tag:user": ["testuser@headscale.net"],
"tag:group": ["group:other"],
"tag:userandgroup": ["testuser@headscale.net" ,"group:other"],
"tag:userandgroup": ["testuser@headscale.net", "group:other"],
},
"hosts": {
@ -132,9 +132,9 @@ func TestUnmarshalPolicy(t *testing.T) {
Group("group:other"): []Username{Username("otheruser@headscale.net")},
},
TagOwners: TagOwners{
Tag("tag:user"): Owners{ptr.To(Username("testuser@headscale.net"))},
Tag("tag:group"): Owners{Group("group:other")},
Tag("tag:userandgroup"): Owners{ptr.To(Username("testuser@headscale.net")), Group("group:other")},
Tag("tag:user"): Owners{up("testuser@headscale.net")},
Tag("tag:group"): Owners{gp("group:other")},
Tag("tag:userandgroup"): Owners{up("testuser@headscale.net"), gp("group:other")},
},
Hosts: Hosts{
"host-1": Prefix(netip.MustParsePrefix("100.100.100.100/32")),
@ -152,14 +152,14 @@ func TestUnmarshalPolicy(t *testing.T) {
// any approved subnets and autogroup:shared.
// It does not allow traffic originating from
// non-tailscale devices (unless it is an approved route).
Host("*"),
hp("*"),
},
Destinations: []AliasWithPorts{
{
// TODO(kradalby): Should this be host?
// It is:
// Includes any destination (no restrictions).
Alias: Host("*"),
Alias: hp("*"),
Ports: []tailcfg.PortRange{tailcfg.PortRangeAny},
},
},
@ -181,11 +181,11 @@ func TestUnmarshalPolicy(t *testing.T) {
Action: "accept",
Protocol: "tcp",
Sources: Aliases{
Group("group:example"),
gp("group:example"),
},
Destinations: []AliasWithPorts{
{
Alias: Group("group:other"),
Alias: gp("group:other"),
Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}},
},
},
@ -194,11 +194,11 @@ func TestUnmarshalPolicy(t *testing.T) {
Action: "accept",
Protocol: "tcp",
Sources: Aliases{
ptr.To(Prefix(netip.MustParsePrefix("100.101.102.103/32"))),
pp("100.101.102.103/32"),
},
Destinations: []AliasWithPorts{
{
Alias: ptr.To(Prefix(netip.MustParsePrefix("100.101.102.104/32"))),
Alias: pp("100.101.102.104/32"),
Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}},
},
},
@ -207,11 +207,11 @@ func TestUnmarshalPolicy(t *testing.T) {
Action: "accept",
Protocol: "udp",
Sources: Aliases{
ptr.To(Prefix(netip.MustParsePrefix("10.0.0.0/8"))),
pp("10.0.0.0/8"),
},
Destinations: []AliasWithPorts{
{
Alias: ptr.To(Prefix(netip.MustParsePrefix("172.16.0.0/16"))),
Alias: pp("172.16.0.0/16"),
Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}},
},
},
@ -220,11 +220,11 @@ func TestUnmarshalPolicy(t *testing.T) {
Action: "accept",
Protocol: "tcp",
Sources: Aliases{
Host("subnet-1"),
hp("subnet-1"),
},
Destinations: []AliasWithPorts{
{
Alias: Host("host-1"),
Alias: hp("host-1"),
Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 88}},
},
},
@ -233,11 +233,11 @@ func TestUnmarshalPolicy(t *testing.T) {
Action: "accept",
Protocol: "tcp",
Sources: Aliases{
Tag("tag:group"),
tp("tag:group"),
},
Destinations: []AliasWithPorts{
{
Alias: Tag("tag:user"),
Alias: tp("tag:user"),
Ports: []tailcfg.PortRange{
tailcfg.PortRange{First: 80, Last: 80},
tailcfg.PortRange{First: 443, Last: 443},
@ -249,11 +249,11 @@ func TestUnmarshalPolicy(t *testing.T) {
Action: "accept",
Protocol: "tcp",
Sources: Aliases{
Tag("tag:group"),
tp("tag:group"),
},
Destinations: []AliasWithPorts{
{
Alias: AutoGroup("autogroup:internet"),
Alias: agp("autogroup:internet"),
Ports: []tailcfg.PortRange{
tailcfg.PortRange{First: 80, Last: 80},
},
@ -367,6 +367,11 @@ func TestUnmarshalPolicy(t *testing.T) {
}
}
func gp(s string) *Group { return ptr.To(Group(s)) }
func up(s string) *Username { return ptr.To(Username(s)) }
func hp(s string) *Host { return ptr.To(Host(s)) }
func tp(s string) *Tag { return ptr.To(Tag(s)) }
func agp(s string) *AutoGroup { return ptr.To(AutoGroup(s)) }
func mp(pref string) netip.Prefix { return netip.MustParsePrefix(pref) }
func ap(addr string) *netip.Addr { return ptr.To(netip.MustParseAddr(addr)) }
func pp(pref string) *Prefix { return ptr.To(Prefix(netip.MustParsePrefix(pref))) }
@ -392,7 +397,7 @@ func TestResolvePolicy(t *testing.T) {
"testhost": p("100.100.101.102/32"),
},
},
toResolve: Host("testhost"),
toResolve: hp("testhost"),
want: []netip.Prefix{mp("100.100.101.102/32")},
},
{
@ -491,7 +496,7 @@ func TestResolvePolicy(t *testing.T) {
},
{
name: "tag",
toResolve: Tag("tag:test"),
toResolve: tp("tag:test"),
nodes: types.Nodes{
// Not matching other user
{