package types import ( "database/sql/driver" "errors" "fmt" "net/netip" "sort" "strings" "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/policy/matcher" "github.com/juanfont/headscale/hscontrol/util" "go4.org/netipx" "google.golang.org/protobuf/types/known/timestamppb" "tailscale.com/tailcfg" "tailscale.com/types/key" ) var ( ErrNodeAddressesInvalid = errors.New("failed to parse node addresses") ErrHostnameTooLong = errors.New("hostname too long") ) // Node is a Headscale client. type Node struct { ID uint64 `gorm:"primary_key"` MachineKey string `gorm:"type:varchar(64);unique_index"` NodeKey string DiscoKey string IPAddresses NodeAddresses // Hostname represents the name given by the Tailscale // client during registration Hostname string // Givenname represents either: // a DNS normalized version of Hostname // a valid name set by the User // // GivenName is the name used in all DNS related // parts of headscale. GivenName string `gorm:"type:varchar(63);unique_index"` UserID uint User User `gorm:"foreignKey:UserID"` RegisterMethod string ForcedTags StringList // TODO(kradalby): This seems like irrelevant information? AuthKeyID uint AuthKey *PreAuthKey LastSeen *time.Time Expiry *time.Time HostInfo HostInfo Endpoints StringList Routes []Route CreatedAt time.Time UpdatedAt time.Time DeletedAt *time.Time } type ( Nodes []*Node ) type NodeAddresses []netip.Addr func (na NodeAddresses) Sort() { sort.Slice(na, func(index1, index2 int) bool { if na[index1].Is4() && na[index2].Is6() { return true } if na[index1].Is6() && na[index2].Is4() { return false } return na[index1].Compare(na[index2]) < 0 }) } func (na NodeAddresses) StringSlice() []string { na.Sort() strSlice := make([]string, 0, len(na)) for _, addr := range na { strSlice = append(strSlice, addr.String()) } return strSlice } func (na NodeAddresses) Prefixes() []netip.Prefix { addrs := []netip.Prefix{} for _, nodeAddress := range na { ip := netip.PrefixFrom(nodeAddress, nodeAddress.BitLen()) addrs = append(addrs, ip) } return addrs } func (na NodeAddresses) InIPSet(set *netipx.IPSet) bool { for _, nodeAddr := range na { if set.Contains(nodeAddr) { return true } } return false } // AppendToIPSet adds the individual ips in NodeAddresses to a // given netipx.IPSetBuilder. func (na NodeAddresses) AppendToIPSet(build *netipx.IPSetBuilder) { for _, ip := range na { build.Add(ip) } } func (na *NodeAddresses) Scan(destination interface{}) error { switch value := destination.(type) { case string: addresses := strings.Split(value, ",") *na = (*na)[:0] for _, addr := range addresses { if len(addr) < 1 { continue } parsed, err := netip.ParseAddr(addr) if err != nil { return err } *na = append(*na, parsed) } return nil default: return fmt.Errorf("%w: unexpected data type %T", ErrNodeAddressesInvalid, destination) } } // Value return json value, implement driver.Valuer interface. func (na NodeAddresses) Value() (driver.Value, error) { addresses := strings.Join(na.StringSlice(), ",") return addresses, nil } // IsExpired returns whether the node registration has expired. func (node Node) IsExpired() bool { // If Expiry is not set, the client has not indicated that // it wants an expiry time, it is therefor considered // to mean "not expired" if node.Expiry == nil || node.Expiry.IsZero() { return false } return time.Now().UTC().After(*node.Expiry) } // IsOnline returns if the node is connected to Headscale. // This is really a naive implementation, as we don't really see // if there is a working connection between the client and the server. func (node *Node) IsOnline() bool { if node.LastSeen == nil { return false } if node.IsExpired() { return false } return node.LastSeen.After(time.Now().Add(-KeepAliveInterval)) } // IsEphemeral returns if the node is registered as an Ephemeral node. // https://tailscale.com/kb/1111/ephemeral-nodes/ func (node *Node) IsEphemeral() bool { return node.AuthKey != nil && node.AuthKey.Ephemeral } func (node *Node) CanAccess(filter []tailcfg.FilterRule, node2 *Node) bool { for _, rule := range filter { // TODO(kradalby): Cache or pregen this matcher := matcher.MatchFromFilterRule(rule) if !matcher.SrcsContainsIPs([]netip.Addr(node.IPAddresses)) { continue } if matcher.DestsContainsIP([]netip.Addr(node2.IPAddresses)) { return true } } return false } func (nodes Nodes) FilterByIP(ip netip.Addr) Nodes { found := make(Nodes, 0) for _, node := range nodes { for _, mIP := range node.IPAddresses { if ip == mIP { found = append(found, node) } } } return found } func (node *Node) Proto() *v1.Node { nodeProto := &v1.Node{ Id: node.ID, MachineKey: node.MachineKey, NodeKey: node.NodeKey, DiscoKey: node.DiscoKey, IpAddresses: node.IPAddresses.StringSlice(), Name: node.Hostname, GivenName: node.GivenName, User: node.User.Proto(), ForcedTags: node.ForcedTags, Online: node.IsOnline(), // TODO(kradalby): Implement register method enum converter // RegisterMethod: , CreatedAt: timestamppb.New(node.CreatedAt), } if node.AuthKey != nil { nodeProto.PreAuthKey = node.AuthKey.Proto() } if node.LastSeen != nil { nodeProto.LastSeen = timestamppb.New(*node.LastSeen) } if node.Expiry != nil { nodeProto.Expiry = timestamppb.New(*node.Expiry) } return nodeProto } // GetHostInfo returns a Hostinfo struct for the node. func (node *Node) GetHostInfo() tailcfg.Hostinfo { return tailcfg.Hostinfo(node.HostInfo) } func (node *Node) GetFQDN(dnsConfig *tailcfg.DNSConfig, baseDomain string) (string, error) { var hostname string if dnsConfig != nil && dnsConfig.Proxied { // MagicDNS hostname = fmt.Sprintf( "%s.%s.%s", node.GivenName, node.User.Name, baseDomain, ) if len(hostname) > MaxHostnameLength { return "", fmt.Errorf( "hostname %q is too long it cannot except 255 ASCII chars: %w", hostname, ErrHostnameTooLong, ) } } else { hostname = node.GivenName } return hostname, nil } func (node *Node) MachinePublicKey() (key.MachinePublic, error) { var machineKey key.MachinePublic if node.MachineKey != "" { err := machineKey.UnmarshalText( []byte(util.MachinePublicKeyEnsurePrefix(node.MachineKey)), ) if err != nil { return key.MachinePublic{}, fmt.Errorf("failed to parse machine public key: %w", err) } } return machineKey, nil } func (node *Node) DiscoPublicKey() (key.DiscoPublic, error) { var discoKey key.DiscoPublic if node.DiscoKey != "" { err := discoKey.UnmarshalText( []byte(util.DiscoPublicKeyEnsurePrefix(node.DiscoKey)), ) if err != nil { return key.DiscoPublic{}, fmt.Errorf("failed to parse disco public key: %w", err) } } else { discoKey = key.DiscoPublic{} } return discoKey, nil } func (node *Node) NodePublicKey() (key.NodePublic, error) { var nodeKey key.NodePublic err := nodeKey.UnmarshalText([]byte(util.NodePublicKeyEnsurePrefix(node.NodeKey))) if err != nil { return key.NodePublic{}, fmt.Errorf("failed to parse node public key: %w", err) } return nodeKey, nil } func (node Node) String() string { return node.Hostname } func (nodes Nodes) String() string { temp := make([]string, len(nodes)) for index, node := range nodes { temp[index] = node.Hostname } return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp)) } func (nodes Nodes) IDMap() map[uint64]*Node { ret := map[uint64]*Node{} for _, node := range nodes { ret[node.ID] = node } return ret }