Enhance route command with ptables and multiple routes

This commit rewrites the `routes list` command to use ptables to present
a slightly nicer list, including a new field if the route is enabled or
not (which is quite useful).

In addition, it reworks the enable command to support enabling multiple
routes (not only one route as per removed TODO). This allows users to
actually take advantage of exit-nodes and subnet relays.
This commit is contained in:
Kristoffer Dalby 2021-08-21 14:49:46 +01:00
parent 47b61c0cea
commit c883e79884
No known key found for this signature in database
GPG key ID: 09F62DC067465735
3 changed files with 202 additions and 36 deletions

View file

@ -5,6 +5,7 @@ import (
"log"
"strings"
"github.com/pterm/pterm"
"github.com/spf13/cobra"
)
@ -44,19 +45,25 @@ var listRoutesCmd = &cobra.Command{
if err != nil {
log.Fatalf("Error initializing: %s", err)
}
routes, err := h.GetNodeRoutes(n, args[0])
if strings.HasPrefix(o, "json") {
JsonOutput(routes, err, o)
return
}
availableRoutes, err := h.GetAdvertisedNodeRoutes(n, args[0])
if err != nil {
fmt.Println(err)
return
}
fmt.Println(routes)
if strings.HasPrefix(o, "json") {
// TODO: Add enable/disabled information to this interface
JsonOutput(availableRoutes, err, o)
return
}
d := h.RoutesToPtables(n, args[0], *availableRoutes)
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
if err != nil {
log.Fatal(err)
}
},
}
@ -80,9 +87,10 @@ var enableRouteCmd = &cobra.Command{
if err != nil {
log.Fatalf("Error initializing: %s", err)
}
route, err := h.EnableNodeRoute(n, args[0], args[1])
err = h.EnableNodeRoute(n, args[0], args[1])
if strings.HasPrefix(o, "json") {
JsonOutput(route, err, o)
JsonOutput(args[1], err, o)
return
}
@ -90,6 +98,6 @@ var enableRouteCmd = &cobra.Command{
fmt.Println(err)
return
}
fmt.Printf("Enabled route %s\n", route)
fmt.Printf("Enabled route %s\n", args[1])
},
}

129
routes.go
View file

@ -2,55 +2,140 @@ package headscale
import (
"encoding/json"
"errors"
"fmt"
"strconv"
"github.com/pterm/pterm"
"gorm.io/datatypes"
"inet.af/netaddr"
)
// GetNodeRoutes returns the subnet routes advertised by a node (identified by
// GetAdvertisedNodeRoutes returns the subnet routes advertised by a node (identified by
// namespace and node name)
func (h *Headscale) GetNodeRoutes(namespace string, nodeName string) (*[]netaddr.IPPrefix, error) {
func (h *Headscale) GetAdvertisedNodeRoutes(namespace string, nodeName string) (*[]netaddr.IPPrefix, error) {
m, err := h.GetMachine(namespace, nodeName)
if err != nil {
return nil, err
}
hi, err := m.GetHostInfo()
hostInfo, err := m.GetHostInfo()
if err != nil {
return nil, err
}
return &hi.RoutableIPs, nil
return &hostInfo.RoutableIPs, nil
}
// GetEnabledNodeRoutes returns the subnet routes enabled by a node (identified by
// namespace and node name)
func (h *Headscale) GetEnabledNodeRoutes(namespace string, nodeName string) ([]netaddr.IPPrefix, error) {
m, err := h.GetMachine(namespace, nodeName)
if err != nil {
return nil, err
}
data, err := m.EnabledRoutes.MarshalJSON()
if err != nil {
return nil, err
}
routesStr := []string{}
err = json.Unmarshal(data, &routesStr)
if err != nil {
return nil, err
}
routes := make([]netaddr.IPPrefix, len(routesStr))
for index, routeStr := range routesStr {
route, err := netaddr.ParseIPPrefix(routeStr)
if err != nil {
return nil, err
}
routes[index] = route
}
return routes, nil
}
func (h *Headscale) IsNodeRouteEnabled(namespace string, nodeName string, routeStr string) bool {
route, err := netaddr.ParseIPPrefix(routeStr)
if err != nil {
return false
}
enabledRoutes, err := h.GetEnabledNodeRoutes(namespace, nodeName)
if err != nil {
return false
}
for _, enabledRoute := range enabledRoutes {
if route == enabledRoute {
return true
}
}
return false
}
// EnableNodeRoute enables a subnet route advertised by a node (identified by
// namespace and node name)
func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr string) (*netaddr.IPPrefix, error) {
func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr string) error {
m, err := h.GetMachine(namespace, nodeName)
if err != nil {
return nil, err
}
hi, err := m.GetHostInfo()
if err != nil {
return nil, err
return err
}
route, err := netaddr.ParseIPPrefix(routeStr)
if err != nil {
return nil, err
return err
}
for _, rIP := range hi.RoutableIPs {
if rIP == route {
routes, _ := json.Marshal([]string{routeStr}) // TODO: only one for the time being, so overwriting the rest
m.EnabledRoutes = datatypes.JSON(routes)
h.db.Save(&m)
availableRoutes, err := h.GetAdvertisedNodeRoutes(namespace, nodeName)
if err != nil {
return err
}
err = h.RequestMapUpdates(m.NamespaceID)
if err != nil {
return nil, err
enabledRoutes, err := h.GetEnabledNodeRoutes(namespace, nodeName)
if err != nil {
return err
}
available := false
for _, availableRoute := range *availableRoutes {
// If the route is available, and not yet enabled, add it to the new routing table
if route == availableRoute {
available = true
if !h.IsNodeRouteEnabled(namespace, nodeName, routeStr) {
enabledRoutes = append(enabledRoutes, route)
}
return &rIP, nil
}
}
return nil, errors.New("could not find routable range")
if !available {
return fmt.Errorf("route (%s) is not available on node %s", nodeName, routeStr)
}
routes, err := json.Marshal(enabledRoutes)
if err != nil {
return err
}
m.EnabledRoutes = datatypes.JSON(routes)
h.db.Save(&m)
err = h.RequestMapUpdates(m.NamespaceID)
if err != nil {
return err
}
return nil
}
func (h *Headscale) RoutesToPtables(namespace string, nodeName string, availableRoutes []netaddr.IPPrefix) pterm.TableData {
d := pterm.TableData{{"Route", "Enabled"}}
for _, route := range availableRoutes {
enabled := h.IsNodeRouteEnabled(namespace, nodeName, route.String())
d = append(d, []string{route.String(), strconv.FormatBool(enabled)})
}
return d
}

View file

@ -33,7 +33,7 @@ func (s *Suite) TestGetRoutes(c *check.C) {
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
Name: "test_get_route_machine",
NamespaceID: n.ID,
Registered: true,
RegisterMethod: "authKey",
@ -42,14 +42,87 @@ func (s *Suite) TestGetRoutes(c *check.C) {
}
h.db.Save(&m)
r, err := h.GetNodeRoutes("test", "testmachine")
r, err := h.GetAdvertisedNodeRoutes("test", "testmachine")
c.Assert(err, check.IsNil)
c.Assert(len(*r), check.Equals, 1)
_, err = h.EnableNodeRoute("test", "testmachine", "192.168.0.0/24")
err = h.EnableNodeRoute("test", "testmachine", "192.168.0.0/24")
c.Assert(err, check.NotNil)
_, err = h.EnableNodeRoute("test", "testmachine", "10.0.0.0/24")
err = h.EnableNodeRoute("test", "testmachine", "10.0.0.0/24")
c.Assert(err, check.IsNil)
}
func (s *Suite) TestGetEnableRoutes(c *check.C) {
n, err := h.CreateNamespace("test")
c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
c.Assert(err, check.IsNil)
_, err = h.GetMachine("test", "testmachine")
c.Assert(err, check.NotNil)
route, err := netaddr.ParseIPPrefix(
"10.0.0.0/24",
)
c.Assert(err, check.IsNil)
route2, err := netaddr.ParseIPPrefix(
"150.0.10.0/25",
)
c.Assert(err, check.IsNil)
hi := tailcfg.Hostinfo{
RoutableIPs: []netaddr.IPPrefix{route, route2},
}
hostinfo, err := json.Marshal(hi)
c.Assert(err, check.IsNil)
m := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "test_enable_route_machine",
NamespaceID: n.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
HostInfo: datatypes.JSON(hostinfo),
}
h.db.Save(&m)
availableRoutes, err := h.GetAdvertisedNodeRoutes("test", "testmachine")
c.Assert(err, check.IsNil)
c.Assert(len(*availableRoutes), check.Equals, 2)
enabledRoutes, err := h.GetEnabledNodeRoutes("test", "testmachine")
c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes), check.Equals, 0)
err = h.EnableNodeRoute("test", "testmachine", "192.168.0.0/24")
c.Assert(err, check.NotNil)
err = h.EnableNodeRoute("test", "testmachine", "10.0.0.0/24")
c.Assert(err, check.IsNil)
enabledRoutes1, err := h.GetEnabledNodeRoutes("test", "testmachine")
c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes1), check.Equals, 1)
// Adding it twice will just let it pass through
err = h.EnableNodeRoute("test", "testmachine", "10.0.0.0/24")
c.Assert(err, check.IsNil)
enabledRoutes2, err := h.GetEnabledNodeRoutes("test", "testmachine")
c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes2), check.Equals, 1)
err = h.EnableNodeRoute("test", "testmachine", "150.0.10.0/25")
c.Assert(err, check.IsNil)
enabledRoutes3, err := h.GetEnabledNodeRoutes("test", "testmachine")
c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes3), check.Equals, 2)
}