diff --git a/hscontrol/policyv2/filter_test.go b/hscontrol/policyv2/filter_test.go index 4edf7233..6b76b168 100644 --- a/hscontrol/policyv2/filter_test.go +++ b/hscontrol/policyv2/filter_test.go @@ -358,7 +358,7 @@ func TestParsing(t *testing.T) { &types.Node{ IPv4: ap("200.200.200.200"), User: types.User{ - Name: "testuser", + Name: "testuser@", }, Hostinfo: &tailcfg.Hostinfo{}, }, diff --git a/hscontrol/policyv2/types.go b/hscontrol/policyv2/types.go index 3bc65e1c..2bbb793a 100644 --- a/hscontrol/policyv2/types.go +++ b/hscontrol/policyv2/types.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "log" "net/netip" "strconv" "strings" @@ -50,24 +51,29 @@ func theInternet() *netipx.IPSet { return theInternetSet } -type Asterix string +type Asterix int func (a Asterix) Validate() error { - if a == "*" { - return nil - } - return fmt.Errorf(`Asterix can only be "*", got: %s`, a) -} - -func (a *Asterix) String() string { - return string(*a) -} - -func (a *Asterix) UnmarshalJSON(b []byte) error { - *a = "*" return nil } +func (a Asterix) String() string { + return "*" +} + +func (a Asterix) UnmarshalJSON(b []byte) error { + return nil +} + +func (a Asterix) Resolve(_ *Policy, nodes types.Nodes) (*netipx.IPSet, error) { + var ips netipx.IPSetBuilder + + ips.AddPrefix(tsaddr.AllIPv4()) + ips.AddPrefix(tsaddr.AllIPv6()) + + return ips.IPSet() +} + // Username is a string that represents a username, it must contain an @. type Username string @@ -120,8 +126,8 @@ func (g Group) Validate() error { return fmt.Errorf(`Group has to start with "group:", got: %q`, g) } -func (g Group) UnmarshalJSON(b []byte) error { - g = Group(strings.Trim(string(b), `"`)) +func (g *Group) UnmarshalJSON(b []byte) error { + *g = Group(strings.Trim(string(b), `"`)) if err := g.Validate(); err != nil { return err } @@ -157,8 +163,8 @@ func (t Tag) Validate() error { return fmt.Errorf(`tag has to start with "tag:", got: %q`, t) } -func (t Tag) UnmarshalJSON(b []byte) error { - t = Tag(strings.Trim(string(b), `"`)) +func (t *Tag) UnmarshalJSON(b []byte) error { + *t = Tag(strings.Trim(string(b), `"`)) if err := t.Validate(); err != nil { return err } @@ -184,8 +190,8 @@ func (h Host) Validate() error { return nil } -func (h Host) UnmarshalJSON(b []byte) error { - h = Host(strings.Trim(string(b), `"`)) +func (h *Host) UnmarshalJSON(b []byte) error { + *h = Host(strings.Trim(string(b), `"`)) if err := h.Validate(); err != nil { return err } @@ -195,7 +201,15 @@ func (h Host) UnmarshalJSON(b []byte) error { func (h Host) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) { var ips netipx.IPSetBuilder - ips.AddPrefix(netip.Prefix(p.Hosts[h])) + pref, ok := p.Hosts[h] + if !ok { + return nil, fmt.Errorf("unable to resolve host: %q", h) + } + err := pref.Validate() + if err != nil { + return nil, err + } + ips.AddPrefix(netip.Prefix(pref)) return ips.IPSet() } @@ -275,8 +289,8 @@ func (ag AutoGroup) Validate() error { return fmt.Errorf("AutoGroup is invalid, got: %q, must be one of %v", ag, autogroups) } -func (ag AutoGroup) UnmarshalJSON(b []byte) error { - ag = AutoGroup(strings.Trim(string(b), `"`)) +func (ag *AutoGroup) UnmarshalJSON(b []byte) error { + *ag = AutoGroup(strings.Trim(string(b), `"`)) if err := ag.Validate(); err != nil { return err } @@ -330,6 +344,9 @@ func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error { } ve.Alias = parseAlias(vs) + if ve.Alias == nil { + return fmt.Errorf("could not determine the type of %q", vs) + } if err := ve.Alias.Validate(); err != nil { return err } @@ -353,17 +370,22 @@ func parseAlias(vs string) Alias { switch { case vs == "*": - return ptr.To(Asterix("*")) + return Asterix(0) case strings.Contains(vs, "@"): return ptr.To(Username(vs)) case strings.HasPrefix(vs, "group:"): - ptr.To(Group(vs)) + return ptr.To(Group(vs)) case strings.HasPrefix(vs, "tag:"): - return Tag(vs) + return ptr.To(Tag(vs)) case strings.HasPrefix(vs, "autogroup:"): - return AutoGroup(vs) + return ptr.To(AutoGroup(vs)) } - return Host(vs) + + if !strings.Contains(vs, "@") && !strings.Contains(vs, ":") { + return ptr.To(Host(vs)) + } + + return nil } // AliasEnc is used to deserialize a Alias. @@ -379,10 +401,13 @@ func (ve *AliasEnc) UnmarshalJSON(b []byte) error { switch val := v.(type) { case string: ve.Alias = parseAlias(val) - ve.Alias = parseAlias(val) + if ve.Alias == nil { + return fmt.Errorf("could not determine the type of %q", val) + } if err := ve.Alias.Validate(); err != nil { return err } + log.Printf("val: %q as type: %T", val, ve.Alias) default: return fmt.Errorf("type %T not supported", val) } @@ -440,10 +465,9 @@ func (ve *OwnerEnc) UnmarshalJSON(b []byte) error { switch { case strings.Contains(val, "@"): - u := Username(val) - ve.Owner = &u + ve.Owner = ptr.To(Username(val)) case strings.HasPrefix(val, "group:"): - ve.Owner = Group(val) + ve.Owner = ptr.To(Group(val)) } default: return fmt.Errorf("type %T not supported", val) diff --git a/hscontrol/policyv2/types_test.go b/hscontrol/policyv2/types_test.go index 26423a61..018f9951 100644 --- a/hscontrol/policyv2/types_test.go +++ b/hscontrol/policyv2/types_test.go @@ -57,7 +57,7 @@ func TestUnmarshalPolicy(t *testing.T) { "tagOwners": { "tag:user": ["testuser@headscale.net"], "tag:group": ["group:other"], - "tag:userandgroup": ["testuser@headscale.net" ,"group:other"], + "tag:userandgroup": ["testuser@headscale.net", "group:other"], }, "hosts": { @@ -132,9 +132,9 @@ func TestUnmarshalPolicy(t *testing.T) { Group("group:other"): []Username{Username("otheruser@headscale.net")}, }, TagOwners: TagOwners{ - Tag("tag:user"): Owners{ptr.To(Username("testuser@headscale.net"))}, - Tag("tag:group"): Owners{Group("group:other")}, - Tag("tag:userandgroup"): Owners{ptr.To(Username("testuser@headscale.net")), Group("group:other")}, + Tag("tag:user"): Owners{up("testuser@headscale.net")}, + Tag("tag:group"): Owners{gp("group:other")}, + Tag("tag:userandgroup"): Owners{up("testuser@headscale.net"), gp("group:other")}, }, Hosts: Hosts{ "host-1": Prefix(netip.MustParsePrefix("100.100.100.100/32")), @@ -152,14 +152,14 @@ func TestUnmarshalPolicy(t *testing.T) { // any approved subnets and autogroup:shared. // It does not allow traffic originating from // non-tailscale devices (unless it is an approved route). - Host("*"), + hp("*"), }, Destinations: []AliasWithPorts{ { // TODO(kradalby): Should this be host? // It is: // Includes any destination (no restrictions). - Alias: Host("*"), + Alias: hp("*"), Ports: []tailcfg.PortRange{tailcfg.PortRangeAny}, }, }, @@ -181,11 +181,11 @@ func TestUnmarshalPolicy(t *testing.T) { Action: "accept", Protocol: "tcp", Sources: Aliases{ - Group("group:example"), + gp("group:example"), }, Destinations: []AliasWithPorts{ { - Alias: Group("group:other"), + Alias: gp("group:other"), Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}}, }, }, @@ -194,11 +194,11 @@ func TestUnmarshalPolicy(t *testing.T) { Action: "accept", Protocol: "tcp", Sources: Aliases{ - ptr.To(Prefix(netip.MustParsePrefix("100.101.102.103/32"))), + pp("100.101.102.103/32"), }, Destinations: []AliasWithPorts{ { - Alias: ptr.To(Prefix(netip.MustParsePrefix("100.101.102.104/32"))), + Alias: pp("100.101.102.104/32"), Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}}, }, }, @@ -207,11 +207,11 @@ func TestUnmarshalPolicy(t *testing.T) { Action: "accept", Protocol: "udp", Sources: Aliases{ - ptr.To(Prefix(netip.MustParsePrefix("10.0.0.0/8"))), + pp("10.0.0.0/8"), }, Destinations: []AliasWithPorts{ { - Alias: ptr.To(Prefix(netip.MustParsePrefix("172.16.0.0/16"))), + Alias: pp("172.16.0.0/16"), Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}}, }, }, @@ -220,11 +220,11 @@ func TestUnmarshalPolicy(t *testing.T) { Action: "accept", Protocol: "tcp", Sources: Aliases{ - Host("subnet-1"), + hp("subnet-1"), }, Destinations: []AliasWithPorts{ { - Alias: Host("host-1"), + Alias: hp("host-1"), Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 88}}, }, }, @@ -233,11 +233,11 @@ func TestUnmarshalPolicy(t *testing.T) { Action: "accept", Protocol: "tcp", Sources: Aliases{ - Tag("tag:group"), + tp("tag:group"), }, Destinations: []AliasWithPorts{ { - Alias: Tag("tag:user"), + Alias: tp("tag:user"), Ports: []tailcfg.PortRange{ tailcfg.PortRange{First: 80, Last: 80}, tailcfg.PortRange{First: 443, Last: 443}, @@ -249,11 +249,11 @@ func TestUnmarshalPolicy(t *testing.T) { Action: "accept", Protocol: "tcp", Sources: Aliases{ - Tag("tag:group"), + tp("tag:group"), }, Destinations: []AliasWithPorts{ { - Alias: AutoGroup("autogroup:internet"), + Alias: agp("autogroup:internet"), Ports: []tailcfg.PortRange{ tailcfg.PortRange{First: 80, Last: 80}, }, @@ -367,6 +367,11 @@ func TestUnmarshalPolicy(t *testing.T) { } } +func gp(s string) *Group { return ptr.To(Group(s)) } +func up(s string) *Username { return ptr.To(Username(s)) } +func hp(s string) *Host { return ptr.To(Host(s)) } +func tp(s string) *Tag { return ptr.To(Tag(s)) } +func agp(s string) *AutoGroup { return ptr.To(AutoGroup(s)) } func mp(pref string) netip.Prefix { return netip.MustParsePrefix(pref) } func ap(addr string) *netip.Addr { return ptr.To(netip.MustParseAddr(addr)) } func pp(pref string) *Prefix { return ptr.To(Prefix(netip.MustParsePrefix(pref))) } @@ -392,7 +397,7 @@ func TestResolvePolicy(t *testing.T) { "testhost": p("100.100.101.102/32"), }, }, - toResolve: Host("testhost"), + toResolve: hp("testhost"), want: []netip.Prefix{mp("100.100.101.102/32")}, }, { @@ -491,7 +496,7 @@ func TestResolvePolicy(t *testing.T) { }, { name: "tag", - toResolve: Tag("tag:test"), + toResolve: tp("tag:test"), nodes: types.Nodes{ // Not matching other user {