Consolidate machine related lookups

This commit moves the routes lookup functions to be subcommands of
Machine, making them a lot simpler and more specific/composable.

It also moves the register command from cli.go into machine, so we can
clear out the extra file.

Finally a toProto function has been added to convert between the machine
database model and the proto/rpc model.
This commit is contained in:
Kristoffer Dalby 2021-11-04 22:11:38 +00:00
parent 67adea5cab
commit 787814ea89
3 changed files with 239 additions and 71 deletions

43
cli.go
View file

@ -1,43 +0,0 @@
package headscale
import (
"errors"
"gorm.io/gorm"
"tailscale.com/types/wgkey"
)
// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey
func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, error) {
ns, err := h.GetNamespace(namespace)
if err != nil {
return nil, err
}
mKey, err := wgkey.ParseHex(key)
if err != nil {
return nil, err
}
m := Machine{}
if result := h.db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, errors.New("Machine not found")
}
h.updateMachineExpiry(&m) // update the machine's expiry before bailing if its already registered
if m.isAlreadyRegistered() {
return nil, errors.New("Machine already registered")
}
ip, err := h.getAvailableIP()
if err != nil {
return nil, err
}
m.IPAddress = ip.String()
m.NamespaceID = ns.ID
m.Registered = true
m.RegisterMethod = "cli"
h.db.Save(&m)
return &m, nil
}

View file

@ -2,6 +2,7 @@ package headscale
import (
"encoding/json"
"errors"
"fmt"
"sort"
"strconv"
@ -10,8 +11,11 @@ import (
"github.com/fatih/set"
"github.com/rs/zerolog/log"
"google.golang.org/protobuf/types/known/timestamppb"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"gorm.io/datatypes"
"gorm.io/gorm"
"inet.af/netaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/wgkey"
@ -91,7 +95,7 @@ func (h *Headscale) updateMachineExpiry(m *Machine) {
func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) {
log.Trace().
Str("func", "getDirectPeers").
Caller().
Str("machine", m.Name).
Msg("Finding direct peers")
@ -105,7 +109,7 @@ func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) {
sort.Slice(machines, func(i, j int) bool { return machines[i].ID < machines[j].ID })
log.Trace().
Str("func", "getDirectmachines").
Caller().
Str("machine", m.Name).
Msgf("Found direct machines: %s", machines.String())
return machines, nil
@ -114,7 +118,7 @@ func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) {
// getShared fetches machines that are shared to the `Namespace` of the machine we are getting peers for
func (h *Headscale) getShared(m *Machine) (Machines, error) {
log.Trace().
Str("func", "getShared").
Caller().
Str("machine", m.Name).
Msg("Finding shared peers")
@ -132,7 +136,7 @@ func (h *Headscale) getShared(m *Machine) (Machines, error) {
sort.Slice(peers, func(i, j int) bool { return peers[i].ID < peers[j].ID })
log.Trace().
Str("func", "getShared").
Caller().
Str("machine", m.Name).
Msgf("Found shared peers: %s", peers.String())
return peers, nil
@ -141,7 +145,7 @@ func (h *Headscale) getShared(m *Machine) (Machines, error) {
// getSharedTo fetches the machines of the namespaces this machine is shared in
func (h *Headscale) getSharedTo(m *Machine) (Machines, error) {
log.Trace().
Str("func", "getSharedTo").
Caller().
Str("machine", m.Name).
Msg("Finding peers in namespaces this machine is shared with")
@ -157,13 +161,13 @@ func (h *Headscale) getSharedTo(m *Machine) (Machines, error) {
if err != nil {
return Machines{}, err
}
peers = append(peers, *namespaceMachines...)
peers = append(peers, namespaceMachines...)
}
sort.Slice(peers, func(i, j int) bool { return peers[i].ID < peers[j].ID })
log.Trace().
Str("func", "getSharedTo").
Caller().
Str("machine", m.Name).
Msgf("Found peers we are shared with: %s", peers.String())
return peers, nil
@ -173,7 +177,7 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) {
direct, err := h.getDirectPeers(m)
if err != nil {
log.Error().
Str("func", "getPeers").
Caller().
Err(err).
Msg("Cannot fetch peers")
return Machines{}, err
@ -182,7 +186,7 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) {
shared, err := h.getShared(m)
if err != nil {
log.Error().
Str("func", "getShared").
Caller().
Err(err).
Msg("Cannot fetch peers")
return Machines{}, err
@ -191,7 +195,7 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) {
sharedTo, err := h.getSharedTo(m)
if err != nil {
log.Error().
Str("func", "sharedTo").
Caller().
Err(err).
Msg("Cannot fetch peers")
return Machines{}, err
@ -203,13 +207,21 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) {
sort.Slice(peers, func(i, j int) bool { return peers[i].ID < peers[j].ID })
log.Trace().
Str("func", "getShared").
Caller().
Str("machine", m.Name).
Msgf("Found total peers: %s", peers.String())
return peers, nil
}
func (h *Headscale) ListMachines() ([]Machine, error) {
machines := []Machine{}
if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Find(&machines).Error; err != nil {
return nil, err
}
return machines, nil
}
// GetMachine finds a Machine by name and namespace and returns the Machine struct
func (h *Headscale) GetMachine(namespace string, name string) (*Machine, error) {
machines, err := h.ListMachinesInNamespace(namespace)
@ -217,7 +229,7 @@ func (h *Headscale) GetMachine(namespace string, name string) (*Machine, error)
return nil, err
}
for _, m := range *machines {
for _, m := range machines {
if m.Name == name {
return &m, nil
}
@ -326,7 +338,7 @@ func (h *Headscale) isOutdated(m *Machine) bool {
lastChange := h.getLastStateChange(namespaces...)
log.Trace().
Str("func", "keepAlive").
Caller().
Str("machine", m.Name).
Time("last_successful_update", *m.LastSuccessfulUpdate).
Time("last_state_change", lastChange).
@ -405,7 +417,7 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include
ip, err := netaddr.ParseIPPrefix(fmt.Sprintf("%s/32", m.IPAddress))
if err != nil {
log.Trace().
Str("func", "toNode").
Caller().
Str("ip", m.IPAddress).
Msgf("Failed to parse IP Prefix from IP: %s", m.IPAddress)
return nil, err
@ -508,3 +520,212 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include
}
return &n, nil
}
func (m *Machine) toProto() *v1.Machine {
machine := &v1.Machine{
Id: m.ID,
MachineKey: m.MachineKey,
NodeKey: m.NodeKey,
DiscoKey: m.DiscoKey,
IpAddress: m.IPAddress,
Name: m.Name,
Namespace: m.Namespace.toProto(),
Registered: m.Registered,
// TODO(kradalby): Implement register method enum converter
// RegisterMethod: ,
CreatedAt: timestamppb.New(m.CreatedAt),
}
if m.AuthKey != nil {
machine.PreAuthKey = m.AuthKey.toProto()
}
if m.LastSeen != nil {
machine.LastSeen = timestamppb.New(*m.LastSeen)
}
if m.LastSuccessfulUpdate != nil {
machine.LastSuccessfulUpdate = timestamppb.New(*m.LastSuccessfulUpdate)
}
if m.Expiry != nil {
machine.Expiry = timestamppb.New(*m.Expiry)
}
return machine
}
// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey
func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, error) {
ns, err := h.GetNamespace(namespace)
if err != nil {
return nil, err
}
mKey, err := wgkey.ParseHex(key)
if err != nil {
return nil, err
}
m := Machine{}
if result := h.db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, errors.New("Machine not found")
}
log.Trace().
Caller().
Str("machine", m.Name).
Msg("Attempting to register machine")
if m.isAlreadyRegistered() {
err := errors.New("Machine already registered")
log.Error().
Caller().
Err(err).
Str("machine", m.Name).
Msg("Attempting to register machine")
return nil, err
}
ip, err := h.getAvailableIP()
if err != nil {
log.Error().
Caller().
Err(err).
Str("machine", m.Name).
Msg("Could not find IP for the new machine")
return nil, err
}
log.Trace().
Caller().
Str("machine", m.Name).
Str("ip", ip.String()).
Msg("Found IP for host")
m.IPAddress = ip.String()
m.NamespaceID = ns.ID
m.Registered = true
m.RegisterMethod = "cli"
h.db.Save(&m)
log.Trace().
Caller().
Str("machine", m.Name).
Str("ip", ip.String()).
Msg("Machine registered with the database")
return &m, nil
}
func (m *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) {
hostInfo, err := m.GetHostInfo()
if err != nil {
return nil, err
}
return hostInfo.RoutableIPs, nil
}
func (m *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) {
data, err := m.EnabledRoutes.MarshalJSON()
if err != nil {
return nil, err
}
routesStr := []string{}
err = json.Unmarshal(data, &routesStr)
if err != nil {
return nil, err
}
routes := make([]netaddr.IPPrefix, len(routesStr))
for index, routeStr := range routesStr {
route, err := netaddr.ParseIPPrefix(routeStr)
if err != nil {
return nil, err
}
routes[index] = route
}
return routes, nil
}
func (m *Machine) IsRoutesEnabled(routeStr string) bool {
route, err := netaddr.ParseIPPrefix(routeStr)
if err != nil {
return false
}
enabledRoutes, err := m.GetEnabledRoutes()
if err != nil {
return false
}
for _, enabledRoute := range enabledRoutes {
if route == enabledRoute {
return true
}
}
return false
}
// EnableNodeRoute enables new routes based on a list of new routes. It will _replace_ the
// previous list of routes.
func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
newRoutes := make([]netaddr.IPPrefix, len(routeStrs))
for index, routeStr := range routeStrs {
route, err := netaddr.ParseIPPrefix(routeStr)
if err != nil {
return err
}
newRoutes[index] = route
}
availableRoutes, err := m.GetAdvertisedRoutes()
if err != nil {
return err
}
for _, newRoute := range newRoutes {
if !containsIpPrefix(availableRoutes, newRoute) {
return fmt.Errorf("route (%s) is not available on node %s", m.Name, newRoute)
}
}
routes, err := json.Marshal(newRoutes)
if err != nil {
return err
}
m.EnabledRoutes = datatypes.JSON(routes)
h.db.Save(&m)
err = h.RequestMapUpdates(m.NamespaceID)
if err != nil {
return err
}
return nil
}
func (m *Machine) RoutesToProto() (*v1.Routes, error) {
availableRoutes, err := m.GetAdvertisedRoutes()
if err != nil {
return nil, err
}
enabledRoutes, err := m.GetEnabledRoutes()
if err != nil {
return nil, err
}
return &v1.Routes{
AdvertisedRoutes: ipPrefixToString(availableRoutes),
EnabledRoutes: ipPrefixToString(enabledRoutes),
}, nil
}

View file

@ -3,13 +3,12 @@ package headscale
import (
"encoding/json"
"fmt"
"strconv"
"github.com/pterm/pterm"
"gorm.io/datatypes"
"inet.af/netaddr"
)
// Deprecated: use machine function instead
// GetAdvertisedNodeRoutes returns the subnet routes advertised by a node (identified by
// namespace and node name)
func (h *Headscale) GetAdvertisedNodeRoutes(namespace string, nodeName string) (*[]netaddr.IPPrefix, error) {
@ -25,6 +24,7 @@ func (h *Headscale) GetAdvertisedNodeRoutes(namespace string, nodeName string) (
return &hostInfo.RoutableIPs, nil
}
// Deprecated: use machine function instead
// GetEnabledNodeRoutes returns the subnet routes enabled by a node (identified by
// namespace and node name)
func (h *Headscale) GetEnabledNodeRoutes(namespace string, nodeName string) ([]netaddr.IPPrefix, error) {
@ -56,6 +56,7 @@ func (h *Headscale) GetEnabledNodeRoutes(namespace string, nodeName string) ([]n
return routes, nil
}
// Deprecated: use machine function instead
// IsNodeRouteEnabled checks if a certain route has been enabled
func (h *Headscale) IsNodeRouteEnabled(namespace string, nodeName string, routeStr string) bool {
route, err := netaddr.ParseIPPrefix(routeStr)
@ -76,6 +77,7 @@ func (h *Headscale) IsNodeRouteEnabled(namespace string, nodeName string, routeS
return false
}
// Deprecated: use EnableRoute in machine.go
// EnableNodeRoute enables a subnet route advertised by a node (identified by
// namespace and node name)
func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr string) error {
@ -129,15 +131,3 @@ func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr
return nil
}
// RoutesToPtables converts the list of routes to a nice table
func (h *Headscale) RoutesToPtables(namespace string, nodeName string, availableRoutes []netaddr.IPPrefix) pterm.TableData {
d := pterm.TableData{{"Route", "Enabled"}}
for _, route := range availableRoutes {
enabled := h.IsNodeRouteEnabled(namespace, nodeName, route.String())
d = append(d, []string{route.String(), strconv.FormatBool(enabled)})
}
return d
}