From 4f46d6513bc96eefb9b3897641b0f13eb9042fec Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 16 Oct 2024 21:56:32 +0200 Subject: [PATCH] loads Signed-off-by: Kristoffer Dalby --- hscontrol/policyv2/filter.go | 79 +++++ hscontrol/policyv2/filter_test.go | 378 ++++++++++++++++++++++ hscontrol/policyv2/types.go | 468 ++++++++++++++++++++++----- hscontrol/policyv2/types_test.go | 507 ++++++++++++++++++++++++++++-- hscontrol/types/node.go | 45 +++ 5 files changed, 1380 insertions(+), 97 deletions(-) create mode 100644 hscontrol/policyv2/filter.go create mode 100644 hscontrol/policyv2/filter_test.go diff --git a/hscontrol/policyv2/filter.go b/hscontrol/policyv2/filter.go new file mode 100644 index 00000000..27f8c8d8 --- /dev/null +++ b/hscontrol/policyv2/filter.go @@ -0,0 +1,79 @@ +package policyv2 + +import ( + "errors" + "fmt" + + "github.com/juanfont/headscale/hscontrol/types" + "go4.org/netipx" + "tailscale.com/tailcfg" +) + +var ( + ErrInvalidAction = errors.New("invalid action") +) + +// CompileFilterRules takes a set of nodes and an ACLPolicy and generates a +// set of Tailscale compatible FilterRules used to allow traffic on clients. +func (pol *Policy) CompileFilterRules( + nodes types.Nodes, +) ([]tailcfg.FilterRule, error) { + if pol == nil { + return tailcfg.FilterAllowAll, nil + } + + var rules []tailcfg.FilterRule + + for _, acl := range pol.ACLs { + if acl.Action != "accept" { + return nil, ErrInvalidAction + } + + srcIPs, err := acl.Sources.Resolve(pol, nodes) + if err != nil { + return nil, fmt.Errorf("resolving source ips: %w", err) + } + + // TODO(kradalby): integrate type into schema + // TODO(kradalby): figure out the _ is wildcard stuff + protocols, _, err := parseProtocol(acl.Protocol) + if err != nil { + return nil, fmt.Errorf("parsing policy, protocol err: %w ", err) + } + + var destPorts []tailcfg.NetPortRange + for _, dest := range acl.Destinations { + ips, err := dest.Alias.Resolve(pol, nodes) + if err != nil { + return nil, err + } + + for _, pref := range ips.Prefixes() { + for _, port := range dest.Ports { + pr := tailcfg.NetPortRange{ + IP: pref.String(), + Ports: port, + } + destPorts = append(destPorts, pr) + } + } + } + + rules = append(rules, tailcfg.FilterRule{ + SrcIPs: ipSetToPrefixStringList(srcIPs), + DstPorts: destPorts, + IPProto: protocols, + }) + } + + return rules, nil +} + +func ipSetToPrefixStringList(ips *netipx.IPSet) []string { + var out []string + + for _, pref := range ips.Prefixes() { + out = append(out, pref.String()) + } + return out +} diff --git a/hscontrol/policyv2/filter_test.go b/hscontrol/policyv2/filter_test.go new file mode 100644 index 00000000..4edf7233 --- /dev/null +++ b/hscontrol/policyv2/filter_test.go @@ -0,0 +1,378 @@ +package policyv2 + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/juanfont/headscale/hscontrol/types" + "tailscale.com/tailcfg" +) + +// TODO(kradalby): +// Convert policy.TestReduceFilterRules to take JSON +// Move it here, run it against both old and new CompileFilterRules + +func TestParsing(t *testing.T) { + tests := []struct { + name string + format string + acl string + want []tailcfg.FilterRule + wantErr bool + }{ + { + name: "invalid-hujson", + format: "hujson", + acl: ` +{ + `, + want: []tailcfg.FilterRule{}, + wantErr: true, + }, + // The new parser will ignore all that is irrelevant + // { + // name: "valid-hujson-invalid-content", + // format: "hujson", + // acl: ` + // { + // "valid_json": true, + // "but_a_policy_though": false + // } + // `, + // want: []tailcfg.FilterRule{}, + // wantErr: true, + // }, + // { + // name: "invalid-cidr", + // format: "hujson", + // acl: ` + // {"example-host-1": "100.100.100.100/42"} + // `, + // want: []tailcfg.FilterRule{}, + // wantErr: true, + // }, + { + name: "basic-rule", + format: "hujson", + acl: ` +{ + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "action": "accept", + "src": [ + "subnet-1", + "192.168.1.0/24" + ], + "dst": [ + "*:22,3389", + "host-1:*", + ], + }, + ], +} + `, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.100.101.0/24", "192.168.1.0/24"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "0.0.0.0/0", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + {IP: "0.0.0.0/0", Ports: tailcfg.PortRange{First: 3389, Last: 3389}}, + {IP: "::/0", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + {IP: "::/0", Ports: tailcfg.PortRange{First: 3389, Last: 3389}}, + {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, + }, + }, + }, + wantErr: false, + }, + { + name: "parse-protocol", + format: "hujson", + acl: ` +{ + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "Action": "accept", + "src": [ + "*", + ], + "proto": "tcp", + "dst": [ + "host-1:*", + ], + }, + { + "Action": "accept", + "src": [ + "*", + ], + "proto": "udp", + "dst": [ + "host-1:53", + ], + }, + { + "Action": "accept", + "src": [ + "*", + ], + "proto": "icmp", + "dst": [ + "host-1:*", + ], + }, + ], +}`, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"0.0.0.0/0", "::/0"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, + }, + IPProto: []int{protocolTCP}, + }, + { + SrcIPs: []string{"0.0.0.0/0", "::/0"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.100.100.100/32", Ports: tailcfg.PortRange{First: 53, Last: 53}}, + }, + IPProto: []int{protocolUDP}, + }, + { + SrcIPs: []string{"0.0.0.0/0", "::/0"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, + }, + IPProto: []int{protocolICMP, protocolIPv6ICMP}, + }, + }, + wantErr: false, + }, + { + name: "port-wildcard", + format: "hujson", + acl: ` +{ + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "Action": "accept", + "src": [ + "*", + ], + "dst": [ + "host-1:*", + ], + }, + ], +} +`, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"0.0.0.0/0", "::/0"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, + }, + }, + }, + wantErr: false, + }, + { + name: "port-range", + format: "hujson", + acl: ` +{ + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "action": "accept", + "src": [ + "subnet-1", + ], + "dst": [ + "host-1:5400-5500", + ], + }, + ], +} +`, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.100.101.0/24"}, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "100.100.100.100/32", + Ports: tailcfg.PortRange{First: 5400, Last: 5500}, + }, + }, + }, + }, + wantErr: false, + }, + { + name: "port-group", + format: "hujson", + acl: ` +{ + "groups": { + "group:example": [ + "testuser@", + ], + }, + + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "action": "accept", + "src": [ + "group:example", + ], + "dst": [ + "host-1:*", + ], + }, + ], +} +`, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"200.200.200.200/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, + }, + }, + }, + wantErr: false, + }, + { + name: "port-user", + format: "hujson", + acl: ` +{ + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "action": "accept", + "src": [ + "testuser@", + ], + "dst": [ + "host-1:*", + ], + }, + ], +} +`, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"200.200.200.200/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, + }, + }, + }, + wantErr: false, + }, + { + name: "ipv6", + format: "hujson", + acl: ` +{ + "hosts": { + "host-1": "100.100.100.100/32", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "action": "accept", + "src": [ + "*", + ], + "dst": [ + "host-1:*", + ], + }, + ], +} +`, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"0.0.0.0/0", "::/0"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, + }, + }, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pol, err := PolicyFromBytes([]byte(tt.acl)) + if tt.wantErr && err == nil { + t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr) + + return + } else if !tt.wantErr && err != nil { + t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr) + + return + } + + if err != nil { + return + } + + rules, err := pol.CompileFilterRules(types.Nodes{ + &types.Node{ + IPv4: ap("100.100.100.100"), + }, + &types.Node{ + IPv4: ap("200.200.200.200"), + User: types.User{ + Name: "testuser", + }, + Hostinfo: &tailcfg.Hostinfo{}, + }, + }) + + if (err != nil) != tt.wantErr { + t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr) + + return + } + + if diff := cmp.Diff(tt.want, rules); diff != "" { + t.Errorf("parsing() unexpected result (-want +got):\n%s", diff) + } + }) + } +} diff --git a/hscontrol/policyv2/types.go b/hscontrol/policyv2/types.go index 2f7af07b..3bc65e1c 100644 --- a/hscontrol/policyv2/types.go +++ b/hscontrol/policyv2/types.go @@ -9,123 +9,293 @@ import ( "strconv" "strings" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" - "github.com/rs/zerolog/log" + "github.com/tailscale/hujson" + "go4.org/netipx" + "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" + "tailscale.com/types/ptr" ) +var theInternetSet *netipx.IPSet + +// theInternet returns the IPSet for the Internet. +// https://www.youtube.com/watch?v=iDbyYGrswtg +func theInternet() *netipx.IPSet { + if theInternetSet != nil { + return theInternetSet + } + + var internetBuilder netipx.IPSetBuilder + internetBuilder.AddPrefix(netip.MustParsePrefix("2000::/3")) + internetBuilder.AddPrefix(tsaddr.AllIPv4()) + + // Delete Private network addresses + // https://datatracker.ietf.org/doc/html/rfc1918 + internetBuilder.RemovePrefix(netip.MustParsePrefix("fc00::/7")) + internetBuilder.RemovePrefix(netip.MustParsePrefix("10.0.0.0/8")) + internetBuilder.RemovePrefix(netip.MustParsePrefix("172.16.0.0/12")) + internetBuilder.RemovePrefix(netip.MustParsePrefix("192.168.0.0/16")) + + // Delete Tailscale networks + internetBuilder.RemovePrefix(tsaddr.TailscaleULARange()) + internetBuilder.RemovePrefix(tsaddr.CGNATRange()) + + // Delete "cant find DHCP networks" + internetBuilder.RemovePrefix(netip.MustParsePrefix("fe80::/10")) // link-loca + internetBuilder.RemovePrefix(netip.MustParsePrefix("169.254.0.0/16")) + + theInternetSet, _ := internetBuilder.IPSet() + return theInternetSet +} + +type Asterix string + +func (a Asterix) Validate() error { + if a == "*" { + return nil + } + return fmt.Errorf(`Asterix can only be "*", got: %s`, a) +} + +func (a *Asterix) String() string { + return string(*a) +} + +func (a *Asterix) UnmarshalJSON(b []byte) error { + *a = "*" + return nil +} + // Username is a string that represents a username, it must contain an @. type Username string -func (u Username) Valid() bool { - return strings.Contains(string(u), "@") +func (u Username) Validate() error { + if strings.Contains(string(u), "@") { + return nil + } + return fmt.Errorf("Username has to contain @, got: %q", u) } -func (u Username) UnmarshalJSON(b []byte) error { - u = Username(strings.Trim(string(b), `"`)) - if !u.Valid() { - return fmt.Errorf("invalid username %q", u) +func (u *Username) String() string { + return string(*u) +} + +func (u *Username) UnmarshalJSON(b []byte) error { + *u = Username(strings.Trim(string(b), `"`)) + if err := u.Validate(); err != nil { + return err } return nil } +func (u Username) CanBeTagOwner() bool { + return true +} + +func (u Username) Resolve(_ *Policy, nodes types.Nodes) (*netipx.IPSet, error) { + var ips netipx.IPSetBuilder + + for _, node := range nodes { + if node.IsTagged() { + continue + } + + if node.User.Username() == string(u) { + node.AppendToIPSet(&ips) + } + } + + return ips.IPSet() +} + // Group is a special string which is always prefixed with `group:` type Group string -func (g Group) Valid() bool { - return strings.HasPrefix(string(g), "group:") +func (g Group) Validate() error { + if strings.HasPrefix(string(g), "group:") { + return nil + } + return fmt.Errorf(`Group has to start with "group:", got: %q`, g) } func (g Group) UnmarshalJSON(b []byte) error { g = Group(strings.Trim(string(b), `"`)) - if !g.Valid() { - return fmt.Errorf("invalid group %q", g) + if err := g.Validate(); err != nil { + return err } return nil } +func (g Group) CanBeTagOwner() bool { + return true +} + +func (g Group) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) { + var ips netipx.IPSetBuilder + + for _, user := range p.Groups[g] { + uips, err := user.Resolve(nil, nodes) + if err != nil { + return nil, err + } + + ips.AddSet(uips) + } + + return ips.IPSet() +} + // Tag is a special string which is always prefixed with `tag:` type Tag string -func (t Tag) Valid() bool { - return strings.HasPrefix(string(t), "tag:") +func (t Tag) Validate() error { + if strings.HasPrefix(string(t), "tag:") { + return nil + } + return fmt.Errorf(`tag has to start with "tag:", got: %q`, t) } func (t Tag) UnmarshalJSON(b []byte) error { t = Tag(strings.Trim(string(b), `"`)) - if !t.Valid() { - return fmt.Errorf("invalid tag %q", t) + if err := t.Validate(); err != nil { + return err } return nil } +func (t Tag) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) { + var ips netipx.IPSetBuilder + + for _, node := range nodes { + if node.HasTag(string(t)) { + node.AppendToIPSet(&ips) + } + } + + return ips.IPSet() +} + // Host is a string that represents a hostname. type Host string -func (h Host) Valid() bool { - return true +func (h Host) Validate() error { + return nil } func (h Host) UnmarshalJSON(b []byte) error { h = Host(strings.Trim(string(b), `"`)) - if !h.Valid() { - return fmt.Errorf("invalid host %q", h) - } - return nil -} - -type Addr netip.Addr - -func (a Addr) Valid() bool { - return netip.Addr(a).IsValid() -} - -func (a Addr) UnmarshalJSON(b []byte) error { - a = Addr(netip.Addr{}) - if err := json.Unmarshal(b, (netip.Addr)(a)); err != nil { + if err := h.Validate(); err != nil { return err } - if !a.Valid() { - return fmt.Errorf("invalid address %v", a) - } return nil } +func (h Host) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) { + var ips netipx.IPSetBuilder + + ips.AddPrefix(netip.Prefix(p.Hosts[h])) + + return ips.IPSet() +} + type Prefix netip.Prefix -func (p Prefix) Valid() bool { - return netip.Prefix(p).IsValid() +func (p Prefix) Validate() error { + if !netip.Prefix(p).IsValid() { + return fmt.Errorf("Prefix %q is invalid", p) + } + + return nil } -func (p Prefix) UnmarshalJSON(b []byte) error { - p = Prefix(netip.Prefix{}) - if err := json.Unmarshal(b, (netip.Prefix)(p)); err != nil { +func (p Prefix) String() string { + return netip.Prefix(p).String() +} + +func (p *Prefix) parseString(addr string) error { + if !strings.Contains(addr, "/") { + addr, err := netip.ParseAddr(addr) + if err != nil { + return err + } + addrPref, err := addr.Prefix(addr.BitLen()) + if err != nil { + return err + } + + *p = Prefix(addrPref) + return nil + } + + pref, err := netip.ParsePrefix(addr) + if err != nil { return err } - if !p.Valid() { - return fmt.Errorf("invalid prefix %v", p) + *p = Prefix(pref) + return nil +} + +func (p *Prefix) UnmarshalJSON(b []byte) error { + err := p.parseString(strings.Trim(string(b), `"`)) + if err != nil { + return err + } + if err := p.Validate(); err != nil { + return err } return nil } +func (p Prefix) Resolve(_ *Policy, _ types.Nodes) (*netipx.IPSet, error) { + var ips netipx.IPSetBuilder + + ips.AddPrefix(netip.Prefix(p)) + + return ips.IPSet() +} + // AutoGroup is a special string which is always prefixed with `autogroup:` type AutoGroup string -func (ag AutoGroup) Valid() bool { - return strings.HasPrefix(string(ag), "autogroup:") +const ( + AutoGroupInternet = "autogroup:internet" +) + +var autogroups = []string{AutoGroupInternet} + +func (ag AutoGroup) Validate() error { + for _, valid := range autogroups { + if valid == string(ag) { + return nil + } + } + + return fmt.Errorf("AutoGroup is invalid, got: %q, must be one of %v", ag, autogroups) } func (ag AutoGroup) UnmarshalJSON(b []byte) error { ag = AutoGroup(strings.Trim(string(b), `"`)) - if !ag.Valid() { - return fmt.Errorf("invalid autogroup %q", ag) + if err := ag.Validate(); err != nil { + return err } return nil } +func (ag AutoGroup) Resolve(_ *Policy, _ types.Nodes) (*netipx.IPSet, error) { + switch ag { + case AutoGroupInternet: + return theInternet(), nil + } + + return nil, nil +} + type Alias interface { - Valid() bool + Validate() error UnmarshalJSON([]byte) error + Resolve(*Policy, types.Nodes) (*netipx.IPSet, error) } type AliasWithPorts struct { @@ -160,6 +330,9 @@ func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error { } ve.Alias = parseAlias(vs) + if err := ve.Alias.Validate(); err != nil { + return err + } default: return fmt.Errorf("type %T not supported", vs) @@ -172,19 +345,19 @@ func parseAlias(vs string) Alias { // ve.Alias = Addr(val) // case netip.Prefix: // ve.Alias = Prefix(val) - if addr, err := netip.ParseAddr(vs); err == nil { - return Addr(addr) - } - - if prefix, err := netip.ParsePrefix(vs); err == nil { - return Prefix(prefix) + var pref Prefix + err := pref.parseString(vs) + if err == nil { + return &pref } switch { + case vs == "*": + return ptr.To(Asterix("*")) case strings.Contains(vs, "@"): - return Username(vs) + return ptr.To(Username(vs)) case strings.HasPrefix(vs, "group:"): - return Group(vs) + ptr.To(Group(vs)) case strings.HasPrefix(vs, "tag:"): return Tag(vs) case strings.HasPrefix(vs, "autogroup:"): @@ -206,6 +379,10 @@ func (ve *AliasEnc) UnmarshalJSON(b []byte) error { switch val := v.(type) { case string: ve.Alias = parseAlias(val) + ve.Alias = parseAlias(val) + if err := ve.Alias.Validate(); err != nil { + return err + } default: return fmt.Errorf("type %T not supported", val) } @@ -228,24 +405,78 @@ func (a *Aliases) UnmarshalJSON(b []byte) error { return nil } -// UserEntity is an interface that represents something that can -// return a list of users: -// - Username -// - Group -// - AutoGroup -type UserEntity interface { - Users() []Username +func (a Aliases) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) { + var ips netipx.IPSetBuilder + + for _, alias := range a { + aips, err := alias.Resolve(p, nodes) + if err != nil { + return nil, err + } + + ips.AddSet(aips) + } + + return ips.IPSet() +} + +type Owner interface { + CanBeTagOwner() bool UnmarshalJSON([]byte) error } +// OwnerEnc is used to deserialize a Owner. +type OwnerEnc struct{ Owner } + +func (ve *OwnerEnc) UnmarshalJSON(b []byte) error { + // TODO(kradalby): use encoding/json/v2 (go-json-experiment) + dec := json.NewDecoder(bytes.NewReader(b)) + var v any + if err := dec.Decode(&v); err != nil { + return err + } + switch val := v.(type) { + case string: + + switch { + case strings.Contains(val, "@"): + u := Username(val) + ve.Owner = &u + case strings.HasPrefix(val, "group:"): + ve.Owner = Group(val) + } + default: + return fmt.Errorf("type %T not supported", val) + } + return nil +} + +type Owners []Owner + +func (o *Owners) UnmarshalJSON(b []byte) error { + var owners []OwnerEnc + err := json.Unmarshal(b, &owners) + if err != nil { + return err + } + + *o = make([]Owner, len(owners)) + for i, owner := range owners { + (*o)[i] = owner.Owner + } + return nil +} + +type Usernames []Username + // Groups are a map of Group to a list of Username. -type Groups map[Group][]Username +type Groups map[Group]Usernames // Hosts are alias for IP addresses or subnets. -type Hosts map[Host]netip.Prefix +type Hosts map[Host]Prefix // TagOwners are a map of Tag to a list of the UserEntities that own the tag. -type TagOwners map[Tag][]UserEntity +type TagOwners map[Tag]Owners type AutoApprovers struct { Routes map[string][]string `json:"routes"` @@ -259,16 +490,45 @@ type ACL struct { Destinations []AliasWithPorts `json:"dst"` } -// ACLPolicy represents a Tailscale ACL Policy. -type ACLPolicy struct { - Groups Groups `json:"groups"` - // Hosts Hosts `json:"hosts"` +// Policy represents a Tailscale Network Policy. +// TODO(kradalby): +// Add validation method checking: +// All users exists +// All groups and users are valid tag TagOwners +// Everything referred to in ACLs exists in other +// entities. +type Policy struct { + // validated is set if the policy has been validated. + // It is not safe to use before it is validated, and + // callers using it should panic if not + validated bool `json:"-"` + + Groups Groups `json:"groups"` + Hosts Hosts `json:"hosts"` TagOwners TagOwners `json:"tagOwners"` ACLs []ACL `json:"acls"` AutoApprovers AutoApprovers `json:"autoApprovers"` // SSHs []SSH `json:"ssh"` } +func PolicyFromBytes(b []byte) (*Policy, error) { + var policy Policy + ast, err := hujson.Parse(b) + if err != nil { + return nil, fmt.Errorf("parsing HuJSON: %w", err) + } + + ast.Standardize() + acl := ast.Pack() + + err = json.Unmarshal(acl, &policy) + if err != nil { + return nil, fmt.Errorf("parsing policy from bytes: %w", err) + } + + return &policy, nil +} + const ( expectedTokenItems = 2 ) @@ -329,7 +589,6 @@ func parsePorts(portsStr string) ([]tailcfg.PortRange, error) { var ports []tailcfg.PortRange for _, portStr := range strings.Split(portsStr, ",") { - log.Trace().Msgf("parsing portstring: %s", portStr) rang := strings.Split(portStr, "-") switch len(rang) { case 1: @@ -363,3 +622,72 @@ func parsePorts(portsStr string) ([]tailcfg.PortRange, error) { return ports, nil } + +// For some reason golang.org/x/net/internal/iana is an internal package. +const ( + protocolICMP = 1 // Internet Control Message + protocolIGMP = 2 // Internet Group Management + protocolIPv4 = 4 // IPv4 encapsulation + protocolTCP = 6 // Transmission Control + protocolEGP = 8 // Exterior Gateway Protocol + protocolIGP = 9 // any private interior gateway (used by Cisco for their IGRP) + protocolUDP = 17 // User Datagram + protocolGRE = 47 // Generic Routing Encapsulation + protocolESP = 50 // Encap Security Payload + protocolAH = 51 // Authentication Header + protocolIPv6ICMP = 58 // ICMP for IPv6 + protocolSCTP = 132 // Stream Control Transmission Protocol + ProtocolFC = 133 // Fibre Channel +) + +// parseProtocol reads the proto field of the ACL and generates a list of +// protocols that will be allowed, following the IANA IP protocol number +// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml +// +// If the ACL proto field is empty, it allows ICMPv4, ICMPv6, TCP, and UDP, +// as per Tailscale behaviour (see tailcfg.FilterRule). +// +// Also returns a boolean indicating if the protocol +// requires all the destinations to use wildcard as port number (only TCP, +// UDP and SCTP support specifying ports). +func parseProtocol(protocol string) ([]int, bool, error) { + switch protocol { + case "": + return nil, false, nil + case "igmp": + return []int{protocolIGMP}, true, nil + case "ipv4", "ip-in-ip": + return []int{protocolIPv4}, true, nil + case "tcp": + return []int{protocolTCP}, false, nil + case "egp": + return []int{protocolEGP}, true, nil + case "igp": + return []int{protocolIGP}, true, nil + case "udp": + return []int{protocolUDP}, false, nil + case "gre": + return []int{protocolGRE}, true, nil + case "esp": + return []int{protocolESP}, true, nil + case "ah": + return []int{protocolAH}, true, nil + case "sctp": + return []int{protocolSCTP}, false, nil + case "icmp": + return []int{protocolICMP, protocolIPv6ICMP}, true, nil + + default: + protocolNumber, err := strconv.Atoi(protocol) + if err != nil { + return nil, false, fmt.Errorf("parsing protocol number: %w", err) + } + + // TODO(kradalby): What is this? + needsWildcard := protocolNumber != protocolTCP && + protocolNumber != protocolUDP && + protocolNumber != protocolSCTP + + return []int{protocolNumber}, needsWildcard, nil + } +} diff --git a/hscontrol/policyv2/types_test.go b/hscontrol/policyv2/types_test.go index ef47bb7d..26423a61 100644 --- a/hscontrol/policyv2/types_test.go +++ b/hscontrol/policyv2/types_test.go @@ -1,25 +1,45 @@ package policyv2 import ( - "encoding/json" + "net/netip" "testing" "github.com/google/go-cmp/cmp" - "github.com/tailscale/hujson" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" "tailscale.com/tailcfg" + "tailscale.com/types/ptr" ) func TestUnmarshalPolicy(t *testing.T) { tests := []struct { name string input string - want *ACLPolicy - wantErr error + want *Policy + wantErr string }{ { name: "empty", input: "{}", - want: &ACLPolicy{}, + want: &Policy{}, + }, + { + name: "groups", + input: ` +{ + "groups": { + "group:example": [ + "derp@headscale.net", + ], + }, +} +`, + want: &Policy{ + Groups: Groups{ + Group("group:example"): []Username{Username("derp@headscale.net")}, + }, + }, }, { name: "basic-types", @@ -29,67 +49,500 @@ func TestUnmarshalPolicy(t *testing.T) { "group:example": [ "testuser@headscale.net", ], + "group:other": [ + "otheruser@headscale.net", + ], + }, + + "tagOwners": { + "tag:user": ["testuser@headscale.net"], + "tag:group": ["group:other"], + "tag:userandgroup": ["testuser@headscale.net" ,"group:other"], }, "hosts": { "host-1": "100.100.100.100", "subnet-1": "100.100.101.100/24", + "outside": "192.168.0.0/16", }, "acls": [ + // All { "action": "accept", - "src": [ - "group:example", - ], - "dst": [ - "host-1:*", - ], + "proto": "tcp", + "src": ["*"], + "dst": ["*:*"], + }, + // Users + { + "action": "accept", + "proto": "tcp", + "src": ["testuser@headscale.net"], + "dst": ["otheruser@headscale.net:80"], + }, + // Groups + { + "action": "accept", + "proto": "tcp", + "src": ["group:example"], + "dst": ["group:other:80"], + }, + // Tailscale IP + { + "action": "accept", + "proto": "tcp", + "src": ["100.101.102.103"], + "dst": ["100.101.102.104:80"], + }, + // Subnet + { + "action": "accept", + "proto": "udp", + "src": ["10.0.0.0/8"], + "dst": ["172.16.0.0/16:80"], + }, + // Hosts + { + "action": "accept", + "proto": "tcp", + "src": ["subnet-1"], + "dst": ["host-1:80-88"], + }, + // Tags + { + "action": "accept", + "proto": "tcp", + "src": ["tag:group"], + "dst": ["tag:user:80,443"], + }, + // Autogroup + { + "action": "accept", + "proto": "tcp", + "src": ["tag:group"], + "dst": ["autogroup:internet:80"], }, ], } `, - want: &ACLPolicy{ + want: &Policy{ Groups: Groups{ - Group("group:example"): []Username{"testuser@headscale.net"}, + Group("group:example"): []Username{Username("testuser@headscale.net")}, + Group("group:other"): []Username{Username("otheruser@headscale.net")}, + }, + TagOwners: TagOwners{ + Tag("tag:user"): Owners{ptr.To(Username("testuser@headscale.net"))}, + Tag("tag:group"): Owners{Group("group:other")}, + Tag("tag:userandgroup"): Owners{ptr.To(Username("testuser@headscale.net")), Group("group:other")}, + }, + Hosts: Hosts{ + "host-1": Prefix(netip.MustParsePrefix("100.100.100.100/32")), + "subnet-1": Prefix(netip.MustParsePrefix("100.100.101.100/24")), + "outside": Prefix(netip.MustParsePrefix("192.168.0.0/16")), }, ACLs: []ACL{ { - Action: "accept", + Action: "accept", + Protocol: "tcp", + Sources: Aliases{ + // TODO(kradalby): Should this be host? + // It is: + // All traffic originating from Tailscale devices in your tailnet, + // any approved subnets and autogroup:shared. + // It does not allow traffic originating from + // non-tailscale devices (unless it is an approved route). + Host("*"), + }, + Destinations: []AliasWithPorts{ + { + // TODO(kradalby): Should this be host? + // It is: + // Includes any destination (no restrictions). + Alias: Host("*"), + Ports: []tailcfg.PortRange{tailcfg.PortRangeAny}, + }, + }, + }, + { + Action: "accept", + Protocol: "tcp", + Sources: Aliases{ + ptr.To(Username("testuser@headscale.net")), + }, + Destinations: []AliasWithPorts{ + { + Alias: ptr.To(Username("otheruser@headscale.net")), + Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}}, + }, + }, + }, + { + Action: "accept", + Protocol: "tcp", Sources: Aliases{ Group("group:example"), }, + Destinations: []AliasWithPorts{ + { + Alias: Group("group:other"), + Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}}, + }, + }, + }, + { + Action: "accept", + Protocol: "tcp", + Sources: Aliases{ + ptr.To(Prefix(netip.MustParsePrefix("100.101.102.103/32"))), + }, + Destinations: []AliasWithPorts{ + { + Alias: ptr.To(Prefix(netip.MustParsePrefix("100.101.102.104/32"))), + Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}}, + }, + }, + }, + { + Action: "accept", + Protocol: "udp", + Sources: Aliases{ + ptr.To(Prefix(netip.MustParsePrefix("10.0.0.0/8"))), + }, + Destinations: []AliasWithPorts{ + { + Alias: ptr.To(Prefix(netip.MustParsePrefix("172.16.0.0/16"))), + Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}}, + }, + }, + }, + { + Action: "accept", + Protocol: "tcp", + Sources: Aliases{ + Host("subnet-1"), + }, Destinations: []AliasWithPorts{ { Alias: Host("host-1"), - Ports: []tailcfg.PortRange{tailcfg.PortRangeAny}, + Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 88}}, + }, + }, + }, + { + Action: "accept", + Protocol: "tcp", + Sources: Aliases{ + Tag("tag:group"), + }, + Destinations: []AliasWithPorts{ + { + Alias: Tag("tag:user"), + Ports: []tailcfg.PortRange{ + tailcfg.PortRange{First: 80, Last: 80}, + tailcfg.PortRange{First: 443, Last: 443}, + }, + }, + }, + }, + { + Action: "accept", + Protocol: "tcp", + Sources: Aliases{ + Tag("tag:group"), + }, + Destinations: []AliasWithPorts{ + { + Alias: AutoGroup("autogroup:internet"), + Ports: []tailcfg.PortRange{ + tailcfg.PortRange{First: 80, Last: 80}, + }, }, }, }, }, }, }, + { + name: "invalid-username", + input: ` +{ + "groups": { + "group:example": [ + "valid@", + "invalid", + ], + }, +} +`, + wantErr: `Username has to contain @, got: "invalid"`, + }, + { + name: "invalid-group", + input: ` +{ + "groups": { + "grou:example": [ + "valid@", + ], + }, +} +`, + wantErr: `Group has to start with "group:", got: "grou:example"`, + }, + { + name: "group-in-group", + input: ` +{ + "groups": { + "group:inner": [], + "group:example": [ + "group:inner", + ], + }, +} +`, + wantErr: `Username has to contain @, got: "group:inner"`, + }, + { + name: "invalid-prefix", + input: ` +{ + "hosts": { + "derp": "10.0", + }, +} +`, + wantErr: `ParseAddr("10.0"): IPv4 address too short`, + }, + { + name: "invalid-auto-group", + input: ` +{ + "acls": [ + // Autogroup + { + "action": "accept", + "proto": "tcp", + "src": ["tag:group"], + "dst": ["autogroup:invalid:80"], + }, + ], +} +`, + wantErr: `AutoGroup is invalid, got: "autogroup:invalid", must be one of [autogroup:internet]`, + }, } + cmps := append(util.Comparers, cmp.Comparer(func(x, y Prefix) bool { + return x == y + })) + cmps = append(cmps, cmpopts.IgnoreUnexported(Policy{})) + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - var policy ACLPolicy - ast, err := hujson.Parse([]byte(tt.input)) + policy, err := PolicyFromBytes([]byte(tt.input)) + // TODO(kradalby): This error checking is broken, + // but so is my brain, #longflight + if err == nil { + if tt.wantErr == "" { + return + } + t.Fatalf("got success; wanted error %q", tt.wantErr) + } + if err.Error() != tt.wantErr { + t.Fatalf("got error %q; want %q", err, tt.wantErr) + // } else if err.Error() == tt.wantErr { + // return + } + if err != nil { - t.Fatalf("parsing hujson: %s", err) + t.Fatalf("unexpected err: %q", err) } - ast.Standardize() - acl := ast.Pack() - - if err := json.Unmarshal(acl, &policy); err != nil { - // TODO: check error type - t.Fatalf("unmarshaling json: %s", err) - } - - if diff := cmp.Diff(tt.want, &policy); diff != "" { + if diff := cmp.Diff(tt.want, &policy, cmps...); diff != "" { t.Fatalf("unexpected policy (-want +got):\n%s", diff) } }) } } + +func mp(pref string) netip.Prefix { return netip.MustParsePrefix(pref) } +func ap(addr string) *netip.Addr { return ptr.To(netip.MustParseAddr(addr)) } +func pp(pref string) *Prefix { return ptr.To(Prefix(netip.MustParsePrefix(pref))) } +func p(pref string) Prefix { return Prefix(netip.MustParsePrefix(pref)) } + +func TestResolvePolicy(t *testing.T) { + tests := []struct { + name string + nodes types.Nodes + pol *Policy + toResolve Alias + want []netip.Prefix + }{ + { + name: "prefix", + toResolve: pp("100.100.101.101/32"), + want: []netip.Prefix{mp("100.100.101.101/32")}, + }, + { + name: "host", + pol: &Policy{ + Hosts: Hosts{ + "testhost": p("100.100.101.102/32"), + }, + }, + toResolve: Host("testhost"), + want: []netip.Prefix{mp("100.100.101.102/32")}, + }, + { + name: "username", + toResolve: ptr.To(Username("testuser")), + nodes: types.Nodes{ + // Not matching other user + { + User: types.User{ + Name: "notme", + }, + IPv4: ap("100.100.101.1"), + }, + // Not matching forced tags + { + User: types.User{ + Name: "testuser", + }, + ForcedTags: []string{"tag:anything"}, + IPv4: ap("100.100.101.2"), + }, + // not matchin pak tag + { + User: types.User{ + Name: "testuser", + }, + AuthKey: &types.PreAuthKey{ + Tags: []string{"alsotagged"}, + }, + IPv4: ap("100.100.101.3"), + }, + { + User: types.User{ + Name: "testuser", + }, + IPv4: ap("100.100.101.103"), + }, + { + User: types.User{ + Name: "testuser", + }, + IPv4: ap("100.100.101.104"), + }, + }, + want: []netip.Prefix{mp("100.100.101.103/32"), mp("100.100.101.104/32")}, + }, + { + name: "group", + toResolve: ptr.To(Group("group:testgroup")), + nodes: types.Nodes{ + // Not matching other user + { + User: types.User{ + Name: "notmetoo", + }, + IPv4: ap("100.100.101.4"), + }, + // Not matching forced tags + { + User: types.User{ + Name: "groupuser", + }, + ForcedTags: []string{"tag:anything"}, + IPv4: ap("100.100.101.5"), + }, + // not matchin pak tag + { + User: types.User{ + Name: "groupuser", + }, + AuthKey: &types.PreAuthKey{ + Tags: []string{"tag:alsotagged"}, + }, + IPv4: ap("100.100.101.6"), + }, + { + User: types.User{ + Name: "groupuser", + }, + IPv4: ap("100.100.101.203"), + }, + { + User: types.User{ + Name: "groupuser", + }, + IPv4: ap("100.100.101.204"), + }, + }, + pol: &Policy{ + Groups: Groups{ + "group:testgroup": Usernames{"groupuser"}, + "group:othergroup": Usernames{"notmetoo"}, + }, + }, + want: []netip.Prefix{mp("100.100.101.203/32"), mp("100.100.101.204/32")}, + }, + { + name: "tag", + toResolve: Tag("tag:test"), + nodes: types.Nodes{ + // Not matching other user + { + User: types.User{ + Name: "notmetoo", + }, + IPv4: ap("100.100.101.9"), + }, + // Not matching forced tags + { + ForcedTags: []string{"tag:anything"}, + IPv4: ap("100.100.101.10"), + }, + // not matchin pak tag + { + AuthKey: &types.PreAuthKey{ + Tags: []string{"tag:alsotagged"}, + }, + IPv4: ap("100.100.101.11"), + }, + // Not matching forced tags + { + ForcedTags: []string{"tag:test"}, + IPv4: ap("100.100.101.234"), + }, + // not matchin pak tag + { + AuthKey: &types.PreAuthKey{ + Tags: []string{"tag:test"}, + }, + IPv4: ap("100.100.101.239"), + }, + }, + // TODO(kradalby): tests handling TagOwners + hostinfo + pol: &Policy{}, + want: []netip.Prefix{mp("100.100.101.234/32"), mp("100.100.101.239/32")}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ips, err := tt.toResolve.Resolve(tt.pol, tt.nodes) + if err != nil { + t.Fatalf("failed to resolve: %s", err) + } + + prefs := ips.Prefixes() + + if diff := cmp.Diff(tt.want, prefs, util.Comparers...); diff != "" { + t.Fatalf("unexpected prefs (-want +got):\n%s", diff) + } + }) + } +} diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 36a65062..c8558993 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "net/netip" + "slices" "strconv" "strings" "time" @@ -134,6 +135,50 @@ func (node *Node) IPs() []netip.Addr { return ret } +// IsTagged reports if a device is tagged +// and therefore should not be treated as a +// user owned device. +// Currently, this function only handles tags set +// via CLI ("forced tags" and preauthkeys) +func (node *Node) IsTagged() bool { + if len(node.ForcedTags) > 0 { + return true + } + + if node.AuthKey != nil && len(node.AuthKey.Tags) > 0 { + return true + } + + if node.Hostinfo == nil { + return false + } + + // TODO(kradalby): Figure out how tagging should work + // and hostinfo.requestedtags. + // Do this in other work. + + return false +} + +// HasTag reports if a node has a given tag. +// Currently, this function only handles tags set +// via CLI ("forced tags" and preauthkeys) +func (node *Node) HasTag(tag string) bool { + if slices.Contains(node.ForcedTags, tag) { + return true + } + + if node.AuthKey != nil && slices.Contains(node.AuthKey.Tags, tag) { + return true + } + + // TODO(kradalby): Figure out how tagging should work + // and hostinfo.requestedtags. + // Do this in other work. + + return false +} + func (node *Node) Prefixes() []netip.Prefix { addrs := []netip.Prefix{} for _, nodeAddress := range node.IPs() {