From 787814ea89b677f28b77435f687b9e357bb1d9e9 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Thu, 4 Nov 2021 22:11:38 +0000 Subject: [PATCH] Consolidate machine related lookups This commit moves the routes lookup functions to be subcommands of Machine, making them a lot simpler and more specific/composable. It also moves the register command from cli.go into machine, so we can clear out the extra file. Finally a toProto function has been added to convert between the machine database model and the proto/rpc model. --- cli.go | 43 --------- machine.go | 249 ++++++++++++++++++++++++++++++++++++++++++++++++++--- routes.go | 18 +--- 3 files changed, 239 insertions(+), 71 deletions(-) delete mode 100644 cli.go diff --git a/cli.go b/cli.go deleted file mode 100644 index 8610b334..00000000 --- a/cli.go +++ /dev/null @@ -1,43 +0,0 @@ -package headscale - -import ( - "errors" - - "gorm.io/gorm" - "tailscale.com/types/wgkey" -) - -// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey -func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, error) { - ns, err := h.GetNamespace(namespace) - if err != nil { - return nil, err - } - mKey, err := wgkey.ParseHex(key) - if err != nil { - return nil, err - } - - m := Machine{} - if result := h.db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, errors.New("Machine not found") - } - - h.updateMachineExpiry(&m) // update the machine's expiry before bailing if its already registered - - if m.isAlreadyRegistered() { - return nil, errors.New("Machine already registered") - } - - ip, err := h.getAvailableIP() - if err != nil { - return nil, err - } - m.IPAddress = ip.String() - m.NamespaceID = ns.ID - m.Registered = true - m.RegisterMethod = "cli" - h.db.Save(&m) - - return &m, nil -} diff --git a/machine.go b/machine.go index ccd30e3e..557ab5b6 100644 --- a/machine.go +++ b/machine.go @@ -2,6 +2,7 @@ package headscale import ( "encoding/json" + "errors" "fmt" "sort" "strconv" @@ -10,8 +11,11 @@ import ( "github.com/fatih/set" "github.com/rs/zerolog/log" + "google.golang.org/protobuf/types/known/timestamppb" + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "gorm.io/datatypes" + "gorm.io/gorm" "inet.af/netaddr" "tailscale.com/tailcfg" "tailscale.com/types/wgkey" @@ -91,7 +95,7 @@ func (h *Headscale) updateMachineExpiry(m *Machine) { func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) { log.Trace(). - Str("func", "getDirectPeers"). + Caller(). Str("machine", m.Name). Msg("Finding direct peers") @@ -105,7 +109,7 @@ func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) { sort.Slice(machines, func(i, j int) bool { return machines[i].ID < machines[j].ID }) log.Trace(). - Str("func", "getDirectmachines"). + Caller(). Str("machine", m.Name). Msgf("Found direct machines: %s", machines.String()) return machines, nil @@ -114,7 +118,7 @@ func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) { // getShared fetches machines that are shared to the `Namespace` of the machine we are getting peers for func (h *Headscale) getShared(m *Machine) (Machines, error) { log.Trace(). - Str("func", "getShared"). + Caller(). Str("machine", m.Name). Msg("Finding shared peers") @@ -132,7 +136,7 @@ func (h *Headscale) getShared(m *Machine) (Machines, error) { sort.Slice(peers, func(i, j int) bool { return peers[i].ID < peers[j].ID }) log.Trace(). - Str("func", "getShared"). + Caller(). Str("machine", m.Name). Msgf("Found shared peers: %s", peers.String()) return peers, nil @@ -141,7 +145,7 @@ func (h *Headscale) getShared(m *Machine) (Machines, error) { // getSharedTo fetches the machines of the namespaces this machine is shared in func (h *Headscale) getSharedTo(m *Machine) (Machines, error) { log.Trace(). - Str("func", "getSharedTo"). + Caller(). Str("machine", m.Name). Msg("Finding peers in namespaces this machine is shared with") @@ -157,13 +161,13 @@ func (h *Headscale) getSharedTo(m *Machine) (Machines, error) { if err != nil { return Machines{}, err } - peers = append(peers, *namespaceMachines...) + peers = append(peers, namespaceMachines...) } sort.Slice(peers, func(i, j int) bool { return peers[i].ID < peers[j].ID }) log.Trace(). - Str("func", "getSharedTo"). + Caller(). Str("machine", m.Name). Msgf("Found peers we are shared with: %s", peers.String()) return peers, nil @@ -173,7 +177,7 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) { direct, err := h.getDirectPeers(m) if err != nil { log.Error(). - Str("func", "getPeers"). + Caller(). Err(err). Msg("Cannot fetch peers") return Machines{}, err @@ -182,7 +186,7 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) { shared, err := h.getShared(m) if err != nil { log.Error(). - Str("func", "getShared"). + Caller(). Err(err). Msg("Cannot fetch peers") return Machines{}, err @@ -191,7 +195,7 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) { sharedTo, err := h.getSharedTo(m) if err != nil { log.Error(). - Str("func", "sharedTo"). + Caller(). Err(err). Msg("Cannot fetch peers") return Machines{}, err @@ -203,13 +207,21 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) { sort.Slice(peers, func(i, j int) bool { return peers[i].ID < peers[j].ID }) log.Trace(). - Str("func", "getShared"). + Caller(). Str("machine", m.Name). Msgf("Found total peers: %s", peers.String()) return peers, nil } +func (h *Headscale) ListMachines() ([]Machine, error) { + machines := []Machine{} + if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Find(&machines).Error; err != nil { + return nil, err + } + return machines, nil +} + // GetMachine finds a Machine by name and namespace and returns the Machine struct func (h *Headscale) GetMachine(namespace string, name string) (*Machine, error) { machines, err := h.ListMachinesInNamespace(namespace) @@ -217,7 +229,7 @@ func (h *Headscale) GetMachine(namespace string, name string) (*Machine, error) return nil, err } - for _, m := range *machines { + for _, m := range machines { if m.Name == name { return &m, nil } @@ -326,7 +338,7 @@ func (h *Headscale) isOutdated(m *Machine) bool { lastChange := h.getLastStateChange(namespaces...) log.Trace(). - Str("func", "keepAlive"). + Caller(). Str("machine", m.Name). Time("last_successful_update", *m.LastSuccessfulUpdate). Time("last_state_change", lastChange). @@ -405,7 +417,7 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include ip, err := netaddr.ParseIPPrefix(fmt.Sprintf("%s/32", m.IPAddress)) if err != nil { log.Trace(). - Str("func", "toNode"). + Caller(). Str("ip", m.IPAddress). Msgf("Failed to parse IP Prefix from IP: %s", m.IPAddress) return nil, err @@ -508,3 +520,212 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include } return &n, nil } + +func (m *Machine) toProto() *v1.Machine { + machine := &v1.Machine{ + Id: m.ID, + MachineKey: m.MachineKey, + + NodeKey: m.NodeKey, + DiscoKey: m.DiscoKey, + IpAddress: m.IPAddress, + Name: m.Name, + Namespace: m.Namespace.toProto(), + + Registered: m.Registered, + + // TODO(kradalby): Implement register method enum converter + // RegisterMethod: , + + CreatedAt: timestamppb.New(m.CreatedAt), + } + + if m.AuthKey != nil { + machine.PreAuthKey = m.AuthKey.toProto() + } + + if m.LastSeen != nil { + machine.LastSeen = timestamppb.New(*m.LastSeen) + } + + if m.LastSuccessfulUpdate != nil { + machine.LastSuccessfulUpdate = timestamppb.New(*m.LastSuccessfulUpdate) + } + + if m.Expiry != nil { + machine.Expiry = timestamppb.New(*m.Expiry) + } + + return machine +} + +// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey +func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, error) { + ns, err := h.GetNamespace(namespace) + if err != nil { + return nil, err + } + mKey, err := wgkey.ParseHex(key) + if err != nil { + return nil, err + } + + m := Machine{} + if result := h.db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, errors.New("Machine not found") + } + + log.Trace(). + Caller(). + Str("machine", m.Name). + Msg("Attempting to register machine") + + if m.isAlreadyRegistered() { + err := errors.New("Machine already registered") + log.Error(). + Caller(). + Err(err). + Str("machine", m.Name). + Msg("Attempting to register machine") + + return nil, err + } + + ip, err := h.getAvailableIP() + if err != nil { + log.Error(). + Caller(). + Err(err). + Str("machine", m.Name). + Msg("Could not find IP for the new machine") + return nil, err + } + + log.Trace(). + Caller(). + Str("machine", m.Name). + Str("ip", ip.String()). + Msg("Found IP for host") + + m.IPAddress = ip.String() + m.NamespaceID = ns.ID + m.Registered = true + m.RegisterMethod = "cli" + h.db.Save(&m) + + log.Trace(). + Caller(). + Str("machine", m.Name). + Str("ip", ip.String()). + Msg("Machine registered with the database") + + return &m, nil +} + +func (m *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) { + hostInfo, err := m.GetHostInfo() + if err != nil { + return nil, err + } + return hostInfo.RoutableIPs, nil +} + +func (m *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) { + data, err := m.EnabledRoutes.MarshalJSON() + if err != nil { + return nil, err + } + + routesStr := []string{} + err = json.Unmarshal(data, &routesStr) + if err != nil { + return nil, err + } + + routes := make([]netaddr.IPPrefix, len(routesStr)) + for index, routeStr := range routesStr { + route, err := netaddr.ParseIPPrefix(routeStr) + if err != nil { + return nil, err + } + routes[index] = route + } + + return routes, nil +} + +func (m *Machine) IsRoutesEnabled(routeStr string) bool { + route, err := netaddr.ParseIPPrefix(routeStr) + if err != nil { + return false + } + + enabledRoutes, err := m.GetEnabledRoutes() + if err != nil { + return false + } + + for _, enabledRoute := range enabledRoutes { + if route == enabledRoute { + return true + } + } + return false +} + +// EnableNodeRoute enables new routes based on a list of new routes. It will _replace_ the +// previous list of routes. +func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error { + newRoutes := make([]netaddr.IPPrefix, len(routeStrs)) + for index, routeStr := range routeStrs { + route, err := netaddr.ParseIPPrefix(routeStr) + if err != nil { + return err + } + + newRoutes[index] = route + } + + availableRoutes, err := m.GetAdvertisedRoutes() + if err != nil { + return err + } + + for _, newRoute := range newRoutes { + if !containsIpPrefix(availableRoutes, newRoute) { + return fmt.Errorf("route (%s) is not available on node %s", m.Name, newRoute) + } + } + + routes, err := json.Marshal(newRoutes) + if err != nil { + return err + } + + m.EnabledRoutes = datatypes.JSON(routes) + h.db.Save(&m) + + err = h.RequestMapUpdates(m.NamespaceID) + if err != nil { + return err + } + + return nil +} + +func (m *Machine) RoutesToProto() (*v1.Routes, error) { + availableRoutes, err := m.GetAdvertisedRoutes() + if err != nil { + return nil, err + } + + enabledRoutes, err := m.GetEnabledRoutes() + if err != nil { + return nil, err + } + + return &v1.Routes{ + AdvertisedRoutes: ipPrefixToString(availableRoutes), + EnabledRoutes: ipPrefixToString(enabledRoutes), + }, nil +} diff --git a/routes.go b/routes.go index 0ef01780..f07b709a 100644 --- a/routes.go +++ b/routes.go @@ -3,13 +3,12 @@ package headscale import ( "encoding/json" "fmt" - "strconv" - "github.com/pterm/pterm" "gorm.io/datatypes" "inet.af/netaddr" ) +// Deprecated: use machine function instead // GetAdvertisedNodeRoutes returns the subnet routes advertised by a node (identified by // namespace and node name) func (h *Headscale) GetAdvertisedNodeRoutes(namespace string, nodeName string) (*[]netaddr.IPPrefix, error) { @@ -25,6 +24,7 @@ func (h *Headscale) GetAdvertisedNodeRoutes(namespace string, nodeName string) ( return &hostInfo.RoutableIPs, nil } +// Deprecated: use machine function instead // GetEnabledNodeRoutes returns the subnet routes enabled by a node (identified by // namespace and node name) func (h *Headscale) GetEnabledNodeRoutes(namespace string, nodeName string) ([]netaddr.IPPrefix, error) { @@ -56,6 +56,7 @@ func (h *Headscale) GetEnabledNodeRoutes(namespace string, nodeName string) ([]n return routes, nil } +// Deprecated: use machine function instead // IsNodeRouteEnabled checks if a certain route has been enabled func (h *Headscale) IsNodeRouteEnabled(namespace string, nodeName string, routeStr string) bool { route, err := netaddr.ParseIPPrefix(routeStr) @@ -76,6 +77,7 @@ func (h *Headscale) IsNodeRouteEnabled(namespace string, nodeName string, routeS return false } +// Deprecated: use EnableRoute in machine.go // EnableNodeRoute enables a subnet route advertised by a node (identified by // namespace and node name) func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr string) error { @@ -129,15 +131,3 @@ func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr return nil } - -// RoutesToPtables converts the list of routes to a nice table -func (h *Headscale) RoutesToPtables(namespace string, nodeName string, availableRoutes []netaddr.IPPrefix) pterm.TableData { - d := pterm.TableData{{"Route", "Enabled"}} - - for _, route := range availableRoutes { - enabled := h.IsNodeRouteEnabled(namespace, nodeName, route.String()) - - d = append(d, []string{route.String(), strconv.FormatBool(enabled)}) - } - return d -}