headscale/hscontrol/types/node.go

561 lines
14 KiB
Go
Raw Normal View History

2023-09-24 20:42:05 +09:00
package types
import (
"database/sql"
"encoding/json"
2023-09-24 20:42:05 +09:00
"errors"
"fmt"
"net/netip"
"strconv"
2023-09-24 20:42:05 +09:00
"strings"
"time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/util"
2023-09-24 20:42:05 +09:00
"go4.org/netipx"
"google.golang.org/protobuf/types/known/timestamppb"
"gorm.io/gorm"
2023-09-24 20:42:05 +09:00
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
var (
ErrNodeAddressesInvalid = errors.New("failed to parse node addresses")
ErrHostnameTooLong = errors.New("hostname too long, cannot except 255 ASCII chars")
ErrNodeHasNoGivenName = errors.New("node has no given name")
ErrNodeUserHasNoName = errors.New("node user has no name")
2023-09-24 20:42:05 +09:00
)
type NodeID uint64
// type NodeConnectedMap *xsync.MapOf[NodeID, bool]
func (id NodeID) StableID() tailcfg.StableNodeID {
return tailcfg.StableNodeID(strconv.FormatUint(uint64(id), util.Base10))
}
func (id NodeID) NodeID() tailcfg.NodeID {
return tailcfg.NodeID(id)
}
func (id NodeID) Uint64() uint64 {
return uint64(id)
}
func (id NodeID) String() string {
return strconv.FormatUint(id.Uint64(), util.Base10)
}
2023-09-24 20:42:05 +09:00
// Node is a Headscale client.
type Node struct {
ID NodeID `gorm:"primary_key"`
// MachineKeyDatabaseField is the string representation of MachineKey
// it is _only_ used for reading and writing the key to the
// database and should not be used.
// Use MachineKey instead.
MachineKeyDatabaseField string `gorm:"column:machine_key;unique_index"`
MachineKey key.MachinePublic `gorm:"-"`
// NodeKeyDatabaseField is the string representation of NodeKey
// it is _only_ used for reading and writing the key to the
// database and should not be used.
// Use NodeKey instead.
NodeKeyDatabaseField string `gorm:"column:node_key"`
NodeKey key.NodePublic `gorm:"-"`
// DiscoKeyDatabaseField is the string representation of DiscoKey
// it is _only_ used for reading and writing the key to the
// database and should not be used.
// Use DiscoKey instead.
DiscoKeyDatabaseField string `gorm:"column:disco_key"`
DiscoKey key.DiscoPublic `gorm:"-"`
// EndpointsDatabaseField is the string list representation of Endpoints
// it is _only_ used for reading and writing the key to the
// database and should not be used.
// Use Endpoints instead.
EndpointsDatabaseField StringList `gorm:"column:endpoints"`
Endpoints []netip.AddrPort `gorm:"-"`
// EndpointsDatabaseField is the string list representation of Endpoints
// it is _only_ used for reading and writing the key to the
// database and should not be used.
// Use Endpoints instead.
HostinfoDatabaseField string `gorm:"column:host_info"`
Hostinfo *tailcfg.Hostinfo `gorm:"-"`
// IPv4DatabaseField is the string representation of v4 address,
// it is _only_ used for reading and writing the key to the
// database and should not be used.
// Use V4 instead.
IPv4DatabaseField sql.NullString `gorm:"column:ipv4"`
IPv4 *netip.Addr `gorm:"-"`
// IPv6DatabaseField is the string representation of v4 address,
// it is _only_ used for reading and writing the key to the
// database and should not be used.
// Use V6 instead.
IPv6DatabaseField sql.NullString `gorm:"column:ipv6"`
IPv6 *netip.Addr `gorm:"-"`
2023-09-24 20:42:05 +09:00
// Hostname represents the name given by the Tailscale
// client during registration
Hostname string
// Givenname represents either:
// a DNS normalized version of Hostname
// a valid name set by the User
//
// GivenName is the name used in all DNS related
// parts of headscale.
GivenName string `gorm:"type:varchar(63);unique_index"`
UserID uint
User User `gorm:"constraint:OnDelete:CASCADE;"`
2023-09-24 20:42:05 +09:00
RegisterMethod string
ForcedTags StringList
// TODO(kradalby): This seems like irrelevant information?
AuthKeyID *uint64 `sql:"DEFAULT:NULL"`
AuthKey *PreAuthKey `gorm:"constraint:OnDelete:SET NULL;"`
2023-09-24 20:42:05 +09:00
LastSeen *time.Time
Expiry *time.Time
Routes []Route `gorm:"constraint:OnDelete:CASCADE;"`
2023-09-24 20:42:05 +09:00
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt *time.Time
IsOnline *bool `gorm:"-"`
2023-09-24 20:42:05 +09:00
}
type (
Nodes []*Node
)
// IsExpired returns whether the node registration has expired.
func (node Node) IsExpired() bool {
// If Expiry is not set, the client has not indicated that
// it wants an expiry time, it is therefor considered
// to mean "not expired"
if node.Expiry == nil || node.Expiry.IsZero() {
return false
}
2023-09-24 20:42:05 +09:00
return time.Since(*node.Expiry) > 0
}
2023-09-24 20:42:05 +09:00
// IsEphemeral returns if the node is registered as an Ephemeral node.
// https://tailscale.com/kb/1111/ephemeral-nodes/
func (node *Node) IsEphemeral() bool {
return node.AuthKey != nil && node.AuthKey.Ephemeral
2023-09-24 20:42:05 +09:00
}
func (node *Node) IPs() []netip.Addr {
var ret []netip.Addr
if node.IPv4 != nil {
ret = append(ret, *node.IPv4)
}
if node.IPv6 != nil {
ret = append(ret, *node.IPv6)
2023-09-24 20:42:05 +09:00
}
return ret
2023-09-24 20:42:05 +09:00
}
func (node *Node) Prefixes() []netip.Prefix {
2023-09-24 20:42:05 +09:00
addrs := []netip.Prefix{}
for _, nodeAddress := range node.IPs() {
2023-09-24 20:42:05 +09:00
ip := netip.PrefixFrom(nodeAddress, nodeAddress.BitLen())
addrs = append(addrs, ip)
}
return addrs
}
func (node *Node) IPsAsString() []string {
var ret []string
if node.IPv4 != nil {
ret = append(ret, node.IPv4.String())
}
if node.IPv6 != nil {
ret = append(ret, node.IPv6.String())
}
return ret
}
func (node *Node) InIPSet(set *netipx.IPSet) bool {
for _, nodeAddr := range node.IPs() {
2023-09-24 20:42:05 +09:00
if set.Contains(nodeAddr) {
return true
}
}
return false
}
// AppendToIPSet adds the individual ips in NodeAddresses to a
// given netipx.IPSetBuilder.
func (node *Node) AppendToIPSet(build *netipx.IPSetBuilder) {
for _, ip := range node.IPs() {
2023-09-24 20:42:05 +09:00
build.Add(ip)
}
}
func (node *Node) CanAccess(filter []tailcfg.FilterRule, node2 *Node) bool {
src := node.IPs()
allowedIPs := node2.IPs()
for _, route := range node2.Routes {
if route.Enabled {
allowedIPs = append(allowedIPs, netip.Prefix(route.Prefix).Addr())
}
}
2023-09-24 20:42:05 +09:00
for _, rule := range filter {
// TODO(kradalby): Cache or pregen this
matcher := matcher.MatchFromFilterRule(rule)
if !matcher.SrcsContainsIPs(src) {
2023-09-24 20:42:05 +09:00
continue
}
if matcher.DestsContainsIP(allowedIPs) {
2023-09-24 20:42:05 +09:00
return true
}
}
return false
}
func (nodes Nodes) FilterByIP(ip netip.Addr) Nodes {
var found Nodes
2023-09-24 20:42:05 +09:00
for _, node := range nodes {
if node.IPv4 != nil && ip == *node.IPv4 {
found = append(found, node)
continue
}
if node.IPv6 != nil && ip == *node.IPv6 {
found = append(found, node)
2023-09-24 20:42:05 +09:00
}
}
return found
}
// BeforeSave is a hook that ensures that some values that
// cannot be directly marshalled into database values are stored
// correctly in the database.
// This currently means storing the keys as strings.
func (node *Node) BeforeSave(tx *gorm.DB) error {
node.MachineKeyDatabaseField = node.MachineKey.String()
node.NodeKeyDatabaseField = node.NodeKey.String()
node.DiscoKeyDatabaseField = node.DiscoKey.String()
var endpoints StringList
for _, addrPort := range node.Endpoints {
endpoints = append(endpoints, addrPort.String())
}
node.EndpointsDatabaseField = endpoints
hi, err := json.Marshal(node.Hostinfo)
if err != nil {
return fmt.Errorf("marshalling Hostinfo to store in db: %w", err)
}
node.HostinfoDatabaseField = string(hi)
if node.IPv4 != nil {
node.IPv4DatabaseField.String, node.IPv4DatabaseField.Valid = node.IPv4.String(), true
} else {
node.IPv4DatabaseField.String, node.IPv4DatabaseField.Valid = "", false
}
if node.IPv6 != nil {
node.IPv6DatabaseField.String, node.IPv6DatabaseField.Valid = node.IPv6.String(), true
} else {
node.IPv6DatabaseField.String, node.IPv6DatabaseField.Valid = "", false
}
return nil
}
// AfterFind is a hook that ensures that Node objects fields that
// has a different type in the database is unwrapped and populated
// correctly.
// This currently unmarshals all the keys, stored as strings, into
// the proper types.
func (node *Node) AfterFind(tx *gorm.DB) error {
var machineKey key.MachinePublic
if err := machineKey.UnmarshalText([]byte(node.MachineKeyDatabaseField)); err != nil {
return fmt.Errorf("unmarshalling machine key from db: %w", err)
}
node.MachineKey = machineKey
var nodeKey key.NodePublic
if err := nodeKey.UnmarshalText([]byte(node.NodeKeyDatabaseField)); err != nil {
return fmt.Errorf("unmarshalling node key from db: %w", err)
}
node.NodeKey = nodeKey
// DiscoKey might be empty if a node has not sent it to headscale.
// This means that this might fail if the disco key is empty.
if node.DiscoKeyDatabaseField != "" {
var discoKey key.DiscoPublic
if err := discoKey.UnmarshalText([]byte(node.DiscoKeyDatabaseField)); err != nil {
return fmt.Errorf("unmarshalling disco key from db: %w", err)
}
node.DiscoKey = discoKey
}
endpoints := make([]netip.AddrPort, len(node.EndpointsDatabaseField))
for idx, ep := range node.EndpointsDatabaseField {
addrPort, err := netip.ParseAddrPort(ep)
if err != nil {
return fmt.Errorf("parsing endpoint from db: %w", err)
}
endpoints[idx] = addrPort
}
node.Endpoints = endpoints
var hi tailcfg.Hostinfo
if err := json.Unmarshal([]byte(node.HostinfoDatabaseField), &hi); err != nil {
return fmt.Errorf("unmarshalling hostinfo from database: %w", err)
}
node.Hostinfo = &hi
if node.IPv4DatabaseField.Valid {
ip, err := netip.ParseAddr(node.IPv4DatabaseField.String)
if err != nil {
return fmt.Errorf("parsing IPv4 from database: %w", err)
}
node.IPv4 = &ip
}
if node.IPv6DatabaseField.Valid {
ip, err := netip.ParseAddr(node.IPv6DatabaseField.String)
if err != nil {
return fmt.Errorf("parsing IPv6 from database: %w", err)
}
node.IPv6 = &ip
}
return nil
}
2023-09-24 20:42:05 +09:00
func (node *Node) Proto() *v1.Node {
nodeProto := &v1.Node{
Id: uint64(node.ID),
MachineKey: node.MachineKey.String(),
2023-09-24 20:42:05 +09:00
NodeKey: node.NodeKey.String(),
DiscoKey: node.DiscoKey.String(),
// TODO(kradalby): replace list with v4, v6 field?
IpAddresses: node.IPsAsString(),
2023-09-24 20:42:05 +09:00
Name: node.Hostname,
GivenName: node.GivenName,
User: node.User.Proto(),
ForcedTags: node.ForcedTags,
RegisterMethod: node.RegisterMethodToV1Enum(),
2023-09-24 20:42:05 +09:00
CreatedAt: timestamppb.New(node.CreatedAt),
}
if node.AuthKey != nil {
nodeProto.PreAuthKey = node.AuthKey.Proto()
}
if node.LastSeen != nil {
nodeProto.LastSeen = timestamppb.New(*node.LastSeen)
}
if node.Expiry != nil {
nodeProto.Expiry = timestamppb.New(*node.Expiry)
}
return nodeProto
}
func (node *Node) GetFQDN(cfg *Config, baseDomain string) (string, error) {
if node.GivenName == "" {
return "", fmt.Errorf("failed to create valid FQDN: %w", ErrNodeHasNoGivenName)
}
hostname := node.GivenName
if baseDomain != "" {
2023-09-24 20:42:05 +09:00
hostname = fmt.Sprintf(
"%s.%s",
2023-09-24 20:42:05 +09:00
node.GivenName,
baseDomain,
)
}
if cfg.DNSUserNameInMagicDNS {
if node.User.Name == "" {
return "", fmt.Errorf("failed to create valid FQDN: %w", ErrNodeUserHasNoName)
}
hostname = fmt.Sprintf(
"%s.%s.%s",
node.GivenName,
node.User.Name,
baseDomain,
)
}
if len(hostname) > MaxHostnameLength {
return "", fmt.Errorf(
"failed to create valid FQDN (%s): %w",
hostname,
ErrHostnameTooLong,
)
2023-09-24 20:42:05 +09:00
}
return hostname, nil
}
// func (node *Node) String() string {
// return node.Hostname
// }
// PeerChangeFromMapRequest takes a MapRequest and compares it to the node
// to produce a PeerChange struct that can be used to updated the node and
// inform peers about smaller changes to the node.
// When a field is added to this function, remember to also add it to:
// - node.ApplyPeerChange
// - logTracePeerChange in poll.go.
func (node *Node) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.PeerChange {
ret := tailcfg.PeerChange{
NodeID: tailcfg.NodeID(node.ID),
}
if node.NodeKey.String() != req.NodeKey.String() {
ret.Key = &req.NodeKey
}
if node.DiscoKey.String() != req.DiscoKey.String() {
ret.DiscoKey = &req.DiscoKey
}
if node.Hostinfo != nil &&
node.Hostinfo.NetInfo != nil &&
req.Hostinfo != nil &&
req.Hostinfo.NetInfo != nil &&
node.Hostinfo.NetInfo.PreferredDERP != req.Hostinfo.NetInfo.PreferredDERP {
ret.DERPRegion = req.Hostinfo.NetInfo.PreferredDERP
}
if req.Hostinfo != nil && req.Hostinfo.NetInfo != nil {
// If there is no stored Hostinfo or NetInfo, use
// the new PreferredDERP.
if node.Hostinfo == nil {
ret.DERPRegion = req.Hostinfo.NetInfo.PreferredDERP
} else if node.Hostinfo.NetInfo == nil {
ret.DERPRegion = req.Hostinfo.NetInfo.PreferredDERP
} else {
// If there is a PreferredDERP check if it has changed.
if node.Hostinfo.NetInfo.PreferredDERP != req.Hostinfo.NetInfo.PreferredDERP {
ret.DERPRegion = req.Hostinfo.NetInfo.PreferredDERP
}
}
}
// TODO(kradalby): Find a good way to compare updates
ret.Endpoints = req.Endpoints
now := time.Now()
ret.LastSeen = &now
return ret
}
func (node *Node) RegisterMethodToV1Enum() v1.RegisterMethod {
switch node.RegisterMethod {
case "authkey":
return v1.RegisterMethod_REGISTER_METHOD_AUTH_KEY
case "oidc":
return v1.RegisterMethod_REGISTER_METHOD_OIDC
case "cli":
return v1.RegisterMethod_REGISTER_METHOD_CLI
default:
return v1.RegisterMethod_REGISTER_METHOD_UNSPECIFIED
}
}
// ApplyPeerChange takes a PeerChange struct and updates the node.
func (node *Node) ApplyPeerChange(change *tailcfg.PeerChange) {
if change.Key != nil {
node.NodeKey = *change.Key
}
if change.DiscoKey != nil {
node.DiscoKey = *change.DiscoKey
}
if change.Online != nil {
node.IsOnline = change.Online
}
if change.Endpoints != nil {
node.Endpoints = change.Endpoints
}
// This might technically not be useful as we replace
// the whole hostinfo blob when it has changed.
if change.DERPRegion != 0 {
if node.Hostinfo == nil {
node.Hostinfo = &tailcfg.Hostinfo{
NetInfo: &tailcfg.NetInfo{
PreferredDERP: change.DERPRegion,
},
}
} else if node.Hostinfo.NetInfo == nil {
node.Hostinfo.NetInfo = &tailcfg.NetInfo{
PreferredDERP: change.DERPRegion,
}
} else {
node.Hostinfo.NetInfo.PreferredDERP = change.DERPRegion
}
}
node.LastSeen = change.LastSeen
2023-09-24 20:42:05 +09:00
}
func (nodes Nodes) String() string {
temp := make([]string, len(nodes))
for index, node := range nodes {
temp[index] = node.Hostname
}
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
}
func (nodes Nodes) IDMap() map[NodeID]*Node {
ret := map[NodeID]*Node{}
2023-09-24 20:42:05 +09:00
for _, node := range nodes {
ret[node.ID] = node
}
return ret
}