diff --git a/hscontrol/app.go b/hscontrol/app.go index 5c85b064..ecffc06f 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -30,6 +30,7 @@ import ( "github.com/juanfont/headscale/hscontrol/mapper" "github.com/juanfont/headscale/hscontrol/notifier" "github.com/juanfont/headscale/hscontrol/policy" + "github.com/juanfont/headscale/hscontrol/policyv2" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" zerolog "github.com/philip-bui/grpc-zerolog" @@ -88,7 +89,8 @@ type Headscale struct { DERPMap *tailcfg.DERPMap DERPServer *derpServer.DERPServer - ACLPolicy *policy.ACLPolicy + ACLPolicy *policy.ACLPolicy + PolicyManager *policyv2.PolicyManager mapper *mapper.Mapper nodeNotifier *notifier.Notifier diff --git a/hscontrol/policyv2/filter.go b/hscontrol/policyv2/filter.go index 27f8c8d8..0498e953 100644 --- a/hscontrol/policyv2/filter.go +++ b/hscontrol/policyv2/filter.go @@ -3,6 +3,7 @@ package policyv2 import ( "errors" "fmt" + "time" "github.com/juanfont/headscale/hscontrol/types" "go4.org/netipx" @@ -16,6 +17,7 @@ var ( // CompileFilterRules takes a set of nodes and an ACLPolicy and generates a // set of Tailscale compatible FilterRules used to allow traffic on clients. func (pol *Policy) CompileFilterRules( + users types.Users, nodes types.Nodes, ) ([]tailcfg.FilterRule, error) { if pol == nil { @@ -29,7 +31,7 @@ func (pol *Policy) CompileFilterRules( return nil, ErrInvalidAction } - srcIPs, err := acl.Sources.Resolve(pol, nodes) + srcIPs, err := acl.Sources.Resolve(pol, users, nodes) if err != nil { return nil, fmt.Errorf("resolving source ips: %w", err) } @@ -43,7 +45,7 @@ func (pol *Policy) CompileFilterRules( var destPorts []tailcfg.NetPortRange for _, dest := range acl.Destinations { - ips, err := dest.Alias.Resolve(pol, nodes) + ips, err := dest.Alias.Resolve(pol, users, nodes) if err != nil { return nil, err } @@ -69,6 +71,105 @@ func (pol *Policy) CompileFilterRules( return rules, nil } +func sshAction(accept bool, duration time.Duration) tailcfg.SSHAction { + return tailcfg.SSHAction{ + Reject: !accept, + Accept: accept, + SessionDuration: duration, + AllowAgentForwarding: true, + AllowLocalPortForwarding: true, + } +} + +func (pol *Policy) CompileSSHPolicy( + users types.Users, + node types.Node, + nodes types.Nodes, +) (*tailcfg.SSHPolicy, error) { + if pol == nil { + return nil, nil + } + + var rules []*tailcfg.SSHRule + + for index, rule := range pol.SSHs { + var dest netipx.IPSetBuilder + for _, src := range rule.Destinations { + ips, err := src.Resolve(pol, users, nodes) + if err != nil { + return nil, err + } + dest.AddSet(ips) + } + + destSet, err := dest.IPSet() + if err != nil { + return nil, err + } + + if !node.InIPSet(destSet) { + continue + } + + var action tailcfg.SSHAction + switch rule.Action { + case "accept": + action = sshAction(true, 0) + case "check": + action = sshAction(true, rule.CheckPeriod) + default: + return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", rule.Action, index, err) + } + + var principals []*tailcfg.SSHPrincipal + for _, src := range rule.Sources { + if isWildcard(rawSrc) { + principals = append(principals, &tailcfg.SSHPrincipal{ + Any: true, + }) + } else if isGroup(rawSrc) { + users, err := pol.expandUsersFromGroup(rawSrc) + if err != nil { + return nil, fmt.Errorf("parsing SSH policy, expanding user from group, index: %d->%d: %w", index, innerIndex, err) + } + + for _, user := range users { + principals = append(principals, &tailcfg.SSHPrincipal{ + UserLogin: user, + }) + } + } else { + expandedSrcs, err := pol.ExpandAlias( + peers, + rawSrc, + ) + if err != nil { + return nil, fmt.Errorf("parsing SSH policy, expanding alias, index: %d->%d: %w", index, innerIndex, err) + } + for _, expandedSrc := range expandedSrcs.Prefixes() { + principals = append(principals, &tailcfg.SSHPrincipal{ + NodeIP: expandedSrc.Addr().String(), + }) + } + } + } + + userMap := make(map[string]string, len(rule.Users)) + for _, user := range rule.Users { + userMap[user] = "=" + } + rules = append(rules, &tailcfg.SSHRule{ + Principals: principals, + SSHUsers: userMap, + Action: &action, + }) + } + + return &tailcfg.SSHPolicy{ + Rules: rules, + }, nil +} + func ipSetToPrefixStringList(ips *netipx.IPSet) []string { var out []string diff --git a/hscontrol/policyv2/filter_test.go b/hscontrol/policyv2/filter_test.go index 7dfb6a7b..d994b18c 100644 --- a/hscontrol/policyv2/filter_test.go +++ b/hscontrol/policyv2/filter_test.go @@ -8,6 +8,7 @@ import ( "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" "github.com/rs/zerolog/log" + "gorm.io/gorm" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" ) @@ -17,6 +18,9 @@ import ( // Move it here, run it against both old and new CompileFilterRules func TestParsing(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "testuser@"}, + } tests := []struct { name string format string @@ -340,7 +344,7 @@ func TestParsing(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - pol, err := PolicyFromBytes([]byte(tt.acl)) + pol, err := policyFromBytes([]byte(tt.acl)) if tt.wantErr && err == nil { t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr) @@ -355,18 +359,18 @@ func TestParsing(t *testing.T) { return } - rules, err := pol.CompileFilterRules(types.Nodes{ - &types.Node{ - IPv4: ap("100.100.100.100"), - }, - &types.Node{ - IPv4: ap("200.200.200.200"), - User: types.User{ - Name: "testuser@", + rules, err := pol.CompileFilterRules( + users, + types.Nodes{ + &types.Node{ + IPv4: ap("100.100.100.100"), }, - Hostinfo: &tailcfg.Hostinfo{}, - }, - }) + &types.Node{ + IPv4: ap("200.200.200.200"), + User: users[0], + Hostinfo: &tailcfg.Hostinfo{}, + }, + }) if (err != nil) != tt.wantErr { t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr) @@ -435,6 +439,14 @@ var hsExitNodeDestForTest = []tailcfg.NetPortRange{ } func TestReduceFilterRules(t *testing.T) { + users := types.Users{ + types.User{Model: gorm.Model{ID: 1}, Name: "mickael"}, + types.User{Model: gorm.Model{ID: 2}, Name: "user1@"}, + types.User{Model: gorm.Model{ID: 3}, Name: "user2@"}, + types.User{Model: gorm.Model{ID: 4}, Name: "user100@"}, + types.User{Model: gorm.Model{ID: 5}, Name: "user3@"}, + } + tests := []struct { name string node *types.Node @@ -463,13 +475,13 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), - User: types.User{Name: "mickael"}, + User: users[0], }, peers: types.Nodes{ &types.Node{ IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), - User: types.User{Name: "mickael"}, + User: users[0], }, }, want: []tailcfg.FilterRule{}, @@ -510,7 +522,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1@"}, + User: users[1], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{ netip.MustParsePrefix("10.33.0.0/16"), @@ -521,7 +533,7 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user1@"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -600,19 +612,19 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1@"}, + User: users[1], }, peers: types.Nodes{ &types.Node{ IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2@"}, + User: users[2], }, // "internal" exit node &types.Node{ IPv4: ap("100.64.0.100"), IPv6: ap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100@"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tsaddr.ExitRoutes(), }, @@ -661,7 +673,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: ap("100.64.0.100"), IPv6: ap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100@"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tsaddr.ExitRoutes(), }, @@ -670,12 +682,12 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2@"}, + User: users[2], }, &types.Node{ IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1@"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -768,7 +780,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: ap("100.64.0.100"), IPv6: ap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100@"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tsaddr.ExitRoutes(), }, @@ -777,12 +789,12 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2@"}, + User: users[2], }, &types.Node{ IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1@"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -809,9 +821,11 @@ func TestReduceFilterRules(t *testing.T) { {IP: "16.0.0.0/4", Ports: tailcfg.PortRangeAny}, {IP: "32.0.0.0/3", Ports: tailcfg.PortRangeAny}, {IP: "64.0.0.0/2", Ports: tailcfg.PortRangeAny}, - {IP: "fd7a:115c:a1e0::1/128", Ports: tailcfg.PortRangeAny}, - {IP: "fd7a:115c:a1e0::2/128", Ports: tailcfg.PortRangeAny}, - {IP: "fd7a:115c:a1e0::100/128", Ports: tailcfg.PortRangeAny}, + // This should not be included I believe, seems like + // this is a bug in the v1 code. + // {IP: "fd7a:115c:a1e0::1/128", Ports: tailcfg.PortRangeAny}, + // {IP: "fd7a:115c:a1e0::2/128", Ports: tailcfg.PortRangeAny}, + // {IP: "fd7a:115c:a1e0::100/128", Ports: tailcfg.PortRangeAny}, {IP: "128.0.0.0/3", Ports: tailcfg.PortRangeAny}, {IP: "160.0.0.0/5", Ports: tailcfg.PortRangeAny}, {IP: "168.0.0.0/6", Ports: tailcfg.PortRangeAny}, @@ -881,7 +895,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: ap("100.64.0.100"), IPv6: ap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100@"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")}, }, @@ -890,12 +904,12 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2@"}, + User: users[2], }, &types.Node{ IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1@"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -969,7 +983,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: ap("100.64.0.100"), IPv6: ap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100@"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")}, }, @@ -978,12 +992,12 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2@"}, + User: users[2], }, &types.Node{ IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1@"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -1046,7 +1060,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: ap("100.64.0.100"), IPv6: ap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100@"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")}, }, @@ -1056,7 +1070,7 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1@"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -1090,22 +1104,19 @@ func TestReduceFilterRules(t *testing.T) { filterV1, _ := polV1.CompileFilterRules( append(tt.peers, tt.node), ) - polV2, err := PolicyFromBytes([]byte(tt.pol)) + pm, err := NewPolicyManager([]byte(tt.pol), users, append(tt.peers, tt.node)) if err != nil { t.Fatalf("parsing policy: %s", err) } - filterV2, _ := polV2.CompileFilterRules( - append(tt.peers, tt.node), - ) - if diff := cmp.Diff(filterV1, filterV2); diff != "" { - log.Trace().Interface("got", filterV2).Msg("result") + if diff := cmp.Diff(filterV1, pm.Filter()); diff != "" { + log.Trace().Interface("got", pm.Filter()).Msg("result") t.Errorf("TestReduceFilterRules() unexpected diff between v1 and v2 (-want +got):\n%s", diff) } // TODO(kradalby): Move this from v1, or // rewrite. - filterV2 = policy.ReduceFilterRules(tt.node, filterV2) + filterV2 := policy.ReduceFilterRules(tt.node, pm.Filter()) if diff := cmp.Diff(tt.want, filterV2); diff != "" { log.Trace().Interface("got", filterV2).Msg("result") diff --git a/hscontrol/policyv2/policy.go b/hscontrol/policyv2/policy.go new file mode 100644 index 00000000..99b30bc3 --- /dev/null +++ b/hscontrol/policyv2/policy.go @@ -0,0 +1,80 @@ +package policyv2 + +import ( + "fmt" + "sync" + + "github.com/juanfont/headscale/hscontrol/types" + "tailscale.com/tailcfg" +) + +type PolicyManager struct { + mu sync.Mutex + pol *Policy + users []types.User + nodes types.Nodes + + filter []tailcfg.FilterRule + + // TODO(kradalby): Implement SSH policy + sshPolicy *tailcfg.SSHPolicy +} + +// NewPolicyManager creates a new PolicyManager from a policy file and a list of users and nodes. +// It returns an error if the policy file is invalid. +// The policy manager will update the filter rules based on the users and nodes. +func NewPolicyManager(b []byte, users []types.User, nodes types.Nodes) (*PolicyManager, error) { + policy, err := policyFromBytes(b) + if err != nil { + return nil, fmt.Errorf("parsing policy: %w", err) + } + + pm := PolicyManager{ + pol: policy, + users: users, + nodes: nodes, + } + + err = pm.updateLocked() + if err != nil { + return nil, err + } + + return &pm, nil +} + +// Filter returns the current filter rules for the entire tailnet. +func (pm *PolicyManager) Filter() []tailcfg.FilterRule { + pm.mu.Lock() + defer pm.mu.Unlock() + return pm.filter +} + +// updateLocked updates the filter rules based on the current policy and nodes. +// It must be called with the lock held. +func (pm *PolicyManager) updateLocked() error { + filter, err := pm.pol.CompileFilterRules(pm.users, pm.nodes) + if err != nil { + return fmt.Errorf("compiling filter rules: %w", err) + } + + pm.filter = filter + + return nil +} + +// SetUsers updates the users in the policy manager and updates the filter rules. +func (pm *PolicyManager) SetUsers(users []types.User) error { + pm.mu.Lock() + defer pm.mu.Unlock() + pm.users = users + return pm.updateLocked() +} + +// SetNodes updates the nodes in the policy manager and updates the filter rules. +func (pm *PolicyManager) SetNodes(nodes types.Nodes) error { + pm.mu.Lock() + defer pm.mu.Unlock() + pm.nodes = nodes + return pm.updateLocked() +} diff --git a/hscontrol/policyv2/policy_test.go b/hscontrol/policyv2/policy_test.go new file mode 100644 index 00000000..b4496a6e --- /dev/null +++ b/hscontrol/policyv2/policy_test.go @@ -0,0 +1,58 @@ +package policyv2 + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/require" + "gorm.io/gorm" + "tailscale.com/tailcfg" +) + +func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo) *types.Node { + return &types.Node{ + ID: 0, + Hostname: name, + IPv4: ap(ipv4), + IPv6: ap(ipv6), + User: user, + UserID: user.ID, + Hostinfo: hostinfo, + } +} + +func TestPolicyManager(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "testuser", Email: "testuser@headscale.net"}, + {Model: gorm.Model{ID: 2}, Name: "otheruser", Email: "otheruser@headscale.net"}, + } + + tests := []struct { + name string + pol string + nodes types.Nodes + wantFilter []tailcfg.FilterRule + }{ + { + name: "empty-policy", + pol: "{}", + nodes: types.Nodes{}, + wantFilter: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pm, err := NewPolicyManager([]byte(tt.pol), users, tt.nodes) + require.NoError(t, err) + + filter := pm.Filter() + if diff := cmp.Diff(filter, tt.wantFilter); diff != "" { + t.Errorf("Filter() mismatch (-want +got):\n%s", diff) + } + + // TODO(kradalby): Test SSH Policy + }) + } +} diff --git a/hscontrol/policyv2/types.go b/hscontrol/policyv2/types.go index cd767f43..5a4c6de8 100644 --- a/hscontrol/policyv2/types.go +++ b/hscontrol/policyv2/types.go @@ -8,6 +8,7 @@ import ( "net/netip" "strconv" "strings" + "time" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" @@ -64,7 +65,7 @@ func (a Asterix) UnmarshalJSON(b []byte) error { return nil } -func (a Asterix) Resolve(_ *Policy, nodes types.Nodes) (*netipx.IPSet, error) { +func (a Asterix) Resolve(_ *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) { var ips netipx.IPSetBuilder ips.AddPrefix(tsaddr.AllIPv4()) @@ -99,15 +100,47 @@ func (u Username) CanBeTagOwner() bool { return true } -func (u Username) Resolve(_ *Policy, nodes types.Nodes) (*netipx.IPSet, error) { +func (u Username) resolveUser(users types.Users) (*types.User, error) { + var potentialUsers types.Users + for _, user := range users { + if user.ProviderIdentifier == string(u) { + potentialUsers = append(potentialUsers, user) + + break + } + if user.Email == string(u) { + potentialUsers = append(potentialUsers, user) + } + if user.Name == string(u) { + potentialUsers = append(potentialUsers, user) + } + } + + if len(potentialUsers) > 1 { + return nil, fmt.Errorf("unable to resolve user identifier to distinct: %s matched multiple %s", u, potentialUsers) + } else if len(potentialUsers) == 0 { + return nil, fmt.Errorf("unable to resolve user identifier, no user found: %s not in %s", u, users) + } + + user := potentialUsers[0] + + return &user, nil +} + +func (u Username) Resolve(_ *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) { var ips netipx.IPSetBuilder + user, err := u.resolveUser(users) + if err != nil { + return nil, err + } + for _, node := range nodes { if node.IsTagged() { continue } - if node.User.Username() == string(u) { + if node.User.ID == user.ID { node.AppendToIPSet(&ips) } } @@ -137,11 +170,11 @@ func (g Group) CanBeTagOwner() bool { return true } -func (g Group) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) { +func (g Group) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) { var ips netipx.IPSetBuilder for _, user := range p.Groups[g] { - uips, err := user.Resolve(nil, nodes) + uips, err := user.Resolve(nil, users, nodes) if err != nil { return nil, err } @@ -170,7 +203,7 @@ func (t *Tag) UnmarshalJSON(b []byte) error { return nil } -func (t Tag) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) { +func (t Tag) Resolve(p *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) { var ips netipx.IPSetBuilder for _, node := range nodes { @@ -197,7 +230,7 @@ func (h *Host) UnmarshalJSON(b []byte) error { return nil } -func (h Host) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) { +func (h Host) Resolve(p *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) { var ips netipx.IPSetBuilder pref, ok := p.Hosts[h] @@ -208,11 +241,26 @@ func (h Host) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) { if err != nil { return nil, err } + + // If the IP is a single host, look for a node to ensure we add all the IPs of + // the node to the IPSet. + appendIfNodeHasIP(nodes, &ips, pref) ips.AddPrefix(netip.Prefix(pref)) return ips.IPSet() } +func appendIfNodeHasIP(nodes types.Nodes, ips *netipx.IPSetBuilder, pref Prefix) { + if netip.Prefix(pref).IsSingleIP() { + addr := netip.Prefix(pref).Addr() + for _, node := range nodes { + if node.HasIP(addr) { + node.AppendToIPSet(ips) + } + } + } +} + type Prefix netip.Prefix func (p Prefix) Validate() error { @@ -261,9 +309,10 @@ func (p *Prefix) UnmarshalJSON(b []byte) error { return nil } -func (p Prefix) Resolve(_ *Policy, _ types.Nodes) (*netipx.IPSet, error) { +func (p Prefix) Resolve(_ *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) { var ips netipx.IPSetBuilder + appendIfNodeHasIP(nodes, &ips, p) ips.AddPrefix(netip.Prefix(p)) return ips.IPSet() @@ -296,7 +345,7 @@ func (ag *AutoGroup) UnmarshalJSON(b []byte) error { return nil } -func (ag AutoGroup) Resolve(_ *Policy, _ types.Nodes) (*netipx.IPSet, error) { +func (ag AutoGroup) Resolve(_ *Policy, _ types.Users, _ types.Nodes) (*netipx.IPSet, error) { switch ag { case AutoGroupInternet: return theInternet(), nil @@ -308,7 +357,7 @@ func (ag AutoGroup) Resolve(_ *Policy, _ types.Nodes) (*netipx.IPSet, error) { type Alias interface { Validate() error UnmarshalJSON([]byte) error - Resolve(*Policy, types.Nodes) (*netipx.IPSet, error) + Resolve(*Policy, types.Users, types.Nodes) (*netipx.IPSet, error) } type AliasWithPorts struct { @@ -428,11 +477,11 @@ func (a *Aliases) UnmarshalJSON(b []byte) error { return nil } -func (a Aliases) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) { +func (a Aliases) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) { var ips netipx.IPSetBuilder for _, alias := range a { - aips, err := alias.Resolve(p, nodes) + aips, err := alias.Resolve(p, users, nodes) if err != nil { return nil, err } @@ -530,10 +579,67 @@ type Policy struct { TagOwners TagOwners `json:"tagOwners"` ACLs []ACL `json:"acls"` AutoApprovers AutoApprovers `json:"autoApprovers"` - // SSHs []SSH `json:"ssh"` + SSHs []SSH `json:"ssh"` } -func PolicyFromBytes(b []byte) (*Policy, error) { +// SSH controls who can ssh into which machines. +type SSH struct { + Action string `json:"action"` + Sources SSHSrcAliases `json:"src"` + Destinations SSHDstAliases `json:"dst"` + Users []SSHUser `json:"users"` + CheckPeriod time.Duration `json:"checkPeriod,omitempty"` +} + +// SSHSrcAliases is a list of aliases that can be used as sources in an SSH rule. +// It can be a list of usernames, groups, tags or autogroups. +type SSHSrcAliases []Alias + +func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error { + var aliases []AliasEnc + err := json.Unmarshal(b, &aliases) + if err != nil { + return err + } + + *a = make([]Alias, len(aliases)) + for i, alias := range aliases { + switch alias.Alias.(type) { + case *Username, *Group, *Tag, *AutoGroup: + (*a)[i] = alias.Alias + default: + return fmt.Errorf("type %T not supported", alias.Alias) + } + } + return nil +} + +// SSHDstAliases is a list of aliases that can be used as destinations in an SSH rule. +// It can be a list of usernames, tags or autogroups. +type SSHDstAliases []Alias + +func (a *SSHDstAliases) UnmarshalJSON(b []byte) error { + var aliases []AliasEnc + err := json.Unmarshal(b, &aliases) + if err != nil { + return err + } + + *a = make([]Alias, len(aliases)) + for i, alias := range aliases { + switch alias.Alias.(type) { + case *Username, *Tag, *AutoGroup: + (*a)[i] = alias.Alias + default: + return fmt.Errorf("type %T not supported", alias.Alias) + } + } + return nil +} + +type SSHUser string + +func policyFromBytes(b []byte) (*Policy, error) { var policy Policy ast, err := hujson.Parse(b) if err != nil { diff --git a/hscontrol/policyv2/types_test.go b/hscontrol/policyv2/types_test.go index 018f9951..9060e657 100644 --- a/hscontrol/policyv2/types_test.go +++ b/hscontrol/policyv2/types_test.go @@ -173,7 +173,7 @@ func TestUnmarshalPolicy(t *testing.T) { Destinations: []AliasWithPorts{ { Alias: ptr.To(Username("otheruser@headscale.net")), - Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}}, + Ports: []tailcfg.PortRange{{First: 80, Last: 80}}, }, }, }, @@ -186,7 +186,7 @@ func TestUnmarshalPolicy(t *testing.T) { Destinations: []AliasWithPorts{ { Alias: gp("group:other"), - Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}}, + Ports: []tailcfg.PortRange{{First: 80, Last: 80}}, }, }, }, @@ -199,7 +199,7 @@ func TestUnmarshalPolicy(t *testing.T) { Destinations: []AliasWithPorts{ { Alias: pp("100.101.102.104/32"), - Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}}, + Ports: []tailcfg.PortRange{{First: 80, Last: 80}}, }, }, }, @@ -212,7 +212,7 @@ func TestUnmarshalPolicy(t *testing.T) { Destinations: []AliasWithPorts{ { Alias: pp("172.16.0.0/16"), - Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}}, + Ports: []tailcfg.PortRange{{First: 80, Last: 80}}, }, }, }, @@ -225,7 +225,7 @@ func TestUnmarshalPolicy(t *testing.T) { Destinations: []AliasWithPorts{ { Alias: hp("host-1"), - Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 88}}, + Ports: []tailcfg.PortRange{{First: 80, Last: 88}}, }, }, }, @@ -239,8 +239,8 @@ func TestUnmarshalPolicy(t *testing.T) { { Alias: tp("tag:user"), Ports: []tailcfg.PortRange{ - tailcfg.PortRange{First: 80, Last: 80}, - tailcfg.PortRange{First: 443, Last: 443}, + {First: 80, Last: 80}, + {First: 443, Last: 443}, }, }, }, @@ -255,7 +255,7 @@ func TestUnmarshalPolicy(t *testing.T) { { Alias: agp("autogroup:internet"), Ports: []tailcfg.PortRange{ - tailcfg.PortRange{First: 80, Last: 80}, + {First: 80, Last: 80}, }, }, }, @@ -341,7 +341,7 @@ func TestUnmarshalPolicy(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - policy, err := PolicyFromBytes([]byte(tt.input)) + policy, err := policyFromBytes([]byte(tt.input)) // TODO(kradalby): This error checking is broken, // but so is my brain, #longflight if err == nil { @@ -538,7 +538,9 @@ func TestResolvePolicy(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ips, err := tt.toResolve.Resolve(tt.pol, tt.nodes) + ips, err := tt.toResolve.Resolve(tt.pol, + types.Users{}, + tt.nodes) if err != nil { t.Fatalf("failed to resolve: %s", err) } diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 1f353e90..f5c29ed6 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -130,6 +130,16 @@ func (node *Node) IPs() []netip.Addr { return ret } +// HasIP reports if a node has a given IP address. +func (node *Node) HasIP(i netip.Addr) bool { + for _, ip := range node.IPs() { + if ip.Compare(i) == 0 { + return true + } + } + return false +} + // IsTagged reports if a device is tagged // and therefore should not be treated as a // user owned device. diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go index 35839f8e..6bb929f4 100644 --- a/hscontrol/types/users.go +++ b/hscontrol/types/users.go @@ -2,7 +2,9 @@ package types import ( "cmp" + "fmt" "strconv" + "strings" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/util" @@ -13,6 +15,19 @@ import ( type UserID uint64 +type Users []User + +func (u Users) String() string { + var sb strings.Builder + sb.WriteString("[ ") + for _, user := range u { + fmt.Fprintf(&sb, "%d: %s, ", user.ID, user.Name) + } + sb.WriteString(" ]") + + return sb.String() +} + // User is the way Headscale implements the concept of users in Tailscale // // At the end of the day, users in Tailscale are some kind of 'bubbles' or users