implement asterix, pass old parsing test

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2024-10-17 01:02:17 +02:00
parent e5e1f15dd9
commit 38f2159c56
No known key found for this signature in database
3 changed files with 81 additions and 52 deletions

View file

@ -358,7 +358,7 @@ func TestParsing(t *testing.T) {
&types.Node{ &types.Node{
IPv4: ap("200.200.200.200"), IPv4: ap("200.200.200.200"),
User: types.User{ User: types.User{
Name: "testuser", Name: "testuser@",
}, },
Hostinfo: &tailcfg.Hostinfo{}, Hostinfo: &tailcfg.Hostinfo{},
}, },

View file

@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"log"
"net/netip" "net/netip"
"strconv" "strconv"
"strings" "strings"
@ -50,24 +51,29 @@ func theInternet() *netipx.IPSet {
return theInternetSet return theInternetSet
} }
type Asterix string type Asterix int
func (a Asterix) Validate() error { func (a Asterix) Validate() error {
if a == "*" {
return nil
}
return fmt.Errorf(`Asterix can only be "*", got: %s`, a)
}
func (a *Asterix) String() string {
return string(*a)
}
func (a *Asterix) UnmarshalJSON(b []byte) error {
*a = "*"
return nil return nil
} }
func (a Asterix) String() string {
return "*"
}
func (a Asterix) UnmarshalJSON(b []byte) error {
return nil
}
func (a Asterix) Resolve(_ *Policy, nodes types.Nodes) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder
ips.AddPrefix(tsaddr.AllIPv4())
ips.AddPrefix(tsaddr.AllIPv6())
return ips.IPSet()
}
// Username is a string that represents a username, it must contain an @. // Username is a string that represents a username, it must contain an @.
type Username string type Username string
@ -120,8 +126,8 @@ func (g Group) Validate() error {
return fmt.Errorf(`Group has to start with "group:", got: %q`, g) return fmt.Errorf(`Group has to start with "group:", got: %q`, g)
} }
func (g Group) UnmarshalJSON(b []byte) error { func (g *Group) UnmarshalJSON(b []byte) error {
g = Group(strings.Trim(string(b), `"`)) *g = Group(strings.Trim(string(b), `"`))
if err := g.Validate(); err != nil { if err := g.Validate(); err != nil {
return err return err
} }
@ -157,8 +163,8 @@ func (t Tag) Validate() error {
return fmt.Errorf(`tag has to start with "tag:", got: %q`, t) return fmt.Errorf(`tag has to start with "tag:", got: %q`, t)
} }
func (t Tag) UnmarshalJSON(b []byte) error { func (t *Tag) UnmarshalJSON(b []byte) error {
t = Tag(strings.Trim(string(b), `"`)) *t = Tag(strings.Trim(string(b), `"`))
if err := t.Validate(); err != nil { if err := t.Validate(); err != nil {
return err return err
} }
@ -184,8 +190,8 @@ func (h Host) Validate() error {
return nil return nil
} }
func (h Host) UnmarshalJSON(b []byte) error { func (h *Host) UnmarshalJSON(b []byte) error {
h = Host(strings.Trim(string(b), `"`)) *h = Host(strings.Trim(string(b), `"`))
if err := h.Validate(); err != nil { if err := h.Validate(); err != nil {
return err return err
} }
@ -195,7 +201,15 @@ func (h Host) UnmarshalJSON(b []byte) error {
func (h Host) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) { func (h Host) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder var ips netipx.IPSetBuilder
ips.AddPrefix(netip.Prefix(p.Hosts[h])) pref, ok := p.Hosts[h]
if !ok {
return nil, fmt.Errorf("unable to resolve host: %q", h)
}
err := pref.Validate()
if err != nil {
return nil, err
}
ips.AddPrefix(netip.Prefix(pref))
return ips.IPSet() return ips.IPSet()
} }
@ -275,8 +289,8 @@ func (ag AutoGroup) Validate() error {
return fmt.Errorf("AutoGroup is invalid, got: %q, must be one of %v", ag, autogroups) return fmt.Errorf("AutoGroup is invalid, got: %q, must be one of %v", ag, autogroups)
} }
func (ag AutoGroup) UnmarshalJSON(b []byte) error { func (ag *AutoGroup) UnmarshalJSON(b []byte) error {
ag = AutoGroup(strings.Trim(string(b), `"`)) *ag = AutoGroup(strings.Trim(string(b), `"`))
if err := ag.Validate(); err != nil { if err := ag.Validate(); err != nil {
return err return err
} }
@ -330,6 +344,9 @@ func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error {
} }
ve.Alias = parseAlias(vs) ve.Alias = parseAlias(vs)
if ve.Alias == nil {
return fmt.Errorf("could not determine the type of %q", vs)
}
if err := ve.Alias.Validate(); err != nil { if err := ve.Alias.Validate(); err != nil {
return err return err
} }
@ -353,17 +370,22 @@ func parseAlias(vs string) Alias {
switch { switch {
case vs == "*": case vs == "*":
return ptr.To(Asterix("*")) return Asterix(0)
case strings.Contains(vs, "@"): case strings.Contains(vs, "@"):
return ptr.To(Username(vs)) return ptr.To(Username(vs))
case strings.HasPrefix(vs, "group:"): case strings.HasPrefix(vs, "group:"):
ptr.To(Group(vs)) return ptr.To(Group(vs))
case strings.HasPrefix(vs, "tag:"): case strings.HasPrefix(vs, "tag:"):
return Tag(vs) return ptr.To(Tag(vs))
case strings.HasPrefix(vs, "autogroup:"): case strings.HasPrefix(vs, "autogroup:"):
return AutoGroup(vs) return ptr.To(AutoGroup(vs))
} }
return Host(vs)
if !strings.Contains(vs, "@") && !strings.Contains(vs, ":") {
return ptr.To(Host(vs))
}
return nil
} }
// AliasEnc is used to deserialize a Alias. // AliasEnc is used to deserialize a Alias.
@ -379,10 +401,13 @@ func (ve *AliasEnc) UnmarshalJSON(b []byte) error {
switch val := v.(type) { switch val := v.(type) {
case string: case string:
ve.Alias = parseAlias(val) ve.Alias = parseAlias(val)
ve.Alias = parseAlias(val) if ve.Alias == nil {
return fmt.Errorf("could not determine the type of %q", val)
}
if err := ve.Alias.Validate(); err != nil { if err := ve.Alias.Validate(); err != nil {
return err return err
} }
log.Printf("val: %q as type: %T", val, ve.Alias)
default: default:
return fmt.Errorf("type %T not supported", val) return fmt.Errorf("type %T not supported", val)
} }
@ -440,10 +465,9 @@ func (ve *OwnerEnc) UnmarshalJSON(b []byte) error {
switch { switch {
case strings.Contains(val, "@"): case strings.Contains(val, "@"):
u := Username(val) ve.Owner = ptr.To(Username(val))
ve.Owner = &u
case strings.HasPrefix(val, "group:"): case strings.HasPrefix(val, "group:"):
ve.Owner = Group(val) ve.Owner = ptr.To(Group(val))
} }
default: default:
return fmt.Errorf("type %T not supported", val) return fmt.Errorf("type %T not supported", val)

View file

@ -57,7 +57,7 @@ func TestUnmarshalPolicy(t *testing.T) {
"tagOwners": { "tagOwners": {
"tag:user": ["testuser@headscale.net"], "tag:user": ["testuser@headscale.net"],
"tag:group": ["group:other"], "tag:group": ["group:other"],
"tag:userandgroup": ["testuser@headscale.net" ,"group:other"], "tag:userandgroup": ["testuser@headscale.net", "group:other"],
}, },
"hosts": { "hosts": {
@ -132,9 +132,9 @@ func TestUnmarshalPolicy(t *testing.T) {
Group("group:other"): []Username{Username("otheruser@headscale.net")}, Group("group:other"): []Username{Username("otheruser@headscale.net")},
}, },
TagOwners: TagOwners{ TagOwners: TagOwners{
Tag("tag:user"): Owners{ptr.To(Username("testuser@headscale.net"))}, Tag("tag:user"): Owners{up("testuser@headscale.net")},
Tag("tag:group"): Owners{Group("group:other")}, Tag("tag:group"): Owners{gp("group:other")},
Tag("tag:userandgroup"): Owners{ptr.To(Username("testuser@headscale.net")), Group("group:other")}, Tag("tag:userandgroup"): Owners{up("testuser@headscale.net"), gp("group:other")},
}, },
Hosts: Hosts{ Hosts: Hosts{
"host-1": Prefix(netip.MustParsePrefix("100.100.100.100/32")), "host-1": Prefix(netip.MustParsePrefix("100.100.100.100/32")),
@ -152,14 +152,14 @@ func TestUnmarshalPolicy(t *testing.T) {
// any approved subnets and autogroup:shared. // any approved subnets and autogroup:shared.
// It does not allow traffic originating from // It does not allow traffic originating from
// non-tailscale devices (unless it is an approved route). // non-tailscale devices (unless it is an approved route).
Host("*"), hp("*"),
}, },
Destinations: []AliasWithPorts{ Destinations: []AliasWithPorts{
{ {
// TODO(kradalby): Should this be host? // TODO(kradalby): Should this be host?
// It is: // It is:
// Includes any destination (no restrictions). // Includes any destination (no restrictions).
Alias: Host("*"), Alias: hp("*"),
Ports: []tailcfg.PortRange{tailcfg.PortRangeAny}, Ports: []tailcfg.PortRange{tailcfg.PortRangeAny},
}, },
}, },
@ -181,11 +181,11 @@ func TestUnmarshalPolicy(t *testing.T) {
Action: "accept", Action: "accept",
Protocol: "tcp", Protocol: "tcp",
Sources: Aliases{ Sources: Aliases{
Group("group:example"), gp("group:example"),
}, },
Destinations: []AliasWithPorts{ Destinations: []AliasWithPorts{
{ {
Alias: Group("group:other"), Alias: gp("group:other"),
Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}}, Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}},
}, },
}, },
@ -194,11 +194,11 @@ func TestUnmarshalPolicy(t *testing.T) {
Action: "accept", Action: "accept",
Protocol: "tcp", Protocol: "tcp",
Sources: Aliases{ Sources: Aliases{
ptr.To(Prefix(netip.MustParsePrefix("100.101.102.103/32"))), pp("100.101.102.103/32"),
}, },
Destinations: []AliasWithPorts{ Destinations: []AliasWithPorts{
{ {
Alias: ptr.To(Prefix(netip.MustParsePrefix("100.101.102.104/32"))), Alias: pp("100.101.102.104/32"),
Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}}, Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}},
}, },
}, },
@ -207,11 +207,11 @@ func TestUnmarshalPolicy(t *testing.T) {
Action: "accept", Action: "accept",
Protocol: "udp", Protocol: "udp",
Sources: Aliases{ Sources: Aliases{
ptr.To(Prefix(netip.MustParsePrefix("10.0.0.0/8"))), pp("10.0.0.0/8"),
}, },
Destinations: []AliasWithPorts{ Destinations: []AliasWithPorts{
{ {
Alias: ptr.To(Prefix(netip.MustParsePrefix("172.16.0.0/16"))), Alias: pp("172.16.0.0/16"),
Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}}, Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}},
}, },
}, },
@ -220,11 +220,11 @@ func TestUnmarshalPolicy(t *testing.T) {
Action: "accept", Action: "accept",
Protocol: "tcp", Protocol: "tcp",
Sources: Aliases{ Sources: Aliases{
Host("subnet-1"), hp("subnet-1"),
}, },
Destinations: []AliasWithPorts{ Destinations: []AliasWithPorts{
{ {
Alias: Host("host-1"), Alias: hp("host-1"),
Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 88}}, Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 88}},
}, },
}, },
@ -233,11 +233,11 @@ func TestUnmarshalPolicy(t *testing.T) {
Action: "accept", Action: "accept",
Protocol: "tcp", Protocol: "tcp",
Sources: Aliases{ Sources: Aliases{
Tag("tag:group"), tp("tag:group"),
}, },
Destinations: []AliasWithPorts{ Destinations: []AliasWithPorts{
{ {
Alias: Tag("tag:user"), Alias: tp("tag:user"),
Ports: []tailcfg.PortRange{ Ports: []tailcfg.PortRange{
tailcfg.PortRange{First: 80, Last: 80}, tailcfg.PortRange{First: 80, Last: 80},
tailcfg.PortRange{First: 443, Last: 443}, tailcfg.PortRange{First: 443, Last: 443},
@ -249,11 +249,11 @@ func TestUnmarshalPolicy(t *testing.T) {
Action: "accept", Action: "accept",
Protocol: "tcp", Protocol: "tcp",
Sources: Aliases{ Sources: Aliases{
Tag("tag:group"), tp("tag:group"),
}, },
Destinations: []AliasWithPorts{ Destinations: []AliasWithPorts{
{ {
Alias: AutoGroup("autogroup:internet"), Alias: agp("autogroup:internet"),
Ports: []tailcfg.PortRange{ Ports: []tailcfg.PortRange{
tailcfg.PortRange{First: 80, Last: 80}, tailcfg.PortRange{First: 80, Last: 80},
}, },
@ -367,6 +367,11 @@ func TestUnmarshalPolicy(t *testing.T) {
} }
} }
func gp(s string) *Group { return ptr.To(Group(s)) }
func up(s string) *Username { return ptr.To(Username(s)) }
func hp(s string) *Host { return ptr.To(Host(s)) }
func tp(s string) *Tag { return ptr.To(Tag(s)) }
func agp(s string) *AutoGroup { return ptr.To(AutoGroup(s)) }
func mp(pref string) netip.Prefix { return netip.MustParsePrefix(pref) } func mp(pref string) netip.Prefix { return netip.MustParsePrefix(pref) }
func ap(addr string) *netip.Addr { return ptr.To(netip.MustParseAddr(addr)) } func ap(addr string) *netip.Addr { return ptr.To(netip.MustParseAddr(addr)) }
func pp(pref string) *Prefix { return ptr.To(Prefix(netip.MustParsePrefix(pref))) } func pp(pref string) *Prefix { return ptr.To(Prefix(netip.MustParsePrefix(pref))) }
@ -392,7 +397,7 @@ func TestResolvePolicy(t *testing.T) {
"testhost": p("100.100.101.102/32"), "testhost": p("100.100.101.102/32"),
}, },
}, },
toResolve: Host("testhost"), toResolve: hp("testhost"),
want: []netip.Prefix{mp("100.100.101.102/32")}, want: []netip.Prefix{mp("100.100.101.102/32")},
}, },
{ {
@ -491,7 +496,7 @@ func TestResolvePolicy(t *testing.T) {
}, },
{ {
name: "tag", name: "tag",
toResolve: Tag("tag:test"), toResolve: tp("tag:test"),
nodes: types.Nodes{ nodes: types.Nodes{
// Not matching other user // Not matching other user
{ {