From 086f6a005e5671b3582a54bcba764748541f85a9 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Fri, 4 Oct 2024 17:39:40 +0200 Subject: [PATCH 1/6] can the policy be typed at parsetime? Signed-off-by: Kristoffer Dalby --- hscontrol/policyv2/types.go | 365 +++++++++++++++++++++++++++++++ hscontrol/policyv2/types_test.go | 95 ++++++++ 2 files changed, 460 insertions(+) create mode 100644 hscontrol/policyv2/types.go create mode 100644 hscontrol/policyv2/types_test.go diff --git a/hscontrol/policyv2/types.go b/hscontrol/policyv2/types.go new file mode 100644 index 00000000..2f7af07b --- /dev/null +++ b/hscontrol/policyv2/types.go @@ -0,0 +1,365 @@ +package policyv2 + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "net/netip" + "strconv" + "strings" + + "github.com/juanfont/headscale/hscontrol/util" + "github.com/rs/zerolog/log" + "tailscale.com/tailcfg" +) + +// Username is a string that represents a username, it must contain an @. +type Username string + +func (u Username) Valid() bool { + return strings.Contains(string(u), "@") +} + +func (u Username) UnmarshalJSON(b []byte) error { + u = Username(strings.Trim(string(b), `"`)) + if !u.Valid() { + return fmt.Errorf("invalid username %q", u) + } + return nil +} + +// Group is a special string which is always prefixed with `group:` +type Group string + +func (g Group) Valid() bool { + return strings.HasPrefix(string(g), "group:") +} + +func (g Group) UnmarshalJSON(b []byte) error { + g = Group(strings.Trim(string(b), `"`)) + if !g.Valid() { + return fmt.Errorf("invalid group %q", g) + } + return nil +} + +// Tag is a special string which is always prefixed with `tag:` +type Tag string + +func (t Tag) Valid() bool { + return strings.HasPrefix(string(t), "tag:") +} + +func (t Tag) UnmarshalJSON(b []byte) error { + t = Tag(strings.Trim(string(b), `"`)) + if !t.Valid() { + return fmt.Errorf("invalid tag %q", t) + } + return nil +} + +// Host is a string that represents a hostname. +type Host string + +func (h Host) Valid() bool { + return true +} + +func (h Host) UnmarshalJSON(b []byte) error { + h = Host(strings.Trim(string(b), `"`)) + if !h.Valid() { + return fmt.Errorf("invalid host %q", h) + } + return nil +} + +type Addr netip.Addr + +func (a Addr) Valid() bool { + return netip.Addr(a).IsValid() +} + +func (a Addr) UnmarshalJSON(b []byte) error { + a = Addr(netip.Addr{}) + if err := json.Unmarshal(b, (netip.Addr)(a)); err != nil { + return err + } + if !a.Valid() { + return fmt.Errorf("invalid address %v", a) + } + return nil +} + +type Prefix netip.Prefix + +func (p Prefix) Valid() bool { + return netip.Prefix(p).IsValid() +} + +func (p Prefix) UnmarshalJSON(b []byte) error { + p = Prefix(netip.Prefix{}) + if err := json.Unmarshal(b, (netip.Prefix)(p)); err != nil { + return err + } + if !p.Valid() { + return fmt.Errorf("invalid prefix %v", p) + } + return nil +} + +// AutoGroup is a special string which is always prefixed with `autogroup:` +type AutoGroup string + +func (ag AutoGroup) Valid() bool { + return strings.HasPrefix(string(ag), "autogroup:") +} + +func (ag AutoGroup) UnmarshalJSON(b []byte) error { + ag = AutoGroup(strings.Trim(string(b), `"`)) + if !ag.Valid() { + return fmt.Errorf("invalid autogroup %q", ag) + } + return nil +} + +type Alias interface { + Valid() bool + UnmarshalJSON([]byte) error +} + +type AliasWithPorts struct { + Alias + Ports []tailcfg.PortRange +} + +func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error { + // TODO(kradalby): use encoding/json/v2 (go-json-experiment) + dec := json.NewDecoder(bytes.NewReader(b)) + var v any + if err := dec.Decode(&v); err != nil { + return err + } + + switch vs := v.(type) { + case string: + var portsPart string + var err error + + if strings.Contains(vs, ":") { + vs, portsPart, err = splitDestination(vs) + if err != nil { + return err + } + + ports, err := parsePorts(portsPart) + if err != nil { + return err + } + ve.Ports = ports + } + + ve.Alias = parseAlias(vs) + + default: + return fmt.Errorf("type %T not supported", vs) + } + return nil +} + +func parseAlias(vs string) Alias { + // case netip.Addr: + // ve.Alias = Addr(val) + // case netip.Prefix: + // ve.Alias = Prefix(val) + if addr, err := netip.ParseAddr(vs); err == nil { + return Addr(addr) + } + + if prefix, err := netip.ParsePrefix(vs); err == nil { + return Prefix(prefix) + } + + switch { + case strings.Contains(vs, "@"): + return Username(vs) + case strings.HasPrefix(vs, "group:"): + return Group(vs) + case strings.HasPrefix(vs, "tag:"): + return Tag(vs) + case strings.HasPrefix(vs, "autogroup:"): + return AutoGroup(vs) + } + return Host(vs) +} + +// AliasEnc is used to deserialize a Alias. +type AliasEnc struct{ Alias } + +func (ve *AliasEnc) UnmarshalJSON(b []byte) error { + // TODO(kradalby): use encoding/json/v2 (go-json-experiment) + dec := json.NewDecoder(bytes.NewReader(b)) + var v any + if err := dec.Decode(&v); err != nil { + return err + } + switch val := v.(type) { + case string: + ve.Alias = parseAlias(val) + default: + return fmt.Errorf("type %T not supported", val) + } + return nil +} + +type Aliases []Alias + +func (a *Aliases) UnmarshalJSON(b []byte) error { + var aliases []AliasEnc + err := json.Unmarshal(b, &aliases) + if err != nil { + return err + } + + *a = make([]Alias, len(aliases)) + for i, alias := range aliases { + (*a)[i] = alias.Alias + } + return nil +} + +// UserEntity is an interface that represents something that can +// return a list of users: +// - Username +// - Group +// - AutoGroup +type UserEntity interface { + Users() []Username + UnmarshalJSON([]byte) error +} + +// Groups are a map of Group to a list of Username. +type Groups map[Group][]Username + +// Hosts are alias for IP addresses or subnets. +type Hosts map[Host]netip.Prefix + +// TagOwners are a map of Tag to a list of the UserEntities that own the tag. +type TagOwners map[Tag][]UserEntity + +type AutoApprovers struct { + Routes map[string][]string `json:"routes"` + ExitNode []string `json:"exitNode"` +} + +type ACL struct { + Action string `json:"action"` + Protocol string `json:"proto"` + Sources Aliases `json:"src"` + Destinations []AliasWithPorts `json:"dst"` +} + +// ACLPolicy represents a Tailscale ACL Policy. +type ACLPolicy struct { + Groups Groups `json:"groups"` + // Hosts Hosts `json:"hosts"` + TagOwners TagOwners `json:"tagOwners"` + ACLs []ACL `json:"acls"` + AutoApprovers AutoApprovers `json:"autoApprovers"` + // SSHs []SSH `json:"ssh"` +} + +const ( + expectedTokenItems = 2 +) + +// TODO(kradalby): copy tests from parseDestination in policy +func splitDestination(dest string) (string, string, error) { + var tokens []string + + // Check if there is a IPv4/6:Port combination, IPv6 has more than + // three ":". + tokens = strings.Split(dest, ":") + if len(tokens) < expectedTokenItems || len(tokens) > 3 { + port := tokens[len(tokens)-1] + + maybeIPv6Str := strings.TrimSuffix(dest, ":"+port) + + filteredMaybeIPv6Str := maybeIPv6Str + if strings.Contains(maybeIPv6Str, "/") { + networkParts := strings.Split(maybeIPv6Str, "/") + filteredMaybeIPv6Str = networkParts[0] + } + + if maybeIPv6, err := netip.ParseAddr(filteredMaybeIPv6Str); err != nil && !maybeIPv6.Is6() { + return "", "", fmt.Errorf( + "failed to split destination: %v", + tokens, + ) + } else { + tokens = []string{maybeIPv6Str, port} + } + } + + var alias string + // We can have here stuff like: + // git-server:* + // 192.168.1.0/24:22 + // fd7a:115c:a1e0::2:22 + // fd7a:115c:a1e0::2/128:22 + // tag:montreal-webserver:80,443 + // tag:api-server:443 + // example-host-1:* + if len(tokens) == expectedTokenItems { + alias = tokens[0] + } else { + alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1]) + } + + return alias, tokens[len(tokens)-1], nil +} + +// TODO(kradalby): write/copy tests from expandPorts in policy +func parsePorts(portsStr string) ([]tailcfg.PortRange, error) { + if portsStr == "*" { + return []tailcfg.PortRange{ + tailcfg.PortRangeAny, + }, nil + } + + var ports []tailcfg.PortRange + for _, portStr := range strings.Split(portsStr, ",") { + log.Trace().Msgf("parsing portstring: %s", portStr) + rang := strings.Split(portStr, "-") + switch len(rang) { + case 1: + port, err := strconv.ParseUint(rang[0], util.Base10, util.BitSize16) + if err != nil { + return nil, err + } + ports = append(ports, tailcfg.PortRange{ + First: uint16(port), + Last: uint16(port), + }) + + case expectedTokenItems: + start, err := strconv.ParseUint(rang[0], util.Base10, util.BitSize16) + if err != nil { + return nil, err + } + last, err := strconv.ParseUint(rang[1], util.Base10, util.BitSize16) + if err != nil { + return nil, err + } + ports = append(ports, tailcfg.PortRange{ + First: uint16(start), + Last: uint16(last), + }) + + default: + return nil, errors.New("invalid ports") + } + } + + return ports, nil +} diff --git a/hscontrol/policyv2/types_test.go b/hscontrol/policyv2/types_test.go new file mode 100644 index 00000000..ef47bb7d --- /dev/null +++ b/hscontrol/policyv2/types_test.go @@ -0,0 +1,95 @@ +package policyv2 + +import ( + "encoding/json" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/tailscale/hujson" + "tailscale.com/tailcfg" +) + +func TestUnmarshalPolicy(t *testing.T) { + tests := []struct { + name string + input string + want *ACLPolicy + wantErr error + }{ + { + name: "empty", + input: "{}", + want: &ACLPolicy{}, + }, + { + name: "basic-types", + input: ` +{ + "groups": { + "group:example": [ + "testuser@headscale.net", + ], + }, + + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "action": "accept", + "src": [ + "group:example", + ], + "dst": [ + "host-1:*", + ], + }, + ], +} +`, + want: &ACLPolicy{ + Groups: Groups{ + Group("group:example"): []Username{"testuser@headscale.net"}, + }, + ACLs: []ACL{ + { + Action: "accept", + Sources: Aliases{ + Group("group:example"), + }, + Destinations: []AliasWithPorts{ + { + Alias: Host("host-1"), + Ports: []tailcfg.PortRange{tailcfg.PortRangeAny}, + }, + }, + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var policy ACLPolicy + ast, err := hujson.Parse([]byte(tt.input)) + if err != nil { + t.Fatalf("parsing hujson: %s", err) + } + + ast.Standardize() + acl := ast.Pack() + + if err := json.Unmarshal(acl, &policy); err != nil { + // TODO: check error type + t.Fatalf("unmarshaling json: %s", err) + } + + if diff := cmp.Diff(tt.want, &policy); diff != "" { + t.Fatalf("unexpected policy (-want +got):\n%s", diff) + } + }) + } +} From e5e1f15dd9024526c9a4605abc279baf00c35553 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 16 Oct 2024 21:56:32 +0200 Subject: [PATCH 2/6] loads Signed-off-by: Kristoffer Dalby --- hscontrol/policyv2/filter.go | 79 +++++ hscontrol/policyv2/filter_test.go | 378 ++++++++++++++++++++++ hscontrol/policyv2/types.go | 468 ++++++++++++++++++++++----- hscontrol/policyv2/types_test.go | 507 ++++++++++++++++++++++++++++-- hscontrol/types/node.go | 45 +++ 5 files changed, 1380 insertions(+), 97 deletions(-) create mode 100644 hscontrol/policyv2/filter.go create mode 100644 hscontrol/policyv2/filter_test.go diff --git a/hscontrol/policyv2/filter.go b/hscontrol/policyv2/filter.go new file mode 100644 index 00000000..27f8c8d8 --- /dev/null +++ b/hscontrol/policyv2/filter.go @@ -0,0 +1,79 @@ +package policyv2 + +import ( + "errors" + "fmt" + + "github.com/juanfont/headscale/hscontrol/types" + "go4.org/netipx" + "tailscale.com/tailcfg" +) + +var ( + ErrInvalidAction = errors.New("invalid action") +) + +// CompileFilterRules takes a set of nodes and an ACLPolicy and generates a +// set of Tailscale compatible FilterRules used to allow traffic on clients. +func (pol *Policy) CompileFilterRules( + nodes types.Nodes, +) ([]tailcfg.FilterRule, error) { + if pol == nil { + return tailcfg.FilterAllowAll, nil + } + + var rules []tailcfg.FilterRule + + for _, acl := range pol.ACLs { + if acl.Action != "accept" { + return nil, ErrInvalidAction + } + + srcIPs, err := acl.Sources.Resolve(pol, nodes) + if err != nil { + return nil, fmt.Errorf("resolving source ips: %w", err) + } + + // TODO(kradalby): integrate type into schema + // TODO(kradalby): figure out the _ is wildcard stuff + protocols, _, err := parseProtocol(acl.Protocol) + if err != nil { + return nil, fmt.Errorf("parsing policy, protocol err: %w ", err) + } + + var destPorts []tailcfg.NetPortRange + for _, dest := range acl.Destinations { + ips, err := dest.Alias.Resolve(pol, nodes) + if err != nil { + return nil, err + } + + for _, pref := range ips.Prefixes() { + for _, port := range dest.Ports { + pr := tailcfg.NetPortRange{ + IP: pref.String(), + Ports: port, + } + destPorts = append(destPorts, pr) + } + } + } + + rules = append(rules, tailcfg.FilterRule{ + SrcIPs: ipSetToPrefixStringList(srcIPs), + DstPorts: destPorts, + IPProto: protocols, + }) + } + + return rules, nil +} + +func ipSetToPrefixStringList(ips *netipx.IPSet) []string { + var out []string + + for _, pref := range ips.Prefixes() { + out = append(out, pref.String()) + } + return out +} diff --git a/hscontrol/policyv2/filter_test.go b/hscontrol/policyv2/filter_test.go new file mode 100644 index 00000000..4edf7233 --- /dev/null +++ b/hscontrol/policyv2/filter_test.go @@ -0,0 +1,378 @@ +package policyv2 + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/juanfont/headscale/hscontrol/types" + "tailscale.com/tailcfg" +) + +// TODO(kradalby): +// Convert policy.TestReduceFilterRules to take JSON +// Move it here, run it against both old and new CompileFilterRules + +func TestParsing(t *testing.T) { + tests := []struct { + name string + format string + acl string + want []tailcfg.FilterRule + wantErr bool + }{ + { + name: "invalid-hujson", + format: "hujson", + acl: ` +{ + `, + want: []tailcfg.FilterRule{}, + wantErr: true, + }, + // The new parser will ignore all that is irrelevant + // { + // name: "valid-hujson-invalid-content", + // format: "hujson", + // acl: ` + // { + // "valid_json": true, + // "but_a_policy_though": false + // } + // `, + // want: []tailcfg.FilterRule{}, + // wantErr: true, + // }, + // { + // name: "invalid-cidr", + // format: "hujson", + // acl: ` + // {"example-host-1": "100.100.100.100/42"} + // `, + // want: []tailcfg.FilterRule{}, + // wantErr: true, + // }, + { + name: "basic-rule", + format: "hujson", + acl: ` +{ + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "action": "accept", + "src": [ + "subnet-1", + "192.168.1.0/24" + ], + "dst": [ + "*:22,3389", + "host-1:*", + ], + }, + ], +} + `, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.100.101.0/24", "192.168.1.0/24"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "0.0.0.0/0", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + {IP: "0.0.0.0/0", Ports: tailcfg.PortRange{First: 3389, Last: 3389}}, + {IP: "::/0", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + {IP: "::/0", Ports: tailcfg.PortRange{First: 3389, Last: 3389}}, + {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, + }, + }, + }, + wantErr: false, + }, + { + name: "parse-protocol", + format: "hujson", + acl: ` +{ + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "Action": "accept", + "src": [ + "*", + ], + "proto": "tcp", + "dst": [ + "host-1:*", + ], + }, + { + "Action": "accept", + "src": [ + "*", + ], + "proto": "udp", + "dst": [ + "host-1:53", + ], + }, + { + "Action": "accept", + "src": [ + "*", + ], + "proto": "icmp", + "dst": [ + "host-1:*", + ], + }, + ], +}`, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"0.0.0.0/0", "::/0"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, + }, + IPProto: []int{protocolTCP}, + }, + { + SrcIPs: []string{"0.0.0.0/0", "::/0"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.100.100.100/32", Ports: tailcfg.PortRange{First: 53, Last: 53}}, + }, + IPProto: []int{protocolUDP}, + }, + { + SrcIPs: []string{"0.0.0.0/0", "::/0"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, + }, + IPProto: []int{protocolICMP, protocolIPv6ICMP}, + }, + }, + wantErr: false, + }, + { + name: "port-wildcard", + format: "hujson", + acl: ` +{ + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "Action": "accept", + "src": [ + "*", + ], + "dst": [ + "host-1:*", + ], + }, + ], +} +`, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"0.0.0.0/0", "::/0"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, + }, + }, + }, + wantErr: false, + }, + { + name: "port-range", + format: "hujson", + acl: ` +{ + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "action": "accept", + "src": [ + "subnet-1", + ], + "dst": [ + "host-1:5400-5500", + ], + }, + ], +} +`, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.100.101.0/24"}, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "100.100.100.100/32", + Ports: tailcfg.PortRange{First: 5400, Last: 5500}, + }, + }, + }, + }, + wantErr: false, + }, + { + name: "port-group", + format: "hujson", + acl: ` +{ + "groups": { + "group:example": [ + "testuser@", + ], + }, + + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "action": "accept", + "src": [ + "group:example", + ], + "dst": [ + "host-1:*", + ], + }, + ], +} +`, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"200.200.200.200/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, + }, + }, + }, + wantErr: false, + }, + { + name: "port-user", + format: "hujson", + acl: ` +{ + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "action": "accept", + "src": [ + "testuser@", + ], + "dst": [ + "host-1:*", + ], + }, + ], +} +`, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"200.200.200.200/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, + }, + }, + }, + wantErr: false, + }, + { + name: "ipv6", + format: "hujson", + acl: ` +{ + "hosts": { + "host-1": "100.100.100.100/32", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "action": "accept", + "src": [ + "*", + ], + "dst": [ + "host-1:*", + ], + }, + ], +} +`, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"0.0.0.0/0", "::/0"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, + }, + }, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pol, err := PolicyFromBytes([]byte(tt.acl)) + if tt.wantErr && err == nil { + t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr) + + return + } else if !tt.wantErr && err != nil { + t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr) + + return + } + + if err != nil { + return + } + + rules, err := pol.CompileFilterRules(types.Nodes{ + &types.Node{ + IPv4: ap("100.100.100.100"), + }, + &types.Node{ + IPv4: ap("200.200.200.200"), + User: types.User{ + Name: "testuser", + }, + Hostinfo: &tailcfg.Hostinfo{}, + }, + }) + + if (err != nil) != tt.wantErr { + t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr) + + return + } + + if diff := cmp.Diff(tt.want, rules); diff != "" { + t.Errorf("parsing() unexpected result (-want +got):\n%s", diff) + } + }) + } +} diff --git a/hscontrol/policyv2/types.go b/hscontrol/policyv2/types.go index 2f7af07b..3bc65e1c 100644 --- a/hscontrol/policyv2/types.go +++ b/hscontrol/policyv2/types.go @@ -9,123 +9,293 @@ import ( "strconv" "strings" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" - "github.com/rs/zerolog/log" + "github.com/tailscale/hujson" + "go4.org/netipx" + "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" + "tailscale.com/types/ptr" ) +var theInternetSet *netipx.IPSet + +// theInternet returns the IPSet for the Internet. +// https://www.youtube.com/watch?v=iDbyYGrswtg +func theInternet() *netipx.IPSet { + if theInternetSet != nil { + return theInternetSet + } + + var internetBuilder netipx.IPSetBuilder + internetBuilder.AddPrefix(netip.MustParsePrefix("2000::/3")) + internetBuilder.AddPrefix(tsaddr.AllIPv4()) + + // Delete Private network addresses + // https://datatracker.ietf.org/doc/html/rfc1918 + internetBuilder.RemovePrefix(netip.MustParsePrefix("fc00::/7")) + internetBuilder.RemovePrefix(netip.MustParsePrefix("10.0.0.0/8")) + internetBuilder.RemovePrefix(netip.MustParsePrefix("172.16.0.0/12")) + internetBuilder.RemovePrefix(netip.MustParsePrefix("192.168.0.0/16")) + + // Delete Tailscale networks + internetBuilder.RemovePrefix(tsaddr.TailscaleULARange()) + internetBuilder.RemovePrefix(tsaddr.CGNATRange()) + + // Delete "cant find DHCP networks" + internetBuilder.RemovePrefix(netip.MustParsePrefix("fe80::/10")) // link-loca + internetBuilder.RemovePrefix(netip.MustParsePrefix("169.254.0.0/16")) + + theInternetSet, _ := internetBuilder.IPSet() + return theInternetSet +} + +type Asterix string + +func (a Asterix) Validate() error { + if a == "*" { + return nil + } + return fmt.Errorf(`Asterix can only be "*", got: %s`, a) +} + +func (a *Asterix) String() string { + return string(*a) +} + +func (a *Asterix) UnmarshalJSON(b []byte) error { + *a = "*" + return nil +} + // Username is a string that represents a username, it must contain an @. type Username string -func (u Username) Valid() bool { - return strings.Contains(string(u), "@") +func (u Username) Validate() error { + if strings.Contains(string(u), "@") { + return nil + } + return fmt.Errorf("Username has to contain @, got: %q", u) } -func (u Username) UnmarshalJSON(b []byte) error { - u = Username(strings.Trim(string(b), `"`)) - if !u.Valid() { - return fmt.Errorf("invalid username %q", u) +func (u *Username) String() string { + return string(*u) +} + +func (u *Username) UnmarshalJSON(b []byte) error { + *u = Username(strings.Trim(string(b), `"`)) + if err := u.Validate(); err != nil { + return err } return nil } +func (u Username) CanBeTagOwner() bool { + return true +} + +func (u Username) Resolve(_ *Policy, nodes types.Nodes) (*netipx.IPSet, error) { + var ips netipx.IPSetBuilder + + for _, node := range nodes { + if node.IsTagged() { + continue + } + + if node.User.Username() == string(u) { + node.AppendToIPSet(&ips) + } + } + + return ips.IPSet() +} + // Group is a special string which is always prefixed with `group:` type Group string -func (g Group) Valid() bool { - return strings.HasPrefix(string(g), "group:") +func (g Group) Validate() error { + if strings.HasPrefix(string(g), "group:") { + return nil + } + return fmt.Errorf(`Group has to start with "group:", got: %q`, g) } func (g Group) UnmarshalJSON(b []byte) error { g = Group(strings.Trim(string(b), `"`)) - if !g.Valid() { - return fmt.Errorf("invalid group %q", g) + if err := g.Validate(); err != nil { + return err } return nil } +func (g Group) CanBeTagOwner() bool { + return true +} + +func (g Group) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) { + var ips netipx.IPSetBuilder + + for _, user := range p.Groups[g] { + uips, err := user.Resolve(nil, nodes) + if err != nil { + return nil, err + } + + ips.AddSet(uips) + } + + return ips.IPSet() +} + // Tag is a special string which is always prefixed with `tag:` type Tag string -func (t Tag) Valid() bool { - return strings.HasPrefix(string(t), "tag:") +func (t Tag) Validate() error { + if strings.HasPrefix(string(t), "tag:") { + return nil + } + return fmt.Errorf(`tag has to start with "tag:", got: %q`, t) } func (t Tag) UnmarshalJSON(b []byte) error { t = Tag(strings.Trim(string(b), `"`)) - if !t.Valid() { - return fmt.Errorf("invalid tag %q", t) + if err := t.Validate(); err != nil { + return err } return nil } +func (t Tag) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) { + var ips netipx.IPSetBuilder + + for _, node := range nodes { + if node.HasTag(string(t)) { + node.AppendToIPSet(&ips) + } + } + + return ips.IPSet() +} + // Host is a string that represents a hostname. type Host string -func (h Host) Valid() bool { - return true +func (h Host) Validate() error { + return nil } func (h Host) UnmarshalJSON(b []byte) error { h = Host(strings.Trim(string(b), `"`)) - if !h.Valid() { - return fmt.Errorf("invalid host %q", h) - } - return nil -} - -type Addr netip.Addr - -func (a Addr) Valid() bool { - return netip.Addr(a).IsValid() -} - -func (a Addr) UnmarshalJSON(b []byte) error { - a = Addr(netip.Addr{}) - if err := json.Unmarshal(b, (netip.Addr)(a)); err != nil { + if err := h.Validate(); err != nil { return err } - if !a.Valid() { - return fmt.Errorf("invalid address %v", a) - } return nil } +func (h Host) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) { + var ips netipx.IPSetBuilder + + ips.AddPrefix(netip.Prefix(p.Hosts[h])) + + return ips.IPSet() +} + type Prefix netip.Prefix -func (p Prefix) Valid() bool { - return netip.Prefix(p).IsValid() +func (p Prefix) Validate() error { + if !netip.Prefix(p).IsValid() { + return fmt.Errorf("Prefix %q is invalid", p) + } + + return nil } -func (p Prefix) UnmarshalJSON(b []byte) error { - p = Prefix(netip.Prefix{}) - if err := json.Unmarshal(b, (netip.Prefix)(p)); err != nil { +func (p Prefix) String() string { + return netip.Prefix(p).String() +} + +func (p *Prefix) parseString(addr string) error { + if !strings.Contains(addr, "/") { + addr, err := netip.ParseAddr(addr) + if err != nil { + return err + } + addrPref, err := addr.Prefix(addr.BitLen()) + if err != nil { + return err + } + + *p = Prefix(addrPref) + return nil + } + + pref, err := netip.ParsePrefix(addr) + if err != nil { return err } - if !p.Valid() { - return fmt.Errorf("invalid prefix %v", p) + *p = Prefix(pref) + return nil +} + +func (p *Prefix) UnmarshalJSON(b []byte) error { + err := p.parseString(strings.Trim(string(b), `"`)) + if err != nil { + return err + } + if err := p.Validate(); err != nil { + return err } return nil } +func (p Prefix) Resolve(_ *Policy, _ types.Nodes) (*netipx.IPSet, error) { + var ips netipx.IPSetBuilder + + ips.AddPrefix(netip.Prefix(p)) + + return ips.IPSet() +} + // AutoGroup is a special string which is always prefixed with `autogroup:` type AutoGroup string -func (ag AutoGroup) Valid() bool { - return strings.HasPrefix(string(ag), "autogroup:") +const ( + AutoGroupInternet = "autogroup:internet" +) + +var autogroups = []string{AutoGroupInternet} + +func (ag AutoGroup) Validate() error { + for _, valid := range autogroups { + if valid == string(ag) { + return nil + } + } + + return fmt.Errorf("AutoGroup is invalid, got: %q, must be one of %v", ag, autogroups) } func (ag AutoGroup) UnmarshalJSON(b []byte) error { ag = AutoGroup(strings.Trim(string(b), `"`)) - if !ag.Valid() { - return fmt.Errorf("invalid autogroup %q", ag) + if err := ag.Validate(); err != nil { + return err } return nil } +func (ag AutoGroup) Resolve(_ *Policy, _ types.Nodes) (*netipx.IPSet, error) { + switch ag { + case AutoGroupInternet: + return theInternet(), nil + } + + return nil, nil +} + type Alias interface { - Valid() bool + Validate() error UnmarshalJSON([]byte) error + Resolve(*Policy, types.Nodes) (*netipx.IPSet, error) } type AliasWithPorts struct { @@ -160,6 +330,9 @@ func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error { } ve.Alias = parseAlias(vs) + if err := ve.Alias.Validate(); err != nil { + return err + } default: return fmt.Errorf("type %T not supported", vs) @@ -172,19 +345,19 @@ func parseAlias(vs string) Alias { // ve.Alias = Addr(val) // case netip.Prefix: // ve.Alias = Prefix(val) - if addr, err := netip.ParseAddr(vs); err == nil { - return Addr(addr) - } - - if prefix, err := netip.ParsePrefix(vs); err == nil { - return Prefix(prefix) + var pref Prefix + err := pref.parseString(vs) + if err == nil { + return &pref } switch { + case vs == "*": + return ptr.To(Asterix("*")) case strings.Contains(vs, "@"): - return Username(vs) + return ptr.To(Username(vs)) case strings.HasPrefix(vs, "group:"): - return Group(vs) + ptr.To(Group(vs)) case strings.HasPrefix(vs, "tag:"): return Tag(vs) case strings.HasPrefix(vs, "autogroup:"): @@ -206,6 +379,10 @@ func (ve *AliasEnc) UnmarshalJSON(b []byte) error { switch val := v.(type) { case string: ve.Alias = parseAlias(val) + ve.Alias = parseAlias(val) + if err := ve.Alias.Validate(); err != nil { + return err + } default: return fmt.Errorf("type %T not supported", val) } @@ -228,24 +405,78 @@ func (a *Aliases) UnmarshalJSON(b []byte) error { return nil } -// UserEntity is an interface that represents something that can -// return a list of users: -// - Username -// - Group -// - AutoGroup -type UserEntity interface { - Users() []Username +func (a Aliases) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) { + var ips netipx.IPSetBuilder + + for _, alias := range a { + aips, err := alias.Resolve(p, nodes) + if err != nil { + return nil, err + } + + ips.AddSet(aips) + } + + return ips.IPSet() +} + +type Owner interface { + CanBeTagOwner() bool UnmarshalJSON([]byte) error } +// OwnerEnc is used to deserialize a Owner. +type OwnerEnc struct{ Owner } + +func (ve *OwnerEnc) UnmarshalJSON(b []byte) error { + // TODO(kradalby): use encoding/json/v2 (go-json-experiment) + dec := json.NewDecoder(bytes.NewReader(b)) + var v any + if err := dec.Decode(&v); err != nil { + return err + } + switch val := v.(type) { + case string: + + switch { + case strings.Contains(val, "@"): + u := Username(val) + ve.Owner = &u + case strings.HasPrefix(val, "group:"): + ve.Owner = Group(val) + } + default: + return fmt.Errorf("type %T not supported", val) + } + return nil +} + +type Owners []Owner + +func (o *Owners) UnmarshalJSON(b []byte) error { + var owners []OwnerEnc + err := json.Unmarshal(b, &owners) + if err != nil { + return err + } + + *o = make([]Owner, len(owners)) + for i, owner := range owners { + (*o)[i] = owner.Owner + } + return nil +} + +type Usernames []Username + // Groups are a map of Group to a list of Username. -type Groups map[Group][]Username +type Groups map[Group]Usernames // Hosts are alias for IP addresses or subnets. -type Hosts map[Host]netip.Prefix +type Hosts map[Host]Prefix // TagOwners are a map of Tag to a list of the UserEntities that own the tag. -type TagOwners map[Tag][]UserEntity +type TagOwners map[Tag]Owners type AutoApprovers struct { Routes map[string][]string `json:"routes"` @@ -259,16 +490,45 @@ type ACL struct { Destinations []AliasWithPorts `json:"dst"` } -// ACLPolicy represents a Tailscale ACL Policy. -type ACLPolicy struct { - Groups Groups `json:"groups"` - // Hosts Hosts `json:"hosts"` +// Policy represents a Tailscale Network Policy. +// TODO(kradalby): +// Add validation method checking: +// All users exists +// All groups and users are valid tag TagOwners +// Everything referred to in ACLs exists in other +// entities. +type Policy struct { + // validated is set if the policy has been validated. + // It is not safe to use before it is validated, and + // callers using it should panic if not + validated bool `json:"-"` + + Groups Groups `json:"groups"` + Hosts Hosts `json:"hosts"` TagOwners TagOwners `json:"tagOwners"` ACLs []ACL `json:"acls"` AutoApprovers AutoApprovers `json:"autoApprovers"` // SSHs []SSH `json:"ssh"` } +func PolicyFromBytes(b []byte) (*Policy, error) { + var policy Policy + ast, err := hujson.Parse(b) + if err != nil { + return nil, fmt.Errorf("parsing HuJSON: %w", err) + } + + ast.Standardize() + acl := ast.Pack() + + err = json.Unmarshal(acl, &policy) + if err != nil { + return nil, fmt.Errorf("parsing policy from bytes: %w", err) + } + + return &policy, nil +} + const ( expectedTokenItems = 2 ) @@ -329,7 +589,6 @@ func parsePorts(portsStr string) ([]tailcfg.PortRange, error) { var ports []tailcfg.PortRange for _, portStr := range strings.Split(portsStr, ",") { - log.Trace().Msgf("parsing portstring: %s", portStr) rang := strings.Split(portStr, "-") switch len(rang) { case 1: @@ -363,3 +622,72 @@ func parsePorts(portsStr string) ([]tailcfg.PortRange, error) { return ports, nil } + +// For some reason golang.org/x/net/internal/iana is an internal package. +const ( + protocolICMP = 1 // Internet Control Message + protocolIGMP = 2 // Internet Group Management + protocolIPv4 = 4 // IPv4 encapsulation + protocolTCP = 6 // Transmission Control + protocolEGP = 8 // Exterior Gateway Protocol + protocolIGP = 9 // any private interior gateway (used by Cisco for their IGRP) + protocolUDP = 17 // User Datagram + protocolGRE = 47 // Generic Routing Encapsulation + protocolESP = 50 // Encap Security Payload + protocolAH = 51 // Authentication Header + protocolIPv6ICMP = 58 // ICMP for IPv6 + protocolSCTP = 132 // Stream Control Transmission Protocol + ProtocolFC = 133 // Fibre Channel +) + +// parseProtocol reads the proto field of the ACL and generates a list of +// protocols that will be allowed, following the IANA IP protocol number +// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml +// +// If the ACL proto field is empty, it allows ICMPv4, ICMPv6, TCP, and UDP, +// as per Tailscale behaviour (see tailcfg.FilterRule). +// +// Also returns a boolean indicating if the protocol +// requires all the destinations to use wildcard as port number (only TCP, +// UDP and SCTP support specifying ports). +func parseProtocol(protocol string) ([]int, bool, error) { + switch protocol { + case "": + return nil, false, nil + case "igmp": + return []int{protocolIGMP}, true, nil + case "ipv4", "ip-in-ip": + return []int{protocolIPv4}, true, nil + case "tcp": + return []int{protocolTCP}, false, nil + case "egp": + return []int{protocolEGP}, true, nil + case "igp": + return []int{protocolIGP}, true, nil + case "udp": + return []int{protocolUDP}, false, nil + case "gre": + return []int{protocolGRE}, true, nil + case "esp": + return []int{protocolESP}, true, nil + case "ah": + return []int{protocolAH}, true, nil + case "sctp": + return []int{protocolSCTP}, false, nil + case "icmp": + return []int{protocolICMP, protocolIPv6ICMP}, true, nil + + default: + protocolNumber, err := strconv.Atoi(protocol) + if err != nil { + return nil, false, fmt.Errorf("parsing protocol number: %w", err) + } + + // TODO(kradalby): What is this? + needsWildcard := protocolNumber != protocolTCP && + protocolNumber != protocolUDP && + protocolNumber != protocolSCTP + + return []int{protocolNumber}, needsWildcard, nil + } +} diff --git a/hscontrol/policyv2/types_test.go b/hscontrol/policyv2/types_test.go index ef47bb7d..26423a61 100644 --- a/hscontrol/policyv2/types_test.go +++ b/hscontrol/policyv2/types_test.go @@ -1,25 +1,45 @@ package policyv2 import ( - "encoding/json" + "net/netip" "testing" "github.com/google/go-cmp/cmp" - "github.com/tailscale/hujson" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" "tailscale.com/tailcfg" + "tailscale.com/types/ptr" ) func TestUnmarshalPolicy(t *testing.T) { tests := []struct { name string input string - want *ACLPolicy - wantErr error + want *Policy + wantErr string }{ { name: "empty", input: "{}", - want: &ACLPolicy{}, + want: &Policy{}, + }, + { + name: "groups", + input: ` +{ + "groups": { + "group:example": [ + "derp@headscale.net", + ], + }, +} +`, + want: &Policy{ + Groups: Groups{ + Group("group:example"): []Username{Username("derp@headscale.net")}, + }, + }, }, { name: "basic-types", @@ -29,67 +49,500 @@ func TestUnmarshalPolicy(t *testing.T) { "group:example": [ "testuser@headscale.net", ], + "group:other": [ + "otheruser@headscale.net", + ], + }, + + "tagOwners": { + "tag:user": ["testuser@headscale.net"], + "tag:group": ["group:other"], + "tag:userandgroup": ["testuser@headscale.net" ,"group:other"], }, "hosts": { "host-1": "100.100.100.100", "subnet-1": "100.100.101.100/24", + "outside": "192.168.0.0/16", }, "acls": [ + // All { "action": "accept", - "src": [ - "group:example", - ], - "dst": [ - "host-1:*", - ], + "proto": "tcp", + "src": ["*"], + "dst": ["*:*"], + }, + // Users + { + "action": "accept", + "proto": "tcp", + "src": ["testuser@headscale.net"], + "dst": ["otheruser@headscale.net:80"], + }, + // Groups + { + "action": "accept", + "proto": "tcp", + "src": ["group:example"], + "dst": ["group:other:80"], + }, + // Tailscale IP + { + "action": "accept", + "proto": "tcp", + "src": ["100.101.102.103"], + "dst": ["100.101.102.104:80"], + }, + // Subnet + { + "action": "accept", + "proto": "udp", + "src": ["10.0.0.0/8"], + "dst": ["172.16.0.0/16:80"], + }, + // Hosts + { + "action": "accept", + "proto": "tcp", + "src": ["subnet-1"], + "dst": ["host-1:80-88"], + }, + // Tags + { + "action": "accept", + "proto": "tcp", + "src": ["tag:group"], + "dst": ["tag:user:80,443"], + }, + // Autogroup + { + "action": "accept", + "proto": "tcp", + "src": ["tag:group"], + "dst": ["autogroup:internet:80"], }, ], } `, - want: &ACLPolicy{ + want: &Policy{ Groups: Groups{ - Group("group:example"): []Username{"testuser@headscale.net"}, + Group("group:example"): []Username{Username("testuser@headscale.net")}, + Group("group:other"): []Username{Username("otheruser@headscale.net")}, + }, + TagOwners: TagOwners{ + Tag("tag:user"): Owners{ptr.To(Username("testuser@headscale.net"))}, + Tag("tag:group"): Owners{Group("group:other")}, + Tag("tag:userandgroup"): Owners{ptr.To(Username("testuser@headscale.net")), Group("group:other")}, + }, + Hosts: Hosts{ + "host-1": Prefix(netip.MustParsePrefix("100.100.100.100/32")), + "subnet-1": Prefix(netip.MustParsePrefix("100.100.101.100/24")), + "outside": Prefix(netip.MustParsePrefix("192.168.0.0/16")), }, ACLs: []ACL{ { - Action: "accept", + Action: "accept", + Protocol: "tcp", + Sources: Aliases{ + // TODO(kradalby): Should this be host? + // It is: + // All traffic originating from Tailscale devices in your tailnet, + // any approved subnets and autogroup:shared. + // It does not allow traffic originating from + // non-tailscale devices (unless it is an approved route). + Host("*"), + }, + Destinations: []AliasWithPorts{ + { + // TODO(kradalby): Should this be host? + // It is: + // Includes any destination (no restrictions). + Alias: Host("*"), + Ports: []tailcfg.PortRange{tailcfg.PortRangeAny}, + }, + }, + }, + { + Action: "accept", + Protocol: "tcp", + Sources: Aliases{ + ptr.To(Username("testuser@headscale.net")), + }, + Destinations: []AliasWithPorts{ + { + Alias: ptr.To(Username("otheruser@headscale.net")), + Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}}, + }, + }, + }, + { + Action: "accept", + Protocol: "tcp", Sources: Aliases{ Group("group:example"), }, + Destinations: []AliasWithPorts{ + { + Alias: Group("group:other"), + Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}}, + }, + }, + }, + { + Action: "accept", + Protocol: "tcp", + Sources: Aliases{ + ptr.To(Prefix(netip.MustParsePrefix("100.101.102.103/32"))), + }, + Destinations: []AliasWithPorts{ + { + Alias: ptr.To(Prefix(netip.MustParsePrefix("100.101.102.104/32"))), + Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}}, + }, + }, + }, + { + Action: "accept", + Protocol: "udp", + Sources: Aliases{ + ptr.To(Prefix(netip.MustParsePrefix("10.0.0.0/8"))), + }, + Destinations: []AliasWithPorts{ + { + Alias: ptr.To(Prefix(netip.MustParsePrefix("172.16.0.0/16"))), + Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}}, + }, + }, + }, + { + Action: "accept", + Protocol: "tcp", + Sources: Aliases{ + Host("subnet-1"), + }, Destinations: []AliasWithPorts{ { Alias: Host("host-1"), - Ports: []tailcfg.PortRange{tailcfg.PortRangeAny}, + Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 88}}, + }, + }, + }, + { + Action: "accept", + Protocol: "tcp", + Sources: Aliases{ + Tag("tag:group"), + }, + Destinations: []AliasWithPorts{ + { + Alias: Tag("tag:user"), + Ports: []tailcfg.PortRange{ + tailcfg.PortRange{First: 80, Last: 80}, + tailcfg.PortRange{First: 443, Last: 443}, + }, + }, + }, + }, + { + Action: "accept", + Protocol: "tcp", + Sources: Aliases{ + Tag("tag:group"), + }, + Destinations: []AliasWithPorts{ + { + Alias: AutoGroup("autogroup:internet"), + Ports: []tailcfg.PortRange{ + tailcfg.PortRange{First: 80, Last: 80}, + }, }, }, }, }, }, }, + { + name: "invalid-username", + input: ` +{ + "groups": { + "group:example": [ + "valid@", + "invalid", + ], + }, +} +`, + wantErr: `Username has to contain @, got: "invalid"`, + }, + { + name: "invalid-group", + input: ` +{ + "groups": { + "grou:example": [ + "valid@", + ], + }, +} +`, + wantErr: `Group has to start with "group:", got: "grou:example"`, + }, + { + name: "group-in-group", + input: ` +{ + "groups": { + "group:inner": [], + "group:example": [ + "group:inner", + ], + }, +} +`, + wantErr: `Username has to contain @, got: "group:inner"`, + }, + { + name: "invalid-prefix", + input: ` +{ + "hosts": { + "derp": "10.0", + }, +} +`, + wantErr: `ParseAddr("10.0"): IPv4 address too short`, + }, + { + name: "invalid-auto-group", + input: ` +{ + "acls": [ + // Autogroup + { + "action": "accept", + "proto": "tcp", + "src": ["tag:group"], + "dst": ["autogroup:invalid:80"], + }, + ], +} +`, + wantErr: `AutoGroup is invalid, got: "autogroup:invalid", must be one of [autogroup:internet]`, + }, } + cmps := append(util.Comparers, cmp.Comparer(func(x, y Prefix) bool { + return x == y + })) + cmps = append(cmps, cmpopts.IgnoreUnexported(Policy{})) + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - var policy ACLPolicy - ast, err := hujson.Parse([]byte(tt.input)) + policy, err := PolicyFromBytes([]byte(tt.input)) + // TODO(kradalby): This error checking is broken, + // but so is my brain, #longflight + if err == nil { + if tt.wantErr == "" { + return + } + t.Fatalf("got success; wanted error %q", tt.wantErr) + } + if err.Error() != tt.wantErr { + t.Fatalf("got error %q; want %q", err, tt.wantErr) + // } else if err.Error() == tt.wantErr { + // return + } + if err != nil { - t.Fatalf("parsing hujson: %s", err) + t.Fatalf("unexpected err: %q", err) } - ast.Standardize() - acl := ast.Pack() - - if err := json.Unmarshal(acl, &policy); err != nil { - // TODO: check error type - t.Fatalf("unmarshaling json: %s", err) - } - - if diff := cmp.Diff(tt.want, &policy); diff != "" { + if diff := cmp.Diff(tt.want, &policy, cmps...); diff != "" { t.Fatalf("unexpected policy (-want +got):\n%s", diff) } }) } } + +func mp(pref string) netip.Prefix { return netip.MustParsePrefix(pref) } +func ap(addr string) *netip.Addr { return ptr.To(netip.MustParseAddr(addr)) } +func pp(pref string) *Prefix { return ptr.To(Prefix(netip.MustParsePrefix(pref))) } +func p(pref string) Prefix { return Prefix(netip.MustParsePrefix(pref)) } + +func TestResolvePolicy(t *testing.T) { + tests := []struct { + name string + nodes types.Nodes + pol *Policy + toResolve Alias + want []netip.Prefix + }{ + { + name: "prefix", + toResolve: pp("100.100.101.101/32"), + want: []netip.Prefix{mp("100.100.101.101/32")}, + }, + { + name: "host", + pol: &Policy{ + Hosts: Hosts{ + "testhost": p("100.100.101.102/32"), + }, + }, + toResolve: Host("testhost"), + want: []netip.Prefix{mp("100.100.101.102/32")}, + }, + { + name: "username", + toResolve: ptr.To(Username("testuser")), + nodes: types.Nodes{ + // Not matching other user + { + User: types.User{ + Name: "notme", + }, + IPv4: ap("100.100.101.1"), + }, + // Not matching forced tags + { + User: types.User{ + Name: "testuser", + }, + ForcedTags: []string{"tag:anything"}, + IPv4: ap("100.100.101.2"), + }, + // not matchin pak tag + { + User: types.User{ + Name: "testuser", + }, + AuthKey: &types.PreAuthKey{ + Tags: []string{"alsotagged"}, + }, + IPv4: ap("100.100.101.3"), + }, + { + User: types.User{ + Name: "testuser", + }, + IPv4: ap("100.100.101.103"), + }, + { + User: types.User{ + Name: "testuser", + }, + IPv4: ap("100.100.101.104"), + }, + }, + want: []netip.Prefix{mp("100.100.101.103/32"), mp("100.100.101.104/32")}, + }, + { + name: "group", + toResolve: ptr.To(Group("group:testgroup")), + nodes: types.Nodes{ + // Not matching other user + { + User: types.User{ + Name: "notmetoo", + }, + IPv4: ap("100.100.101.4"), + }, + // Not matching forced tags + { + User: types.User{ + Name: "groupuser", + }, + ForcedTags: []string{"tag:anything"}, + IPv4: ap("100.100.101.5"), + }, + // not matchin pak tag + { + User: types.User{ + Name: "groupuser", + }, + AuthKey: &types.PreAuthKey{ + Tags: []string{"tag:alsotagged"}, + }, + IPv4: ap("100.100.101.6"), + }, + { + User: types.User{ + Name: "groupuser", + }, + IPv4: ap("100.100.101.203"), + }, + { + User: types.User{ + Name: "groupuser", + }, + IPv4: ap("100.100.101.204"), + }, + }, + pol: &Policy{ + Groups: Groups{ + "group:testgroup": Usernames{"groupuser"}, + "group:othergroup": Usernames{"notmetoo"}, + }, + }, + want: []netip.Prefix{mp("100.100.101.203/32"), mp("100.100.101.204/32")}, + }, + { + name: "tag", + toResolve: Tag("tag:test"), + nodes: types.Nodes{ + // Not matching other user + { + User: types.User{ + Name: "notmetoo", + }, + IPv4: ap("100.100.101.9"), + }, + // Not matching forced tags + { + ForcedTags: []string{"tag:anything"}, + IPv4: ap("100.100.101.10"), + }, + // not matchin pak tag + { + AuthKey: &types.PreAuthKey{ + Tags: []string{"tag:alsotagged"}, + }, + IPv4: ap("100.100.101.11"), + }, + // Not matching forced tags + { + ForcedTags: []string{"tag:test"}, + IPv4: ap("100.100.101.234"), + }, + // not matchin pak tag + { + AuthKey: &types.PreAuthKey{ + Tags: []string{"tag:test"}, + }, + IPv4: ap("100.100.101.239"), + }, + }, + // TODO(kradalby): tests handling TagOwners + hostinfo + pol: &Policy{}, + want: []netip.Prefix{mp("100.100.101.234/32"), mp("100.100.101.239/32")}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ips, err := tt.toResolve.Resolve(tt.pol, tt.nodes) + if err != nil { + t.Fatalf("failed to resolve: %s", err) + } + + prefs := ips.Prefixes() + + if diff := cmp.Diff(tt.want, prefs, util.Comparers...); diff != "" { + t.Fatalf("unexpected prefs (-want +got):\n%s", diff) + } + }) + } +} diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 9d632bd8..92aba945 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "net/netip" + "slices" "strconv" "strings" "time" @@ -134,6 +135,50 @@ func (node *Node) IPs() []netip.Addr { return ret } +// IsTagged reports if a device is tagged +// and therefore should not be treated as a +// user owned device. +// Currently, this function only handles tags set +// via CLI ("forced tags" and preauthkeys) +func (node *Node) IsTagged() bool { + if len(node.ForcedTags) > 0 { + return true + } + + if node.AuthKey != nil && len(node.AuthKey.Tags) > 0 { + return true + } + + if node.Hostinfo == nil { + return false + } + + // TODO(kradalby): Figure out how tagging should work + // and hostinfo.requestedtags. + // Do this in other work. + + return false +} + +// HasTag reports if a node has a given tag. +// Currently, this function only handles tags set +// via CLI ("forced tags" and preauthkeys) +func (node *Node) HasTag(tag string) bool { + if slices.Contains(node.ForcedTags, tag) { + return true + } + + if node.AuthKey != nil && slices.Contains(node.AuthKey.Tags, tag) { + return true + } + + // TODO(kradalby): Figure out how tagging should work + // and hostinfo.requestedtags. + // Do this in other work. + + return false +} + func (node *Node) Prefixes() []netip.Prefix { addrs := []netip.Prefix{} for _, nodeAddress := range node.IPs() { From 38f2159c567d8fc706d9037f0f1a1191bb554344 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Thu, 17 Oct 2024 01:02:17 +0200 Subject: [PATCH 3/6] implement asterix, pass old parsing test Signed-off-by: Kristoffer Dalby --- hscontrol/policyv2/filter_test.go | 2 +- hscontrol/policyv2/types.go | 86 ++++++++++++++++++++----------- hscontrol/policyv2/types_test.go | 45 +++++++++------- 3 files changed, 81 insertions(+), 52 deletions(-) diff --git a/hscontrol/policyv2/filter_test.go b/hscontrol/policyv2/filter_test.go index 4edf7233..6b76b168 100644 --- a/hscontrol/policyv2/filter_test.go +++ b/hscontrol/policyv2/filter_test.go @@ -358,7 +358,7 @@ func TestParsing(t *testing.T) { &types.Node{ IPv4: ap("200.200.200.200"), User: types.User{ - Name: "testuser", + Name: "testuser@", }, Hostinfo: &tailcfg.Hostinfo{}, }, diff --git a/hscontrol/policyv2/types.go b/hscontrol/policyv2/types.go index 3bc65e1c..2bbb793a 100644 --- a/hscontrol/policyv2/types.go +++ b/hscontrol/policyv2/types.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "log" "net/netip" "strconv" "strings" @@ -50,24 +51,29 @@ func theInternet() *netipx.IPSet { return theInternetSet } -type Asterix string +type Asterix int func (a Asterix) Validate() error { - if a == "*" { - return nil - } - return fmt.Errorf(`Asterix can only be "*", got: %s`, a) -} - -func (a *Asterix) String() string { - return string(*a) -} - -func (a *Asterix) UnmarshalJSON(b []byte) error { - *a = "*" return nil } +func (a Asterix) String() string { + return "*" +} + +func (a Asterix) UnmarshalJSON(b []byte) error { + return nil +} + +func (a Asterix) Resolve(_ *Policy, nodes types.Nodes) (*netipx.IPSet, error) { + var ips netipx.IPSetBuilder + + ips.AddPrefix(tsaddr.AllIPv4()) + ips.AddPrefix(tsaddr.AllIPv6()) + + return ips.IPSet() +} + // Username is a string that represents a username, it must contain an @. type Username string @@ -120,8 +126,8 @@ func (g Group) Validate() error { return fmt.Errorf(`Group has to start with "group:", got: %q`, g) } -func (g Group) UnmarshalJSON(b []byte) error { - g = Group(strings.Trim(string(b), `"`)) +func (g *Group) UnmarshalJSON(b []byte) error { + *g = Group(strings.Trim(string(b), `"`)) if err := g.Validate(); err != nil { return err } @@ -157,8 +163,8 @@ func (t Tag) Validate() error { return fmt.Errorf(`tag has to start with "tag:", got: %q`, t) } -func (t Tag) UnmarshalJSON(b []byte) error { - t = Tag(strings.Trim(string(b), `"`)) +func (t *Tag) UnmarshalJSON(b []byte) error { + *t = Tag(strings.Trim(string(b), `"`)) if err := t.Validate(); err != nil { return err } @@ -184,8 +190,8 @@ func (h Host) Validate() error { return nil } -func (h Host) UnmarshalJSON(b []byte) error { - h = Host(strings.Trim(string(b), `"`)) +func (h *Host) UnmarshalJSON(b []byte) error { + *h = Host(strings.Trim(string(b), `"`)) if err := h.Validate(); err != nil { return err } @@ -195,7 +201,15 @@ func (h Host) UnmarshalJSON(b []byte) error { func (h Host) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) { var ips netipx.IPSetBuilder - ips.AddPrefix(netip.Prefix(p.Hosts[h])) + pref, ok := p.Hosts[h] + if !ok { + return nil, fmt.Errorf("unable to resolve host: %q", h) + } + err := pref.Validate() + if err != nil { + return nil, err + } + ips.AddPrefix(netip.Prefix(pref)) return ips.IPSet() } @@ -275,8 +289,8 @@ func (ag AutoGroup) Validate() error { return fmt.Errorf("AutoGroup is invalid, got: %q, must be one of %v", ag, autogroups) } -func (ag AutoGroup) UnmarshalJSON(b []byte) error { - ag = AutoGroup(strings.Trim(string(b), `"`)) +func (ag *AutoGroup) UnmarshalJSON(b []byte) error { + *ag = AutoGroup(strings.Trim(string(b), `"`)) if err := ag.Validate(); err != nil { return err } @@ -330,6 +344,9 @@ func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error { } ve.Alias = parseAlias(vs) + if ve.Alias == nil { + return fmt.Errorf("could not determine the type of %q", vs) + } if err := ve.Alias.Validate(); err != nil { return err } @@ -353,17 +370,22 @@ func parseAlias(vs string) Alias { switch { case vs == "*": - return ptr.To(Asterix("*")) + return Asterix(0) case strings.Contains(vs, "@"): return ptr.To(Username(vs)) case strings.HasPrefix(vs, "group:"): - ptr.To(Group(vs)) + return ptr.To(Group(vs)) case strings.HasPrefix(vs, "tag:"): - return Tag(vs) + return ptr.To(Tag(vs)) case strings.HasPrefix(vs, "autogroup:"): - return AutoGroup(vs) + return ptr.To(AutoGroup(vs)) } - return Host(vs) + + if !strings.Contains(vs, "@") && !strings.Contains(vs, ":") { + return ptr.To(Host(vs)) + } + + return nil } // AliasEnc is used to deserialize a Alias. @@ -379,10 +401,13 @@ func (ve *AliasEnc) UnmarshalJSON(b []byte) error { switch val := v.(type) { case string: ve.Alias = parseAlias(val) - ve.Alias = parseAlias(val) + if ve.Alias == nil { + return fmt.Errorf("could not determine the type of %q", val) + } if err := ve.Alias.Validate(); err != nil { return err } + log.Printf("val: %q as type: %T", val, ve.Alias) default: return fmt.Errorf("type %T not supported", val) } @@ -440,10 +465,9 @@ func (ve *OwnerEnc) UnmarshalJSON(b []byte) error { switch { case strings.Contains(val, "@"): - u := Username(val) - ve.Owner = &u + ve.Owner = ptr.To(Username(val)) case strings.HasPrefix(val, "group:"): - ve.Owner = Group(val) + ve.Owner = ptr.To(Group(val)) } default: return fmt.Errorf("type %T not supported", val) diff --git a/hscontrol/policyv2/types_test.go b/hscontrol/policyv2/types_test.go index 26423a61..018f9951 100644 --- a/hscontrol/policyv2/types_test.go +++ b/hscontrol/policyv2/types_test.go @@ -57,7 +57,7 @@ func TestUnmarshalPolicy(t *testing.T) { "tagOwners": { "tag:user": ["testuser@headscale.net"], "tag:group": ["group:other"], - "tag:userandgroup": ["testuser@headscale.net" ,"group:other"], + "tag:userandgroup": ["testuser@headscale.net", "group:other"], }, "hosts": { @@ -132,9 +132,9 @@ func TestUnmarshalPolicy(t *testing.T) { Group("group:other"): []Username{Username("otheruser@headscale.net")}, }, TagOwners: TagOwners{ - Tag("tag:user"): Owners{ptr.To(Username("testuser@headscale.net"))}, - Tag("tag:group"): Owners{Group("group:other")}, - Tag("tag:userandgroup"): Owners{ptr.To(Username("testuser@headscale.net")), Group("group:other")}, + Tag("tag:user"): Owners{up("testuser@headscale.net")}, + Tag("tag:group"): Owners{gp("group:other")}, + Tag("tag:userandgroup"): Owners{up("testuser@headscale.net"), gp("group:other")}, }, Hosts: Hosts{ "host-1": Prefix(netip.MustParsePrefix("100.100.100.100/32")), @@ -152,14 +152,14 @@ func TestUnmarshalPolicy(t *testing.T) { // any approved subnets and autogroup:shared. // It does not allow traffic originating from // non-tailscale devices (unless it is an approved route). - Host("*"), + hp("*"), }, Destinations: []AliasWithPorts{ { // TODO(kradalby): Should this be host? // It is: // Includes any destination (no restrictions). - Alias: Host("*"), + Alias: hp("*"), Ports: []tailcfg.PortRange{tailcfg.PortRangeAny}, }, }, @@ -181,11 +181,11 @@ func TestUnmarshalPolicy(t *testing.T) { Action: "accept", Protocol: "tcp", Sources: Aliases{ - Group("group:example"), + gp("group:example"), }, Destinations: []AliasWithPorts{ { - Alias: Group("group:other"), + Alias: gp("group:other"), Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}}, }, }, @@ -194,11 +194,11 @@ func TestUnmarshalPolicy(t *testing.T) { Action: "accept", Protocol: "tcp", Sources: Aliases{ - ptr.To(Prefix(netip.MustParsePrefix("100.101.102.103/32"))), + pp("100.101.102.103/32"), }, Destinations: []AliasWithPorts{ { - Alias: ptr.To(Prefix(netip.MustParsePrefix("100.101.102.104/32"))), + Alias: pp("100.101.102.104/32"), Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}}, }, }, @@ -207,11 +207,11 @@ func TestUnmarshalPolicy(t *testing.T) { Action: "accept", Protocol: "udp", Sources: Aliases{ - ptr.To(Prefix(netip.MustParsePrefix("10.0.0.0/8"))), + pp("10.0.0.0/8"), }, Destinations: []AliasWithPorts{ { - Alias: ptr.To(Prefix(netip.MustParsePrefix("172.16.0.0/16"))), + Alias: pp("172.16.0.0/16"), Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}}, }, }, @@ -220,11 +220,11 @@ func TestUnmarshalPolicy(t *testing.T) { Action: "accept", Protocol: "tcp", Sources: Aliases{ - Host("subnet-1"), + hp("subnet-1"), }, Destinations: []AliasWithPorts{ { - Alias: Host("host-1"), + Alias: hp("host-1"), Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 88}}, }, }, @@ -233,11 +233,11 @@ func TestUnmarshalPolicy(t *testing.T) { Action: "accept", Protocol: "tcp", Sources: Aliases{ - Tag("tag:group"), + tp("tag:group"), }, Destinations: []AliasWithPorts{ { - Alias: Tag("tag:user"), + Alias: tp("tag:user"), Ports: []tailcfg.PortRange{ tailcfg.PortRange{First: 80, Last: 80}, tailcfg.PortRange{First: 443, Last: 443}, @@ -249,11 +249,11 @@ func TestUnmarshalPolicy(t *testing.T) { Action: "accept", Protocol: "tcp", Sources: Aliases{ - Tag("tag:group"), + tp("tag:group"), }, Destinations: []AliasWithPorts{ { - Alias: AutoGroup("autogroup:internet"), + Alias: agp("autogroup:internet"), Ports: []tailcfg.PortRange{ tailcfg.PortRange{First: 80, Last: 80}, }, @@ -367,6 +367,11 @@ func TestUnmarshalPolicy(t *testing.T) { } } +func gp(s string) *Group { return ptr.To(Group(s)) } +func up(s string) *Username { return ptr.To(Username(s)) } +func hp(s string) *Host { return ptr.To(Host(s)) } +func tp(s string) *Tag { return ptr.To(Tag(s)) } +func agp(s string) *AutoGroup { return ptr.To(AutoGroup(s)) } func mp(pref string) netip.Prefix { return netip.MustParsePrefix(pref) } func ap(addr string) *netip.Addr { return ptr.To(netip.MustParseAddr(addr)) } func pp(pref string) *Prefix { return ptr.To(Prefix(netip.MustParsePrefix(pref))) } @@ -392,7 +397,7 @@ func TestResolvePolicy(t *testing.T) { "testhost": p("100.100.101.102/32"), }, }, - toResolve: Host("testhost"), + toResolve: hp("testhost"), want: []netip.Prefix{mp("100.100.101.102/32")}, }, { @@ -491,7 +496,7 @@ func TestResolvePolicy(t *testing.T) { }, { name: "tag", - toResolve: Tag("tag:test"), + toResolve: tp("tag:test"), nodes: types.Nodes{ // Not matching other user { From a7b2468a420aaa4f00ce77317831cdb49f315f14 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Thu, 17 Oct 2024 01:23:44 +0200 Subject: [PATCH 4/6] use json in TestReduceFilterRules test This is to allow for the tests to be ran with the new upcoming parser to ensure we get the same input. Signed-off-by: Kristoffer Dalby --- hscontrol/policy/acls_test.go | 493 ++++++++++++++++++++-------------- 1 file changed, 290 insertions(+), 203 deletions(-) diff --git a/hscontrol/policy/acls_test.go b/hscontrol/policy/acls_test.go index 1c6e4de8..b3aa366e 100644 --- a/hscontrol/policy/acls_test.go +++ b/hscontrol/policy/acls_test.go @@ -1838,20 +1838,27 @@ func TestReduceFilterRules(t *testing.T) { name string node *types.Node peers types.Nodes - pol ACLPolicy + pol string want []tailcfg.FilterRule }{ { name: "host1-can-reach-host2-no-rules", - pol: ACLPolicy{ - ACLs: []ACL{ - { - Action: "accept", - Sources: []string{"100.64.0.1"}, - Destinations: []string{"100.64.0.2:*"}, - }, - }, - }, + pol: ` +{ + "acls": [ + { + "action": "accept", + "proto": "", + "src": [ + "100.64.0.1" + ], + "dst": [ + "100.64.0.2:*" + ] + } + ], +} +`, node: &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), @@ -1868,23 +1875,37 @@ func TestReduceFilterRules(t *testing.T) { }, { name: "1604-subnet-routers-are-preserved", - pol: ACLPolicy{ - Groups: Groups{ - "group:admins": {"user1"}, - }, - ACLs: []ACL{ - { - Action: "accept", - Sources: []string{"group:admins"}, - Destinations: []string{"group:admins:*"}, - }, - { - Action: "accept", - Sources: []string{"group:admins"}, - Destinations: []string{"10.33.0.0/16:*"}, - }, - }, - }, + pol: ` +{ + "groups": { + "group:admins": [ + "user1" + ] + }, + "acls": [ + { + "action": "accept", + "proto": "", + "src": [ + "group:admins" + ], + "dst": [ + "group:admins:*" + ] + }, + { + "action": "accept", + "proto": "", + "src": [ + "group:admins" + ], + "dst": [ + "10.33.0.0/16:*" + ] + } + ], +} +`, node: &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), @@ -1939,31 +1960,42 @@ func TestReduceFilterRules(t *testing.T) { }, { name: "1786-reducing-breaks-exit-nodes-the-client", - pol: ACLPolicy{ - Hosts: Hosts{ - // Exit node - "internal": netip.MustParsePrefix("100.64.0.100/32"), - }, - Groups: Groups{ - "group:team": {"user3", "user2", "user1"}, - }, - ACLs: []ACL{ - { - Action: "accept", - Sources: []string{"group:team"}, - Destinations: []string{ - "internal:*", - }, - }, - { - Action: "accept", - Sources: []string{"group:team"}, - Destinations: []string{ - "autogroup:internet:*", - }, - }, - }, - }, + pol: ` +{ + "groups": { + "group:team": [ + "user3", + "user2", + "user1" + ] + }, + "hosts": { + "internal": "100.64.0.100/32" + }, + "acls": [ + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "internal:*" + ] + }, + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "autogroup:internet:*" + ] + } + ], +} +`, node: &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), @@ -1989,31 +2021,42 @@ func TestReduceFilterRules(t *testing.T) { }, { name: "1786-reducing-breaks-exit-nodes-the-exit", - pol: ACLPolicy{ - Hosts: Hosts{ - // Exit node - "internal": netip.MustParsePrefix("100.64.0.100/32"), - }, - Groups: Groups{ - "group:team": {"user3", "user2", "user1"}, - }, - ACLs: []ACL{ - { - Action: "accept", - Sources: []string{"group:team"}, - Destinations: []string{ - "internal:*", - }, - }, - { - Action: "accept", - Sources: []string{"group:team"}, - Destinations: []string{ - "autogroup:internet:*", - }, - }, - }, - }, + pol: ` +{ + "groups": { + "group:team": [ + "user3", + "user2", + "user1" + ] + }, + "hosts": { + "internal": "100.64.0.100/32" + }, + "acls": [ + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "internal:*" + ] + }, + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "autogroup:internet:*" + ] + } + ], +} +`, node: &types.Node{ IPv4: iap("100.64.0.100"), IPv6: iap("fd7a:115c:a1e0::100"), @@ -2056,60 +2099,71 @@ func TestReduceFilterRules(t *testing.T) { }, { name: "1786-reducing-breaks-exit-nodes-the-example-from-issue", - pol: ACLPolicy{ - Hosts: Hosts{ - // Exit node - "internal": netip.MustParsePrefix("100.64.0.100/32"), - }, - Groups: Groups{ - "group:team": {"user3", "user2", "user1"}, - }, - ACLs: []ACL{ - { - Action: "accept", - Sources: []string{"group:team"}, - Destinations: []string{ - "internal:*", - }, - }, - { - Action: "accept", - Sources: []string{"group:team"}, - Destinations: []string{ - "0.0.0.0/5:*", - "8.0.0.0/7:*", - "11.0.0.0/8:*", - "12.0.0.0/6:*", - "16.0.0.0/4:*", - "32.0.0.0/3:*", - "64.0.0.0/2:*", - "128.0.0.0/3:*", - "160.0.0.0/5:*", - "168.0.0.0/6:*", - "172.0.0.0/12:*", - "172.32.0.0/11:*", - "172.64.0.0/10:*", - "172.128.0.0/9:*", - "173.0.0.0/8:*", - "174.0.0.0/7:*", - "176.0.0.0/4:*", - "192.0.0.0/9:*", - "192.128.0.0/11:*", - "192.160.0.0/13:*", - "192.169.0.0/16:*", - "192.170.0.0/15:*", - "192.172.0.0/14:*", - "192.176.0.0/12:*", - "192.192.0.0/10:*", - "193.0.0.0/8:*", - "194.0.0.0/7:*", - "196.0.0.0/6:*", - "200.0.0.0/5:*", - "208.0.0.0/4:*", - }, - }, - }, - }, + pol: ` +{ + "groups": { + "group:team": [ + "user3", + "user2", + "user1" + ] + }, + "hosts": { + "internal": "100.64.0.100/32" + }, + "acls": [ + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "internal:*" + ] + }, + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "0.0.0.0/5:*", + "8.0.0.0/7:*", + "11.0.0.0/8:*", + "12.0.0.0/6:*", + "16.0.0.0/4:*", + "32.0.0.0/3:*", + "64.0.0.0/2:*", + "128.0.0.0/3:*", + "160.0.0.0/5:*", + "168.0.0.0/6:*", + "172.0.0.0/12:*", + "172.32.0.0/11:*", + "172.64.0.0/10:*", + "172.128.0.0/9:*", + "173.0.0.0/8:*", + "174.0.0.0/7:*", + "176.0.0.0/4:*", + "192.0.0.0/9:*", + "192.128.0.0/11:*", + "192.160.0.0/13:*", + "192.169.0.0/16:*", + "192.170.0.0/15:*", + "192.172.0.0/14:*", + "192.176.0.0/12:*", + "192.192.0.0/10:*", + "193.0.0.0/8:*", + "194.0.0.0/7:*", + "196.0.0.0/6:*", + "200.0.0.0/5:*", + "208.0.0.0/4:*" + ] + } + ], +} +`, node: &types.Node{ IPv4: iap("100.64.0.100"), IPv6: iap("fd7a:115c:a1e0::100"), @@ -2186,32 +2240,43 @@ func TestReduceFilterRules(t *testing.T) { }, { name: "1786-reducing-breaks-exit-nodes-app-connector-like", - pol: ACLPolicy{ - Hosts: Hosts{ - // Exit node - "internal": netip.MustParsePrefix("100.64.0.100/32"), - }, - Groups: Groups{ - "group:team": {"user3", "user2", "user1"}, - }, - ACLs: []ACL{ - { - Action: "accept", - Sources: []string{"group:team"}, - Destinations: []string{ - "internal:*", - }, - }, - { - Action: "accept", - Sources: []string{"group:team"}, - Destinations: []string{ - "8.0.0.0/8:*", - "16.0.0.0/8:*", - }, - }, - }, - }, + pol: ` +{ + "groups": { + "group:team": [ + "user3", + "user2", + "user1" + ] + }, + "hosts": { + "internal": "100.64.0.100/32" + }, + "acls": [ + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "internal:*" + ] + }, + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "8.0.0.0/8:*", + "16.0.0.0/8:*" + ] + } + ], +} +`, node: &types.Node{ IPv4: iap("100.64.0.100"), IPv6: iap("fd7a:115c:a1e0::100"), @@ -2263,32 +2328,43 @@ func TestReduceFilterRules(t *testing.T) { }, { name: "1786-reducing-breaks-exit-nodes-app-connector-like2", - pol: ACLPolicy{ - Hosts: Hosts{ - // Exit node - "internal": netip.MustParsePrefix("100.64.0.100/32"), - }, - Groups: Groups{ - "group:team": {"user3", "user2", "user1"}, - }, - ACLs: []ACL{ - { - Action: "accept", - Sources: []string{"group:team"}, - Destinations: []string{ - "internal:*", - }, - }, - { - Action: "accept", - Sources: []string{"group:team"}, - Destinations: []string{ - "8.0.0.0/16:*", - "16.0.0.0/16:*", - }, - }, - }, - }, + pol: ` +{ + "groups": { + "group:team": [ + "user3", + "user2", + "user1" + ] + }, + "hosts": { + "internal": "100.64.0.100/32" + }, + "acls": [ + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "internal:*" + ] + }, + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "8.0.0.0/16:*", + "16.0.0.0/16:*" + ] + } + ], +} +`, node: &types.Node{ IPv4: iap("100.64.0.100"), IPv6: iap("fd7a:115c:a1e0::100"), @@ -2340,25 +2416,32 @@ func TestReduceFilterRules(t *testing.T) { }, { name: "1817-reduce-breaks-32-mask", - pol: ACLPolicy{ - Hosts: Hosts{ - "vlan1": netip.MustParsePrefix("172.16.0.0/24"), - "dns1": netip.MustParsePrefix("172.16.0.21/32"), - }, - Groups: Groups{ - "group:access": {"user1"}, - }, - ACLs: []ACL{ - { - Action: "accept", - Sources: []string{"group:access"}, - Destinations: []string{ - "tag:access-servers:*", - "dns1:*", - }, - }, - }, - }, + pol: ` +{ + "groups": { + "group:access": [ + "user1" + ] + }, + "hosts": { + "dns1": "172.16.0.21/32", + "vlan1": "172.16.0.0/24" + }, + "acls": [ + { + "action": "accept", + "proto": "", + "src": [ + "group:access" + ], + "dst": [ + "tag:access-servers:*", + "dns1:*" + ] + } + ], +} +`, node: &types.Node{ IPv4: iap("100.64.0.100"), IPv6: iap("fd7a:115c:a1e0::100"), @@ -2399,7 +2482,11 @@ func TestReduceFilterRules(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, _ := tt.pol.CompileFilterRules( + pol, err := LoadACLPolicyFromBytes([]byte(tt.pol)) + if err != nil { + t.Fatalf("parsing policy: %s", err) + } + got, _ := pol.CompileFilterRules( append(tt.peers, tt.node), ) From 53cbdfc277dee48eaa05c2f343ce171e89ddabe5 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Thu, 17 Oct 2024 01:37:05 +0200 Subject: [PATCH 5/6] copy reduce test filter test to compare v1 vs v2 Signed-off-by: Kristoffer Dalby --- hscontrol/policy/acls_test.go | 14 +- hscontrol/policyv2/filter_test.go | 738 ++++++++++++++++++++++++++++++ hscontrol/policyv2/types.go | 2 - 3 files changed, 745 insertions(+), 9 deletions(-) diff --git a/hscontrol/policy/acls_test.go b/hscontrol/policy/acls_test.go index b3aa366e..af035f0f 100644 --- a/hscontrol/policy/acls_test.go +++ b/hscontrol/policy/acls_test.go @@ -1764,9 +1764,9 @@ var tsExitNodeDest = []tailcfg.NetPortRange{ }, } -// hsExitNodeDest is the list of destination IP ranges that are allowed when +// hsExitNodeDestForTest is the list of destination IP ranges that are allowed when // we use headscale "autogroup:internet". -var hsExitNodeDest = []tailcfg.NetPortRange{ +var hsExitNodeDestForTest = []tailcfg.NetPortRange{ {IP: "0.0.0.0/5", Ports: tailcfg.PortRangeAny}, {IP: "8.0.0.0/7", Ports: tailcfg.PortRangeAny}, {IP: "11.0.0.0/8", Ports: tailcfg.PortRangeAny}, @@ -1823,13 +1823,13 @@ func TestTheInternet(t *testing.T) { internetPrefs := internetSet.Prefixes() for i := range internetPrefs { - if internetPrefs[i].String() != hsExitNodeDest[i].IP { - t.Errorf("prefix from internet set %q != hsExit list %q", internetPrefs[i].String(), hsExitNodeDest[i].IP) + if internetPrefs[i].String() != hsExitNodeDestForTest[i].IP { + t.Errorf("prefix from internet set %q != hsExit list %q", internetPrefs[i].String(), hsExitNodeDestForTest[i].IP) } } - if len(internetPrefs) != len(hsExitNodeDest) { - t.Fatalf("expected same length of prefixes, internet: %d, hsExit: %d", len(internetPrefs), len(hsExitNodeDest)) + if len(internetPrefs) != len(hsExitNodeDestForTest) { + t.Fatalf("expected same length of prefixes, internet: %d, hsExit: %d", len(internetPrefs), len(hsExitNodeDestForTest)) } } @@ -2093,7 +2093,7 @@ func TestReduceFilterRules(t *testing.T) { }, { SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, - DstPorts: hsExitNodeDest, + DstPorts: hsExitNodeDestForTest, }, }, }, diff --git a/hscontrol/policyv2/filter_test.go b/hscontrol/policyv2/filter_test.go index 6b76b168..7dfb6a7b 100644 --- a/hscontrol/policyv2/filter_test.go +++ b/hscontrol/policyv2/filter_test.go @@ -1,10 +1,14 @@ package policyv2 import ( + "net/netip" "testing" "github.com/google/go-cmp/cmp" + "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" + "github.com/rs/zerolog/log" + "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" ) @@ -376,3 +380,737 @@ func TestParsing(t *testing.T) { }) } } + +// hsExitNodeDestForTest is the list of destination IP ranges that are allowed when +// we use headscale "autogroup:internet". +var hsExitNodeDestForTest = []tailcfg.NetPortRange{ + {IP: "0.0.0.0/5", Ports: tailcfg.PortRangeAny}, + {IP: "8.0.0.0/7", Ports: tailcfg.PortRangeAny}, + {IP: "11.0.0.0/8", Ports: tailcfg.PortRangeAny}, + {IP: "12.0.0.0/6", Ports: tailcfg.PortRangeAny}, + {IP: "16.0.0.0/4", Ports: tailcfg.PortRangeAny}, + {IP: "32.0.0.0/3", Ports: tailcfg.PortRangeAny}, + {IP: "64.0.0.0/3", Ports: tailcfg.PortRangeAny}, + {IP: "96.0.0.0/6", Ports: tailcfg.PortRangeAny}, + {IP: "100.0.0.0/10", Ports: tailcfg.PortRangeAny}, + {IP: "100.128.0.0/9", Ports: tailcfg.PortRangeAny}, + {IP: "101.0.0.0/8", Ports: tailcfg.PortRangeAny}, + {IP: "102.0.0.0/7", Ports: tailcfg.PortRangeAny}, + {IP: "104.0.0.0/5", Ports: tailcfg.PortRangeAny}, + {IP: "112.0.0.0/4", Ports: tailcfg.PortRangeAny}, + {IP: "128.0.0.0/3", Ports: tailcfg.PortRangeAny}, + {IP: "160.0.0.0/5", Ports: tailcfg.PortRangeAny}, + {IP: "168.0.0.0/8", Ports: tailcfg.PortRangeAny}, + {IP: "169.0.0.0/9", Ports: tailcfg.PortRangeAny}, + {IP: "169.128.0.0/10", Ports: tailcfg.PortRangeAny}, + {IP: "169.192.0.0/11", Ports: tailcfg.PortRangeAny}, + {IP: "169.224.0.0/12", Ports: tailcfg.PortRangeAny}, + {IP: "169.240.0.0/13", Ports: tailcfg.PortRangeAny}, + {IP: "169.248.0.0/14", Ports: tailcfg.PortRangeAny}, + {IP: "169.252.0.0/15", Ports: tailcfg.PortRangeAny}, + {IP: "169.255.0.0/16", Ports: tailcfg.PortRangeAny}, + {IP: "170.0.0.0/7", Ports: tailcfg.PortRangeAny}, + {IP: "172.0.0.0/12", Ports: tailcfg.PortRangeAny}, + {IP: "172.32.0.0/11", Ports: tailcfg.PortRangeAny}, + {IP: "172.64.0.0/10", Ports: tailcfg.PortRangeAny}, + {IP: "172.128.0.0/9", Ports: tailcfg.PortRangeAny}, + {IP: "173.0.0.0/8", Ports: tailcfg.PortRangeAny}, + {IP: "174.0.0.0/7", Ports: tailcfg.PortRangeAny}, + {IP: "176.0.0.0/4", Ports: tailcfg.PortRangeAny}, + {IP: "192.0.0.0/9", Ports: tailcfg.PortRangeAny}, + {IP: "192.128.0.0/11", Ports: tailcfg.PortRangeAny}, + {IP: "192.160.0.0/13", Ports: tailcfg.PortRangeAny}, + {IP: "192.169.0.0/16", Ports: tailcfg.PortRangeAny}, + {IP: "192.170.0.0/15", Ports: tailcfg.PortRangeAny}, + {IP: "192.172.0.0/14", Ports: tailcfg.PortRangeAny}, + {IP: "192.176.0.0/12", Ports: tailcfg.PortRangeAny}, + {IP: "192.192.0.0/10", Ports: tailcfg.PortRangeAny}, + {IP: "193.0.0.0/8", Ports: tailcfg.PortRangeAny}, + {IP: "194.0.0.0/7", Ports: tailcfg.PortRangeAny}, + {IP: "196.0.0.0/6", Ports: tailcfg.PortRangeAny}, + {IP: "200.0.0.0/5", Ports: tailcfg.PortRangeAny}, + {IP: "208.0.0.0/4", Ports: tailcfg.PortRangeAny}, + {IP: "224.0.0.0/3", Ports: tailcfg.PortRangeAny}, + {IP: "2000::/3", Ports: tailcfg.PortRangeAny}, +} + +func TestReduceFilterRules(t *testing.T) { + tests := []struct { + name string + node *types.Node + peers types.Nodes + pol string + want []tailcfg.FilterRule + }{ + { + name: "host1-can-reach-host2-no-rules", + pol: ` +{ + "acls": [ + { + "action": "accept", + "proto": "", + "src": [ + "100.64.0.1" + ], + "dst": [ + "100.64.0.2:*" + ] + } + ], +} +`, + node: &types.Node{ + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), + User: types.User{Name: "mickael"}, + }, + peers: types.Nodes{ + &types.Node{ + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), + User: types.User{Name: "mickael"}, + }, + }, + want: []tailcfg.FilterRule{}, + }, + { + name: "1604-subnet-routers-are-preserved", + pol: ` +{ + "groups": { + "group:admins": [ + "user1@" + ] + }, + "acls": [ + { + "action": "accept", + "proto": "", + "src": [ + "group:admins" + ], + "dst": [ + "group:admins:*" + ] + }, + { + "action": "accept", + "proto": "", + "src": [ + "group:admins" + ], + "dst": [ + "10.33.0.0/16:*" + ] + } + ], +} +`, + node: &types.Node{ + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: types.User{Name: "user1@"}, + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: []netip.Prefix{ + netip.MustParsePrefix("10.33.0.0/16"), + }, + }, + }, + peers: types.Nodes{ + &types.Node{ + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: types.User{Name: "user1@"}, + }, + }, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{ + "100.64.0.1/32", + "100.64.0.2/32", + "fd7a:115c:a1e0::1/128", + "fd7a:115c:a1e0::2/128", + }, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "100.64.0.1/32", + Ports: tailcfg.PortRangeAny, + }, + { + IP: "fd7a:115c:a1e0::1/128", + Ports: tailcfg.PortRangeAny, + }, + }, + }, + { + SrcIPs: []string{ + "100.64.0.1/32", + "100.64.0.2/32", + "fd7a:115c:a1e0::1/128", + "fd7a:115c:a1e0::2/128", + }, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "10.33.0.0/16", + Ports: tailcfg.PortRangeAny, + }, + }, + }, + }, + }, + { + name: "1786-reducing-breaks-exit-nodes-the-client", + pol: ` +{ + "groups": { + "group:team": [ + "user3@", + "user2@", + "user1@" + ] + }, + "hosts": { + "internal": "100.64.0.100/32" + }, + "acls": [ + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "internal:*" + ] + }, + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "autogroup:internet:*" + ] + } + ], +} +`, + node: &types.Node{ + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: types.User{Name: "user1@"}, + }, + peers: types.Nodes{ + &types.Node{ + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: types.User{Name: "user2@"}, + }, + // "internal" exit node + &types.Node{ + IPv4: ap("100.64.0.100"), + IPv6: ap("fd7a:115c:a1e0::100"), + User: types.User{Name: "user100@"}, + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: tsaddr.ExitRoutes(), + }, + }, + }, + want: []tailcfg.FilterRule{}, + }, + { + name: "1786-reducing-breaks-exit-nodes-the-exit", + pol: ` +{ + "groups": { + "group:team": [ + "user3@", + "user2@", + "user1@" + ] + }, + "hosts": { + "internal": "100.64.0.100/32" + }, + "acls": [ + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "internal:*" + ] + }, + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "autogroup:internet:*" + ] + } + ], +} +`, + node: &types.Node{ + IPv4: ap("100.64.0.100"), + IPv6: ap("fd7a:115c:a1e0::100"), + User: types.User{Name: "user100@"}, + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: tsaddr.ExitRoutes(), + }, + }, + peers: types.Nodes{ + &types.Node{ + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: types.User{Name: "user2@"}, + }, + &types.Node{ + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: types.User{Name: "user1@"}, + }, + }, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "100.64.0.100/32", + Ports: tailcfg.PortRangeAny, + }, + { + IP: "fd7a:115c:a1e0::100/128", + Ports: tailcfg.PortRangeAny, + }, + }, + }, + { + SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, + DstPorts: hsExitNodeDestForTest, + }, + }, + }, + { + name: "1786-reducing-breaks-exit-nodes-the-example-from-issue", + pol: ` +{ + "groups": { + "group:team": [ + "user3@", + "user2@", + "user1@" + ] + }, + "hosts": { + "internal": "100.64.0.100/32" + }, + "acls": [ + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "internal:*" + ] + }, + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "0.0.0.0/5:*", + "8.0.0.0/7:*", + "11.0.0.0/8:*", + "12.0.0.0/6:*", + "16.0.0.0/4:*", + "32.0.0.0/3:*", + "64.0.0.0/2:*", + "128.0.0.0/3:*", + "160.0.0.0/5:*", + "168.0.0.0/6:*", + "172.0.0.0/12:*", + "172.32.0.0/11:*", + "172.64.0.0/10:*", + "172.128.0.0/9:*", + "173.0.0.0/8:*", + "174.0.0.0/7:*", + "176.0.0.0/4:*", + "192.0.0.0/9:*", + "192.128.0.0/11:*", + "192.160.0.0/13:*", + "192.169.0.0/16:*", + "192.170.0.0/15:*", + "192.172.0.0/14:*", + "192.176.0.0/12:*", + "192.192.0.0/10:*", + "193.0.0.0/8:*", + "194.0.0.0/7:*", + "196.0.0.0/6:*", + "200.0.0.0/5:*", + "208.0.0.0/4:*" + ] + } + ], +} +`, + node: &types.Node{ + IPv4: ap("100.64.0.100"), + IPv6: ap("fd7a:115c:a1e0::100"), + User: types.User{Name: "user100@"}, + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: tsaddr.ExitRoutes(), + }, + }, + peers: types.Nodes{ + &types.Node{ + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: types.User{Name: "user2@"}, + }, + &types.Node{ + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: types.User{Name: "user1@"}, + }, + }, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "100.64.0.100/32", + Ports: tailcfg.PortRangeAny, + }, + { + IP: "fd7a:115c:a1e0::100/128", + Ports: tailcfg.PortRangeAny, + }, + }, + }, + { + SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "0.0.0.0/5", Ports: tailcfg.PortRangeAny}, + {IP: "8.0.0.0/7", Ports: tailcfg.PortRangeAny}, + {IP: "11.0.0.0/8", Ports: tailcfg.PortRangeAny}, + {IP: "12.0.0.0/6", Ports: tailcfg.PortRangeAny}, + {IP: "16.0.0.0/4", Ports: tailcfg.PortRangeAny}, + {IP: "32.0.0.0/3", Ports: tailcfg.PortRangeAny}, + {IP: "64.0.0.0/2", Ports: tailcfg.PortRangeAny}, + {IP: "fd7a:115c:a1e0::1/128", Ports: tailcfg.PortRangeAny}, + {IP: "fd7a:115c:a1e0::2/128", Ports: tailcfg.PortRangeAny}, + {IP: "fd7a:115c:a1e0::100/128", Ports: tailcfg.PortRangeAny}, + {IP: "128.0.0.0/3", Ports: tailcfg.PortRangeAny}, + {IP: "160.0.0.0/5", Ports: tailcfg.PortRangeAny}, + {IP: "168.0.0.0/6", Ports: tailcfg.PortRangeAny}, + {IP: "172.0.0.0/12", Ports: tailcfg.PortRangeAny}, + {IP: "172.32.0.0/11", Ports: tailcfg.PortRangeAny}, + {IP: "172.64.0.0/10", Ports: tailcfg.PortRangeAny}, + {IP: "172.128.0.0/9", Ports: tailcfg.PortRangeAny}, + {IP: "173.0.0.0/8", Ports: tailcfg.PortRangeAny}, + {IP: "174.0.0.0/7", Ports: tailcfg.PortRangeAny}, + {IP: "176.0.0.0/4", Ports: tailcfg.PortRangeAny}, + {IP: "192.0.0.0/9", Ports: tailcfg.PortRangeAny}, + {IP: "192.128.0.0/11", Ports: tailcfg.PortRangeAny}, + {IP: "192.160.0.0/13", Ports: tailcfg.PortRangeAny}, + {IP: "192.169.0.0/16", Ports: tailcfg.PortRangeAny}, + {IP: "192.170.0.0/15", Ports: tailcfg.PortRangeAny}, + {IP: "192.172.0.0/14", Ports: tailcfg.PortRangeAny}, + {IP: "192.176.0.0/12", Ports: tailcfg.PortRangeAny}, + {IP: "192.192.0.0/10", Ports: tailcfg.PortRangeAny}, + {IP: "193.0.0.0/8", Ports: tailcfg.PortRangeAny}, + {IP: "194.0.0.0/7", Ports: tailcfg.PortRangeAny}, + {IP: "196.0.0.0/6", Ports: tailcfg.PortRangeAny}, + {IP: "200.0.0.0/5", Ports: tailcfg.PortRangeAny}, + {IP: "208.0.0.0/4", Ports: tailcfg.PortRangeAny}, + }, + }, + }, + }, + { + name: "1786-reducing-breaks-exit-nodes-app-connector-like", + pol: ` +{ + "groups": { + "group:team": [ + "user3@", + "user2@", + "user1@" + ] + }, + "hosts": { + "internal": "100.64.0.100/32" + }, + "acls": [ + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "internal:*" + ] + }, + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "8.0.0.0/8:*", + "16.0.0.0/8:*" + ] + } + ], +} +`, + node: &types.Node{ + IPv4: ap("100.64.0.100"), + IPv6: ap("fd7a:115c:a1e0::100"), + User: types.User{Name: "user100@"}, + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")}, + }, + }, + peers: types.Nodes{ + &types.Node{ + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: types.User{Name: "user2@"}, + }, + &types.Node{ + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: types.User{Name: "user1@"}, + }, + }, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "100.64.0.100/32", + Ports: tailcfg.PortRangeAny, + }, + { + IP: "fd7a:115c:a1e0::100/128", + Ports: tailcfg.PortRangeAny, + }, + }, + }, + { + SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "8.0.0.0/8", + Ports: tailcfg.PortRangeAny, + }, + { + IP: "16.0.0.0/8", + Ports: tailcfg.PortRangeAny, + }, + }, + }, + }, + }, + { + name: "1786-reducing-breaks-exit-nodes-app-connector-like2", + pol: ` +{ + "groups": { + "group:team": [ + "user3@", + "user2@", + "user1@" + ] + }, + "hosts": { + "internal": "100.64.0.100/32" + }, + "acls": [ + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "internal:*" + ] + }, + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "8.0.0.0/16:*", + "16.0.0.0/16:*" + ] + } + ], +} +`, + node: &types.Node{ + IPv4: ap("100.64.0.100"), + IPv6: ap("fd7a:115c:a1e0::100"), + User: types.User{Name: "user100@"}, + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")}, + }, + }, + peers: types.Nodes{ + &types.Node{ + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: types.User{Name: "user2@"}, + }, + &types.Node{ + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: types.User{Name: "user1@"}, + }, + }, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "100.64.0.100/32", + Ports: tailcfg.PortRangeAny, + }, + { + IP: "fd7a:115c:a1e0::100/128", + Ports: tailcfg.PortRangeAny, + }, + }, + }, + { + SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "8.0.0.0/16", + Ports: tailcfg.PortRangeAny, + }, + { + IP: "16.0.0.0/16", + Ports: tailcfg.PortRangeAny, + }, + }, + }, + }, + }, + { + name: "1817-reduce-breaks-32-mask", + pol: ` +{ + "groups": { + "group:access": [ + "user1@" + ] + }, + "hosts": { + "dns1": "172.16.0.21/32", + "vlan1": "172.16.0.0/24" + }, + "acls": [ + { + "action": "accept", + "proto": "", + "src": [ + "group:access" + ], + "dst": [ + "tag:access-servers:*", + "dns1:*" + ] + } + ], +} +`, + node: &types.Node{ + IPv4: ap("100.64.0.100"), + IPv6: ap("fd7a:115c:a1e0::100"), + User: types.User{Name: "user100@"}, + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")}, + }, + ForcedTags: []string{"tag:access-servers"}, + }, + peers: types.Nodes{ + &types.Node{ + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: types.User{Name: "user1@"}, + }, + }, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32", "fd7a:115c:a1e0::1/128"}, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "100.64.0.100/32", + Ports: tailcfg.PortRangeAny, + }, + { + IP: "fd7a:115c:a1e0::100/128", + Ports: tailcfg.PortRangeAny, + }, + { + IP: "172.16.0.21/32", + Ports: tailcfg.PortRangeAny, + }, + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + polV1, err := policy.LoadACLPolicyFromBytes([]byte(tt.pol)) + if err != nil { + t.Fatalf("parsing policy: %s", err) + } + filterV1, _ := polV1.CompileFilterRules( + append(tt.peers, tt.node), + ) + polV2, err := PolicyFromBytes([]byte(tt.pol)) + if err != nil { + t.Fatalf("parsing policy: %s", err) + } + filterV2, _ := polV2.CompileFilterRules( + append(tt.peers, tt.node), + ) + + if diff := cmp.Diff(filterV1, filterV2); diff != "" { + log.Trace().Interface("got", filterV2).Msg("result") + t.Errorf("TestReduceFilterRules() unexpected diff between v1 and v2 (-want +got):\n%s", diff) + } + + // TODO(kradalby): Move this from v1, or + // rewrite. + filterV2 = policy.ReduceFilterRules(tt.node, filterV2) + + if diff := cmp.Diff(tt.want, filterV2); diff != "" { + log.Trace().Interface("got", filterV2).Msg("result") + t.Errorf("TestReduceFilterRules() unexpected result (-want +got):\n%s", diff) + } + }) + } +} diff --git a/hscontrol/policyv2/types.go b/hscontrol/policyv2/types.go index 2bbb793a..cd767f43 100644 --- a/hscontrol/policyv2/types.go +++ b/hscontrol/policyv2/types.go @@ -5,7 +5,6 @@ import ( "encoding/json" "errors" "fmt" - "log" "net/netip" "strconv" "strings" @@ -407,7 +406,6 @@ func (ve *AliasEnc) UnmarshalJSON(b []byte) error { if err := ve.Alias.Validate(); err != nil { return err } - log.Printf("val: %q as type: %T", val, ve.Alias) default: return fmt.Errorf("type %T not supported", val) } From 907449bd993a506f7002ee4eebd904fdf89e90eb Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Fri, 25 Oct 2024 11:52:12 -0500 Subject: [PATCH 6/6] policy manager Signed-off-by: Kristoffer Dalby --- hscontrol/app.go | 4 +- hscontrol/policyv2/filter.go | 105 ++++++++++++++++++++++- hscontrol/policyv2/filter_test.go | 97 +++++++++++---------- hscontrol/policyv2/policy.go | 80 ++++++++++++++++++ hscontrol/policyv2/policy_test.go | 58 +++++++++++++ hscontrol/policyv2/types.go | 134 ++++++++++++++++++++++++++---- hscontrol/policyv2/types_test.go | 22 ++--- hscontrol/types/node.go | 10 +++ hscontrol/types/users.go | 15 ++++ 9 files changed, 455 insertions(+), 70 deletions(-) create mode 100644 hscontrol/policyv2/policy.go create mode 100644 hscontrol/policyv2/policy_test.go diff --git a/hscontrol/app.go b/hscontrol/app.go index 737e8098..7140492a 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -30,6 +30,7 @@ import ( "github.com/juanfont/headscale/hscontrol/mapper" "github.com/juanfont/headscale/hscontrol/notifier" "github.com/juanfont/headscale/hscontrol/policy" + "github.com/juanfont/headscale/hscontrol/policyv2" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" zerolog "github.com/philip-bui/grpc-zerolog" @@ -88,7 +89,8 @@ type Headscale struct { DERPMap *tailcfg.DERPMap DERPServer *derpServer.DERPServer - ACLPolicy *policy.ACLPolicy + ACLPolicy *policy.ACLPolicy + PolicyManager *policyv2.PolicyManager mapper *mapper.Mapper nodeNotifier *notifier.Notifier diff --git a/hscontrol/policyv2/filter.go b/hscontrol/policyv2/filter.go index 27f8c8d8..0498e953 100644 --- a/hscontrol/policyv2/filter.go +++ b/hscontrol/policyv2/filter.go @@ -3,6 +3,7 @@ package policyv2 import ( "errors" "fmt" + "time" "github.com/juanfont/headscale/hscontrol/types" "go4.org/netipx" @@ -16,6 +17,7 @@ var ( // CompileFilterRules takes a set of nodes and an ACLPolicy and generates a // set of Tailscale compatible FilterRules used to allow traffic on clients. func (pol *Policy) CompileFilterRules( + users types.Users, nodes types.Nodes, ) ([]tailcfg.FilterRule, error) { if pol == nil { @@ -29,7 +31,7 @@ func (pol *Policy) CompileFilterRules( return nil, ErrInvalidAction } - srcIPs, err := acl.Sources.Resolve(pol, nodes) + srcIPs, err := acl.Sources.Resolve(pol, users, nodes) if err != nil { return nil, fmt.Errorf("resolving source ips: %w", err) } @@ -43,7 +45,7 @@ func (pol *Policy) CompileFilterRules( var destPorts []tailcfg.NetPortRange for _, dest := range acl.Destinations { - ips, err := dest.Alias.Resolve(pol, nodes) + ips, err := dest.Alias.Resolve(pol, users, nodes) if err != nil { return nil, err } @@ -69,6 +71,105 @@ func (pol *Policy) CompileFilterRules( return rules, nil } +func sshAction(accept bool, duration time.Duration) tailcfg.SSHAction { + return tailcfg.SSHAction{ + Reject: !accept, + Accept: accept, + SessionDuration: duration, + AllowAgentForwarding: true, + AllowLocalPortForwarding: true, + } +} + +func (pol *Policy) CompileSSHPolicy( + users types.Users, + node types.Node, + nodes types.Nodes, +) (*tailcfg.SSHPolicy, error) { + if pol == nil { + return nil, nil + } + + var rules []*tailcfg.SSHRule + + for index, rule := range pol.SSHs { + var dest netipx.IPSetBuilder + for _, src := range rule.Destinations { + ips, err := src.Resolve(pol, users, nodes) + if err != nil { + return nil, err + } + dest.AddSet(ips) + } + + destSet, err := dest.IPSet() + if err != nil { + return nil, err + } + + if !node.InIPSet(destSet) { + continue + } + + var action tailcfg.SSHAction + switch rule.Action { + case "accept": + action = sshAction(true, 0) + case "check": + action = sshAction(true, rule.CheckPeriod) + default: + return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", rule.Action, index, err) + } + + var principals []*tailcfg.SSHPrincipal + for _, src := range rule.Sources { + if isWildcard(rawSrc) { + principals = append(principals, &tailcfg.SSHPrincipal{ + Any: true, + }) + } else if isGroup(rawSrc) { + users, err := pol.expandUsersFromGroup(rawSrc) + if err != nil { + return nil, fmt.Errorf("parsing SSH policy, expanding user from group, index: %d->%d: %w", index, innerIndex, err) + } + + for _, user := range users { + principals = append(principals, &tailcfg.SSHPrincipal{ + UserLogin: user, + }) + } + } else { + expandedSrcs, err := pol.ExpandAlias( + peers, + rawSrc, + ) + if err != nil { + return nil, fmt.Errorf("parsing SSH policy, expanding alias, index: %d->%d: %w", index, innerIndex, err) + } + for _, expandedSrc := range expandedSrcs.Prefixes() { + principals = append(principals, &tailcfg.SSHPrincipal{ + NodeIP: expandedSrc.Addr().String(), + }) + } + } + } + + userMap := make(map[string]string, len(rule.Users)) + for _, user := range rule.Users { + userMap[user] = "=" + } + rules = append(rules, &tailcfg.SSHRule{ + Principals: principals, + SSHUsers: userMap, + Action: &action, + }) + } + + return &tailcfg.SSHPolicy{ + Rules: rules, + }, nil +} + func ipSetToPrefixStringList(ips *netipx.IPSet) []string { var out []string diff --git a/hscontrol/policyv2/filter_test.go b/hscontrol/policyv2/filter_test.go index 7dfb6a7b..d994b18c 100644 --- a/hscontrol/policyv2/filter_test.go +++ b/hscontrol/policyv2/filter_test.go @@ -8,6 +8,7 @@ import ( "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" "github.com/rs/zerolog/log" + "gorm.io/gorm" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" ) @@ -17,6 +18,9 @@ import ( // Move it here, run it against both old and new CompileFilterRules func TestParsing(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "testuser@"}, + } tests := []struct { name string format string @@ -340,7 +344,7 @@ func TestParsing(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - pol, err := PolicyFromBytes([]byte(tt.acl)) + pol, err := policyFromBytes([]byte(tt.acl)) if tt.wantErr && err == nil { t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr) @@ -355,18 +359,18 @@ func TestParsing(t *testing.T) { return } - rules, err := pol.CompileFilterRules(types.Nodes{ - &types.Node{ - IPv4: ap("100.100.100.100"), - }, - &types.Node{ - IPv4: ap("200.200.200.200"), - User: types.User{ - Name: "testuser@", + rules, err := pol.CompileFilterRules( + users, + types.Nodes{ + &types.Node{ + IPv4: ap("100.100.100.100"), }, - Hostinfo: &tailcfg.Hostinfo{}, - }, - }) + &types.Node{ + IPv4: ap("200.200.200.200"), + User: users[0], + Hostinfo: &tailcfg.Hostinfo{}, + }, + }) if (err != nil) != tt.wantErr { t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr) @@ -435,6 +439,14 @@ var hsExitNodeDestForTest = []tailcfg.NetPortRange{ } func TestReduceFilterRules(t *testing.T) { + users := types.Users{ + types.User{Model: gorm.Model{ID: 1}, Name: "mickael"}, + types.User{Model: gorm.Model{ID: 2}, Name: "user1@"}, + types.User{Model: gorm.Model{ID: 3}, Name: "user2@"}, + types.User{Model: gorm.Model{ID: 4}, Name: "user100@"}, + types.User{Model: gorm.Model{ID: 5}, Name: "user3@"}, + } + tests := []struct { name string node *types.Node @@ -463,13 +475,13 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), - User: types.User{Name: "mickael"}, + User: users[0], }, peers: types.Nodes{ &types.Node{ IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), - User: types.User{Name: "mickael"}, + User: users[0], }, }, want: []tailcfg.FilterRule{}, @@ -510,7 +522,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1@"}, + User: users[1], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{ netip.MustParsePrefix("10.33.0.0/16"), @@ -521,7 +533,7 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user1@"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -600,19 +612,19 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1@"}, + User: users[1], }, peers: types.Nodes{ &types.Node{ IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2@"}, + User: users[2], }, // "internal" exit node &types.Node{ IPv4: ap("100.64.0.100"), IPv6: ap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100@"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tsaddr.ExitRoutes(), }, @@ -661,7 +673,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: ap("100.64.0.100"), IPv6: ap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100@"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tsaddr.ExitRoutes(), }, @@ -670,12 +682,12 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2@"}, + User: users[2], }, &types.Node{ IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1@"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -768,7 +780,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: ap("100.64.0.100"), IPv6: ap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100@"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tsaddr.ExitRoutes(), }, @@ -777,12 +789,12 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2@"}, + User: users[2], }, &types.Node{ IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1@"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -809,9 +821,11 @@ func TestReduceFilterRules(t *testing.T) { {IP: "16.0.0.0/4", Ports: tailcfg.PortRangeAny}, {IP: "32.0.0.0/3", Ports: tailcfg.PortRangeAny}, {IP: "64.0.0.0/2", Ports: tailcfg.PortRangeAny}, - {IP: "fd7a:115c:a1e0::1/128", Ports: tailcfg.PortRangeAny}, - {IP: "fd7a:115c:a1e0::2/128", Ports: tailcfg.PortRangeAny}, - {IP: "fd7a:115c:a1e0::100/128", Ports: tailcfg.PortRangeAny}, + // This should not be included I believe, seems like + // this is a bug in the v1 code. + // {IP: "fd7a:115c:a1e0::1/128", Ports: tailcfg.PortRangeAny}, + // {IP: "fd7a:115c:a1e0::2/128", Ports: tailcfg.PortRangeAny}, + // {IP: "fd7a:115c:a1e0::100/128", Ports: tailcfg.PortRangeAny}, {IP: "128.0.0.0/3", Ports: tailcfg.PortRangeAny}, {IP: "160.0.0.0/5", Ports: tailcfg.PortRangeAny}, {IP: "168.0.0.0/6", Ports: tailcfg.PortRangeAny}, @@ -881,7 +895,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: ap("100.64.0.100"), IPv6: ap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100@"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")}, }, @@ -890,12 +904,12 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2@"}, + User: users[2], }, &types.Node{ IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1@"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -969,7 +983,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: ap("100.64.0.100"), IPv6: ap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100@"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")}, }, @@ -978,12 +992,12 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2@"}, + User: users[2], }, &types.Node{ IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1@"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -1046,7 +1060,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: ap("100.64.0.100"), IPv6: ap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100@"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")}, }, @@ -1056,7 +1070,7 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1@"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -1090,22 +1104,19 @@ func TestReduceFilterRules(t *testing.T) { filterV1, _ := polV1.CompileFilterRules( append(tt.peers, tt.node), ) - polV2, err := PolicyFromBytes([]byte(tt.pol)) + pm, err := NewPolicyManager([]byte(tt.pol), users, append(tt.peers, tt.node)) if err != nil { t.Fatalf("parsing policy: %s", err) } - filterV2, _ := polV2.CompileFilterRules( - append(tt.peers, tt.node), - ) - if diff := cmp.Diff(filterV1, filterV2); diff != "" { - log.Trace().Interface("got", filterV2).Msg("result") + if diff := cmp.Diff(filterV1, pm.Filter()); diff != "" { + log.Trace().Interface("got", pm.Filter()).Msg("result") t.Errorf("TestReduceFilterRules() unexpected diff between v1 and v2 (-want +got):\n%s", diff) } // TODO(kradalby): Move this from v1, or // rewrite. - filterV2 = policy.ReduceFilterRules(tt.node, filterV2) + filterV2 := policy.ReduceFilterRules(tt.node, pm.Filter()) if diff := cmp.Diff(tt.want, filterV2); diff != "" { log.Trace().Interface("got", filterV2).Msg("result") diff --git a/hscontrol/policyv2/policy.go b/hscontrol/policyv2/policy.go new file mode 100644 index 00000000..99b30bc3 --- /dev/null +++ b/hscontrol/policyv2/policy.go @@ -0,0 +1,80 @@ +package policyv2 + +import ( + "fmt" + "sync" + + "github.com/juanfont/headscale/hscontrol/types" + "tailscale.com/tailcfg" +) + +type PolicyManager struct { + mu sync.Mutex + pol *Policy + users []types.User + nodes types.Nodes + + filter []tailcfg.FilterRule + + // TODO(kradalby): Implement SSH policy + sshPolicy *tailcfg.SSHPolicy +} + +// NewPolicyManager creates a new PolicyManager from a policy file and a list of users and nodes. +// It returns an error if the policy file is invalid. +// The policy manager will update the filter rules based on the users and nodes. +func NewPolicyManager(b []byte, users []types.User, nodes types.Nodes) (*PolicyManager, error) { + policy, err := policyFromBytes(b) + if err != nil { + return nil, fmt.Errorf("parsing policy: %w", err) + } + + pm := PolicyManager{ + pol: policy, + users: users, + nodes: nodes, + } + + err = pm.updateLocked() + if err != nil { + return nil, err + } + + return &pm, nil +} + +// Filter returns the current filter rules for the entire tailnet. +func (pm *PolicyManager) Filter() []tailcfg.FilterRule { + pm.mu.Lock() + defer pm.mu.Unlock() + return pm.filter +} + +// updateLocked updates the filter rules based on the current policy and nodes. +// It must be called with the lock held. +func (pm *PolicyManager) updateLocked() error { + filter, err := pm.pol.CompileFilterRules(pm.users, pm.nodes) + if err != nil { + return fmt.Errorf("compiling filter rules: %w", err) + } + + pm.filter = filter + + return nil +} + +// SetUsers updates the users in the policy manager and updates the filter rules. +func (pm *PolicyManager) SetUsers(users []types.User) error { + pm.mu.Lock() + defer pm.mu.Unlock() + pm.users = users + return pm.updateLocked() +} + +// SetNodes updates the nodes in the policy manager and updates the filter rules. +func (pm *PolicyManager) SetNodes(nodes types.Nodes) error { + pm.mu.Lock() + defer pm.mu.Unlock() + pm.nodes = nodes + return pm.updateLocked() +} diff --git a/hscontrol/policyv2/policy_test.go b/hscontrol/policyv2/policy_test.go new file mode 100644 index 00000000..b4496a6e --- /dev/null +++ b/hscontrol/policyv2/policy_test.go @@ -0,0 +1,58 @@ +package policyv2 + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/require" + "gorm.io/gorm" + "tailscale.com/tailcfg" +) + +func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo) *types.Node { + return &types.Node{ + ID: 0, + Hostname: name, + IPv4: ap(ipv4), + IPv6: ap(ipv6), + User: user, + UserID: user.ID, + Hostinfo: hostinfo, + } +} + +func TestPolicyManager(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "testuser", Email: "testuser@headscale.net"}, + {Model: gorm.Model{ID: 2}, Name: "otheruser", Email: "otheruser@headscale.net"}, + } + + tests := []struct { + name string + pol string + nodes types.Nodes + wantFilter []tailcfg.FilterRule + }{ + { + name: "empty-policy", + pol: "{}", + nodes: types.Nodes{}, + wantFilter: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pm, err := NewPolicyManager([]byte(tt.pol), users, tt.nodes) + require.NoError(t, err) + + filter := pm.Filter() + if diff := cmp.Diff(filter, tt.wantFilter); diff != "" { + t.Errorf("Filter() mismatch (-want +got):\n%s", diff) + } + + // TODO(kradalby): Test SSH Policy + }) + } +} diff --git a/hscontrol/policyv2/types.go b/hscontrol/policyv2/types.go index cd767f43..5a4c6de8 100644 --- a/hscontrol/policyv2/types.go +++ b/hscontrol/policyv2/types.go @@ -8,6 +8,7 @@ import ( "net/netip" "strconv" "strings" + "time" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" @@ -64,7 +65,7 @@ func (a Asterix) UnmarshalJSON(b []byte) error { return nil } -func (a Asterix) Resolve(_ *Policy, nodes types.Nodes) (*netipx.IPSet, error) { +func (a Asterix) Resolve(_ *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) { var ips netipx.IPSetBuilder ips.AddPrefix(tsaddr.AllIPv4()) @@ -99,15 +100,47 @@ func (u Username) CanBeTagOwner() bool { return true } -func (u Username) Resolve(_ *Policy, nodes types.Nodes) (*netipx.IPSet, error) { +func (u Username) resolveUser(users types.Users) (*types.User, error) { + var potentialUsers types.Users + for _, user := range users { + if user.ProviderIdentifier == string(u) { + potentialUsers = append(potentialUsers, user) + + break + } + if user.Email == string(u) { + potentialUsers = append(potentialUsers, user) + } + if user.Name == string(u) { + potentialUsers = append(potentialUsers, user) + } + } + + if len(potentialUsers) > 1 { + return nil, fmt.Errorf("unable to resolve user identifier to distinct: %s matched multiple %s", u, potentialUsers) + } else if len(potentialUsers) == 0 { + return nil, fmt.Errorf("unable to resolve user identifier, no user found: %s not in %s", u, users) + } + + user := potentialUsers[0] + + return &user, nil +} + +func (u Username) Resolve(_ *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) { var ips netipx.IPSetBuilder + user, err := u.resolveUser(users) + if err != nil { + return nil, err + } + for _, node := range nodes { if node.IsTagged() { continue } - if node.User.Username() == string(u) { + if node.User.ID == user.ID { node.AppendToIPSet(&ips) } } @@ -137,11 +170,11 @@ func (g Group) CanBeTagOwner() bool { return true } -func (g Group) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) { +func (g Group) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) { var ips netipx.IPSetBuilder for _, user := range p.Groups[g] { - uips, err := user.Resolve(nil, nodes) + uips, err := user.Resolve(nil, users, nodes) if err != nil { return nil, err } @@ -170,7 +203,7 @@ func (t *Tag) UnmarshalJSON(b []byte) error { return nil } -func (t Tag) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) { +func (t Tag) Resolve(p *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) { var ips netipx.IPSetBuilder for _, node := range nodes { @@ -197,7 +230,7 @@ func (h *Host) UnmarshalJSON(b []byte) error { return nil } -func (h Host) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) { +func (h Host) Resolve(p *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) { var ips netipx.IPSetBuilder pref, ok := p.Hosts[h] @@ -208,11 +241,26 @@ func (h Host) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) { if err != nil { return nil, err } + + // If the IP is a single host, look for a node to ensure we add all the IPs of + // the node to the IPSet. + appendIfNodeHasIP(nodes, &ips, pref) ips.AddPrefix(netip.Prefix(pref)) return ips.IPSet() } +func appendIfNodeHasIP(nodes types.Nodes, ips *netipx.IPSetBuilder, pref Prefix) { + if netip.Prefix(pref).IsSingleIP() { + addr := netip.Prefix(pref).Addr() + for _, node := range nodes { + if node.HasIP(addr) { + node.AppendToIPSet(ips) + } + } + } +} + type Prefix netip.Prefix func (p Prefix) Validate() error { @@ -261,9 +309,10 @@ func (p *Prefix) UnmarshalJSON(b []byte) error { return nil } -func (p Prefix) Resolve(_ *Policy, _ types.Nodes) (*netipx.IPSet, error) { +func (p Prefix) Resolve(_ *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) { var ips netipx.IPSetBuilder + appendIfNodeHasIP(nodes, &ips, p) ips.AddPrefix(netip.Prefix(p)) return ips.IPSet() @@ -296,7 +345,7 @@ func (ag *AutoGroup) UnmarshalJSON(b []byte) error { return nil } -func (ag AutoGroup) Resolve(_ *Policy, _ types.Nodes) (*netipx.IPSet, error) { +func (ag AutoGroup) Resolve(_ *Policy, _ types.Users, _ types.Nodes) (*netipx.IPSet, error) { switch ag { case AutoGroupInternet: return theInternet(), nil @@ -308,7 +357,7 @@ func (ag AutoGroup) Resolve(_ *Policy, _ types.Nodes) (*netipx.IPSet, error) { type Alias interface { Validate() error UnmarshalJSON([]byte) error - Resolve(*Policy, types.Nodes) (*netipx.IPSet, error) + Resolve(*Policy, types.Users, types.Nodes) (*netipx.IPSet, error) } type AliasWithPorts struct { @@ -428,11 +477,11 @@ func (a *Aliases) UnmarshalJSON(b []byte) error { return nil } -func (a Aliases) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) { +func (a Aliases) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) { var ips netipx.IPSetBuilder for _, alias := range a { - aips, err := alias.Resolve(p, nodes) + aips, err := alias.Resolve(p, users, nodes) if err != nil { return nil, err } @@ -530,10 +579,67 @@ type Policy struct { TagOwners TagOwners `json:"tagOwners"` ACLs []ACL `json:"acls"` AutoApprovers AutoApprovers `json:"autoApprovers"` - // SSHs []SSH `json:"ssh"` + SSHs []SSH `json:"ssh"` } -func PolicyFromBytes(b []byte) (*Policy, error) { +// SSH controls who can ssh into which machines. +type SSH struct { + Action string `json:"action"` + Sources SSHSrcAliases `json:"src"` + Destinations SSHDstAliases `json:"dst"` + Users []SSHUser `json:"users"` + CheckPeriod time.Duration `json:"checkPeriod,omitempty"` +} + +// SSHSrcAliases is a list of aliases that can be used as sources in an SSH rule. +// It can be a list of usernames, groups, tags or autogroups. +type SSHSrcAliases []Alias + +func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error { + var aliases []AliasEnc + err := json.Unmarshal(b, &aliases) + if err != nil { + return err + } + + *a = make([]Alias, len(aliases)) + for i, alias := range aliases { + switch alias.Alias.(type) { + case *Username, *Group, *Tag, *AutoGroup: + (*a)[i] = alias.Alias + default: + return fmt.Errorf("type %T not supported", alias.Alias) + } + } + return nil +} + +// SSHDstAliases is a list of aliases that can be used as destinations in an SSH rule. +// It can be a list of usernames, tags or autogroups. +type SSHDstAliases []Alias + +func (a *SSHDstAliases) UnmarshalJSON(b []byte) error { + var aliases []AliasEnc + err := json.Unmarshal(b, &aliases) + if err != nil { + return err + } + + *a = make([]Alias, len(aliases)) + for i, alias := range aliases { + switch alias.Alias.(type) { + case *Username, *Tag, *AutoGroup: + (*a)[i] = alias.Alias + default: + return fmt.Errorf("type %T not supported", alias.Alias) + } + } + return nil +} + +type SSHUser string + +func policyFromBytes(b []byte) (*Policy, error) { var policy Policy ast, err := hujson.Parse(b) if err != nil { diff --git a/hscontrol/policyv2/types_test.go b/hscontrol/policyv2/types_test.go index 018f9951..9060e657 100644 --- a/hscontrol/policyv2/types_test.go +++ b/hscontrol/policyv2/types_test.go @@ -173,7 +173,7 @@ func TestUnmarshalPolicy(t *testing.T) { Destinations: []AliasWithPorts{ { Alias: ptr.To(Username("otheruser@headscale.net")), - Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}}, + Ports: []tailcfg.PortRange{{First: 80, Last: 80}}, }, }, }, @@ -186,7 +186,7 @@ func TestUnmarshalPolicy(t *testing.T) { Destinations: []AliasWithPorts{ { Alias: gp("group:other"), - Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}}, + Ports: []tailcfg.PortRange{{First: 80, Last: 80}}, }, }, }, @@ -199,7 +199,7 @@ func TestUnmarshalPolicy(t *testing.T) { Destinations: []AliasWithPorts{ { Alias: pp("100.101.102.104/32"), - Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}}, + Ports: []tailcfg.PortRange{{First: 80, Last: 80}}, }, }, }, @@ -212,7 +212,7 @@ func TestUnmarshalPolicy(t *testing.T) { Destinations: []AliasWithPorts{ { Alias: pp("172.16.0.0/16"), - Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}}, + Ports: []tailcfg.PortRange{{First: 80, Last: 80}}, }, }, }, @@ -225,7 +225,7 @@ func TestUnmarshalPolicy(t *testing.T) { Destinations: []AliasWithPorts{ { Alias: hp("host-1"), - Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 88}}, + Ports: []tailcfg.PortRange{{First: 80, Last: 88}}, }, }, }, @@ -239,8 +239,8 @@ func TestUnmarshalPolicy(t *testing.T) { { Alias: tp("tag:user"), Ports: []tailcfg.PortRange{ - tailcfg.PortRange{First: 80, Last: 80}, - tailcfg.PortRange{First: 443, Last: 443}, + {First: 80, Last: 80}, + {First: 443, Last: 443}, }, }, }, @@ -255,7 +255,7 @@ func TestUnmarshalPolicy(t *testing.T) { { Alias: agp("autogroup:internet"), Ports: []tailcfg.PortRange{ - tailcfg.PortRange{First: 80, Last: 80}, + {First: 80, Last: 80}, }, }, }, @@ -341,7 +341,7 @@ func TestUnmarshalPolicy(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - policy, err := PolicyFromBytes([]byte(tt.input)) + policy, err := policyFromBytes([]byte(tt.input)) // TODO(kradalby): This error checking is broken, // but so is my brain, #longflight if err == nil { @@ -538,7 +538,9 @@ func TestResolvePolicy(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ips, err := tt.toResolve.Resolve(tt.pol, tt.nodes) + ips, err := tt.toResolve.Resolve(tt.pol, + types.Users{}, + tt.nodes) if err != nil { t.Fatalf("failed to resolve: %s", err) } diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 92aba945..bc195c14 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -135,6 +135,16 @@ func (node *Node) IPs() []netip.Addr { return ret } +// HasIP reports if a node has a given IP address. +func (node *Node) HasIP(i netip.Addr) bool { + for _, ip := range node.IPs() { + if ip.Compare(i) == 0 { + return true + } + } + return false +} + // IsTagged reports if a device is tagged // and therefore should not be treated as a // user owned device. diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go index f983d7f5..aa73fdf8 100644 --- a/hscontrol/types/users.go +++ b/hscontrol/types/users.go @@ -2,7 +2,9 @@ package types import ( "cmp" + "fmt" "strconv" + "strings" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/util" @@ -13,6 +15,19 @@ import ( type UserID uint64 +type Users []User + +func (u Users) String() string { + var sb strings.Builder + sb.WriteString("[ ") + for _, user := range u { + fmt.Fprintf(&sb, "%d: %s, ", user.ID, user.Name) + } + sb.WriteString(" ]") + + return sb.String() +} + // User is the way Headscale implements the concept of users in Tailscale // // At the end of the day, users in Tailscale are some kind of 'bubbles' or users