From c883e798849edb3d3babe25c24dba9d832da6bf8 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Sat, 21 Aug 2021 14:49:46 +0100 Subject: [PATCH] Enhance route command with ptables and multiple routes This commit rewrites the `routes list` command to use ptables to present a slightly nicer list, including a new field if the route is enabled or not (which is quite useful). In addition, it reworks the enable command to support enabling multiple routes (not only one route as per removed TODO). This allows users to actually take advantage of exit-nodes and subnet relays. --- cmd/headscale/cli/routes.go | 28 +++++--- routes.go | 129 ++++++++++++++++++++++++++++++------ routes_test.go | 81 ++++++++++++++++++++-- 3 files changed, 202 insertions(+), 36 deletions(-) diff --git a/cmd/headscale/cli/routes.go b/cmd/headscale/cli/routes.go index 98b653f9..f58d4990 100644 --- a/cmd/headscale/cli/routes.go +++ b/cmd/headscale/cli/routes.go @@ -5,6 +5,7 @@ import ( "log" "strings" + "github.com/pterm/pterm" "github.com/spf13/cobra" ) @@ -44,19 +45,25 @@ var listRoutesCmd = &cobra.Command{ if err != nil { log.Fatalf("Error initializing: %s", err) } - routes, err := h.GetNodeRoutes(n, args[0]) - - if strings.HasPrefix(o, "json") { - JsonOutput(routes, err, o) - return - } + availableRoutes, err := h.GetAdvertisedNodeRoutes(n, args[0]) if err != nil { fmt.Println(err) return } - fmt.Println(routes) + if strings.HasPrefix(o, "json") { + // TODO: Add enable/disabled information to this interface + JsonOutput(availableRoutes, err, o) + return + } + + d := h.RoutesToPtables(n, args[0], *availableRoutes) + + err = pterm.DefaultTable.WithHasHeader().WithData(d).Render() + if err != nil { + log.Fatal(err) + } }, } @@ -80,9 +87,10 @@ var enableRouteCmd = &cobra.Command{ if err != nil { log.Fatalf("Error initializing: %s", err) } - route, err := h.EnableNodeRoute(n, args[0], args[1]) + + err = h.EnableNodeRoute(n, args[0], args[1]) if strings.HasPrefix(o, "json") { - JsonOutput(route, err, o) + JsonOutput(args[1], err, o) return } @@ -90,6 +98,6 @@ var enableRouteCmd = &cobra.Command{ fmt.Println(err) return } - fmt.Printf("Enabled route %s\n", route) + fmt.Printf("Enabled route %s\n", args[1]) }, } diff --git a/routes.go b/routes.go index 202754b1..28d86837 100644 --- a/routes.go +++ b/routes.go @@ -2,55 +2,140 @@ package headscale import ( "encoding/json" - "errors" + "fmt" + "strconv" + "github.com/pterm/pterm" "gorm.io/datatypes" "inet.af/netaddr" ) -// GetNodeRoutes returns the subnet routes advertised by a node (identified by +// GetAdvertisedNodeRoutes returns the subnet routes advertised by a node (identified by // namespace and node name) -func (h *Headscale) GetNodeRoutes(namespace string, nodeName string) (*[]netaddr.IPPrefix, error) { +func (h *Headscale) GetAdvertisedNodeRoutes(namespace string, nodeName string) (*[]netaddr.IPPrefix, error) { m, err := h.GetMachine(namespace, nodeName) if err != nil { return nil, err } - hi, err := m.GetHostInfo() + hostInfo, err := m.GetHostInfo() if err != nil { return nil, err } - return &hi.RoutableIPs, nil + return &hostInfo.RoutableIPs, nil +} + +// GetEnabledNodeRoutes returns the subnet routes enabled by a node (identified by +// namespace and node name) +func (h *Headscale) GetEnabledNodeRoutes(namespace string, nodeName string) ([]netaddr.IPPrefix, error) { + m, err := h.GetMachine(namespace, nodeName) + if err != nil { + return nil, err + } + + data, err := m.EnabledRoutes.MarshalJSON() + if err != nil { + return nil, err + } + + routesStr := []string{} + err = json.Unmarshal(data, &routesStr) + if err != nil { + return nil, err + } + + routes := make([]netaddr.IPPrefix, len(routesStr)) + for index, routeStr := range routesStr { + route, err := netaddr.ParseIPPrefix(routeStr) + if err != nil { + return nil, err + } + routes[index] = route + } + + return routes, nil +} + +func (h *Headscale) IsNodeRouteEnabled(namespace string, nodeName string, routeStr string) bool { + route, err := netaddr.ParseIPPrefix(routeStr) + if err != nil { + return false + } + + enabledRoutes, err := h.GetEnabledNodeRoutes(namespace, nodeName) + if err != nil { + return false + } + + for _, enabledRoute := range enabledRoutes { + if route == enabledRoute { + return true + } + } + return false } // EnableNodeRoute enables a subnet route advertised by a node (identified by // namespace and node name) -func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr string) (*netaddr.IPPrefix, error) { +func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr string) error { m, err := h.GetMachine(namespace, nodeName) if err != nil { - return nil, err - } - hi, err := m.GetHostInfo() - if err != nil { - return nil, err + return err } + route, err := netaddr.ParseIPPrefix(routeStr) if err != nil { - return nil, err + return err } - for _, rIP := range hi.RoutableIPs { - if rIP == route { - routes, _ := json.Marshal([]string{routeStr}) // TODO: only one for the time being, so overwriting the rest - m.EnabledRoutes = datatypes.JSON(routes) - h.db.Save(&m) + availableRoutes, err := h.GetAdvertisedNodeRoutes(namespace, nodeName) + if err != nil { + return err + } - err = h.RequestMapUpdates(m.NamespaceID) - if err != nil { - return nil, err + enabledRoutes, err := h.GetEnabledNodeRoutes(namespace, nodeName) + if err != nil { + return err + } + + available := false + for _, availableRoute := range *availableRoutes { + // If the route is available, and not yet enabled, add it to the new routing table + if route == availableRoute { + available = true + if !h.IsNodeRouteEnabled(namespace, nodeName, routeStr) { + enabledRoutes = append(enabledRoutes, route) } - return &rIP, nil } } - return nil, errors.New("could not find routable range") + + if !available { + return fmt.Errorf("route (%s) is not available on node %s", nodeName, routeStr) + } + + routes, err := json.Marshal(enabledRoutes) + if err != nil { + return err + } + + m.EnabledRoutes = datatypes.JSON(routes) + h.db.Save(&m) + + err = h.RequestMapUpdates(m.NamespaceID) + if err != nil { + return err + } + + return nil +} + +func (h *Headscale) RoutesToPtables(namespace string, nodeName string, availableRoutes []netaddr.IPPrefix) pterm.TableData { + d := pterm.TableData{{"Route", "Enabled"}} + + for _, route := range availableRoutes { + enabled := h.IsNodeRouteEnabled(namespace, nodeName, route.String()) + + d = append(d, []string{route.String(), strconv.FormatBool(enabled)}) + } + return d } diff --git a/routes_test.go b/routes_test.go index a05b7e16..33aaa9df 100644 --- a/routes_test.go +++ b/routes_test.go @@ -33,7 +33,7 @@ func (s *Suite) TestGetRoutes(c *check.C) { MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", - Name: "testmachine", + Name: "test_get_route_machine", NamespaceID: n.ID, Registered: true, RegisterMethod: "authKey", @@ -42,14 +42,87 @@ func (s *Suite) TestGetRoutes(c *check.C) { } h.db.Save(&m) - r, err := h.GetNodeRoutes("test", "testmachine") + r, err := h.GetAdvertisedNodeRoutes("test", "testmachine") c.Assert(err, check.IsNil) c.Assert(len(*r), check.Equals, 1) - _, err = h.EnableNodeRoute("test", "testmachine", "192.168.0.0/24") + err = h.EnableNodeRoute("test", "testmachine", "192.168.0.0/24") c.Assert(err, check.NotNil) - _, err = h.EnableNodeRoute("test", "testmachine", "10.0.0.0/24") + err = h.EnableNodeRoute("test", "testmachine", "10.0.0.0/24") + c.Assert(err, check.IsNil) +} + +func (s *Suite) TestGetEnableRoutes(c *check.C) { + n, err := h.CreateNamespace("test") c.Assert(err, check.IsNil) + pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) + c.Assert(err, check.IsNil) + + _, err = h.GetMachine("test", "testmachine") + c.Assert(err, check.NotNil) + + route, err := netaddr.ParseIPPrefix( + "10.0.0.0/24", + ) + c.Assert(err, check.IsNil) + + route2, err := netaddr.ParseIPPrefix( + "150.0.10.0/25", + ) + c.Assert(err, check.IsNil) + + hi := tailcfg.Hostinfo{ + RoutableIPs: []netaddr.IPPrefix{route, route2}, + } + hostinfo, err := json.Marshal(hi) + c.Assert(err, check.IsNil) + + m := Machine{ + ID: 0, + MachineKey: "foo", + NodeKey: "bar", + DiscoKey: "faa", + Name: "test_enable_route_machine", + NamespaceID: n.ID, + Registered: true, + RegisterMethod: "authKey", + AuthKeyID: uint(pak.ID), + HostInfo: datatypes.JSON(hostinfo), + } + h.db.Save(&m) + + availableRoutes, err := h.GetAdvertisedNodeRoutes("test", "testmachine") + c.Assert(err, check.IsNil) + c.Assert(len(*availableRoutes), check.Equals, 2) + + enabledRoutes, err := h.GetEnabledNodeRoutes("test", "testmachine") + c.Assert(err, check.IsNil) + c.Assert(len(enabledRoutes), check.Equals, 0) + + err = h.EnableNodeRoute("test", "testmachine", "192.168.0.0/24") + c.Assert(err, check.NotNil) + + err = h.EnableNodeRoute("test", "testmachine", "10.0.0.0/24") + c.Assert(err, check.IsNil) + + enabledRoutes1, err := h.GetEnabledNodeRoutes("test", "testmachine") + c.Assert(err, check.IsNil) + c.Assert(len(enabledRoutes1), check.Equals, 1) + + // Adding it twice will just let it pass through + err = h.EnableNodeRoute("test", "testmachine", "10.0.0.0/24") + c.Assert(err, check.IsNil) + + enabledRoutes2, err := h.GetEnabledNodeRoutes("test", "testmachine") + c.Assert(err, check.IsNil) + c.Assert(len(enabledRoutes2), check.Equals, 1) + + err = h.EnableNodeRoute("test", "testmachine", "150.0.10.0/25") + c.Assert(err, check.IsNil) + + enabledRoutes3, err := h.GetEnabledNodeRoutes("test", "testmachine") + c.Assert(err, check.IsNil) + c.Assert(len(enabledRoutes3), check.Equals, 2) }