move to use tailscfg types over strings/custom types (#1612)

* rename database only fields

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* use correct endpoint type over string list

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* remove HostInfo wrapper

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* wrap errors in database hooks

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

---------

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2023-11-21 18:20:06 +01:00 committed by GitHub
parent ed4e19996b
commit b918aa03fc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 147 additions and 154 deletions

View file

@ -2,6 +2,7 @@ package types
import (
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"net/netip"
@ -27,27 +28,40 @@ var (
type Node struct {
ID uint64 `gorm:"primary_key"`
// MachineKeyValue is the string representation of MachineKey
// MachineKeyDatabaseField is the string representation of MachineKey
// it is _only_ used for reading and writing the key to the
// database and should not be used.
// Use MachineKey instead.
MachineKeyValue string `gorm:"column:machine_key;unique_index"`
MachineKeyDatabaseField string `gorm:"column:machine_key;unique_index"`
MachineKey key.MachinePublic `gorm:"-"`
// NodeKeyValue is the string representation of NodeKey
// NodeKeyDatabaseField is the string representation of NodeKey
// it is _only_ used for reading and writing the key to the
// database and should not be used.
// Use NodeKey instead.
NodeKeyValue string `gorm:"column:node_key"`
NodeKeyDatabaseField string `gorm:"column:node_key"`
NodeKey key.NodePublic `gorm:"-"`
// DiscoKeyValue is the string representation of DiscoKey
// DiscoKeyDatabaseField is the string representation of DiscoKey
// it is _only_ used for reading and writing the key to the
// database and should not be used.
// Use DiscoKey instead.
DiscoKeyValue string `gorm:"column:disco_key"`
DiscoKeyDatabaseField string `gorm:"column:disco_key"`
DiscoKey key.DiscoPublic `gorm:"-"`
MachineKey key.MachinePublic `gorm:"-"`
NodeKey key.NodePublic `gorm:"-"`
DiscoKey key.DiscoPublic `gorm:"-"`
// EndpointsDatabaseField is the string list representation of Endpoints
// it is _only_ used for reading and writing the key to the
// database and should not be used.
// Use Endpoints instead.
EndpointsDatabaseField StringList `gorm:"column:endpoints"`
Endpoints []netip.AddrPort `gorm:"-"`
// EndpointsDatabaseField is the string list representation of Endpoints
// it is _only_ used for reading and writing the key to the
// database and should not be used.
// Use Endpoints instead.
HostinfoDatabaseField string `gorm:"column:hostinfo"`
Hostinfo *tailcfg.Hostinfo `gorm:"-"`
IPAddresses NodeAddresses
@ -76,9 +90,6 @@ type Node struct {
LastSeen *time.Time
Expiry *time.Time
HostInfo HostInfo
Endpoints StringList
Routes []Route
CreatedAt time.Time
@ -195,31 +206,6 @@ func (node Node) IsExpired() bool {
return time.Now().UTC().After(*node.Expiry)
}
// TODO(kradalby): Try to replace the types in the DB to be correct.
func (node *Node) EndpointsToAddrPort() ([]netip.AddrPort, error) {
var ret []netip.AddrPort
for _, ep := range node.Endpoints {
addrPort, err := netip.ParseAddrPort(ep)
if err != nil {
return nil, err
}
ret = append(ret, addrPort)
}
return ret, nil
}
// TODO(kradalby): Try to replace the types in the DB to be correct.
func (node *Node) SetEndpointsFromAddrPorts(in []netip.AddrPort) {
var strs StringList
for _, addrPort := range in {
strs = append(strs, addrPort.String())
}
node.Endpoints = strs
}
// IsOnline returns if the node is connected to Headscale.
// This is really a naive implementation, as we don't really see
// if there is a working connection between the client and the server.
@ -277,9 +263,22 @@ func (nodes Nodes) FilterByIP(ip netip.Addr) Nodes {
// correctly in the database.
// This currently means storing the keys as strings.
func (n *Node) BeforeSave(tx *gorm.DB) (err error) {
n.MachineKeyValue = n.MachineKey.String()
n.NodeKeyValue = n.NodeKey.String()
n.DiscoKeyValue = n.DiscoKey.String()
n.MachineKeyDatabaseField = n.MachineKey.String()
n.NodeKeyDatabaseField = n.NodeKey.String()
n.DiscoKeyDatabaseField = n.DiscoKey.String()
var endpoints StringList
for _, addrPort := range n.Endpoints {
endpoints = append(endpoints, addrPort.String())
}
n.EndpointsDatabaseField = endpoints
hi, err := json.Marshal(n.Hostinfo)
if err != nil {
return fmt.Errorf("failed to marshal Hostinfo to store in db: %w", err)
}
n.HostinfoDatabaseField = string(hi)
return
}
@ -291,23 +290,40 @@ func (n *Node) BeforeSave(tx *gorm.DB) (err error) {
// the proper types.
func (n *Node) AfterFind(tx *gorm.DB) (err error) {
var machineKey key.MachinePublic
if err := machineKey.UnmarshalText([]byte(n.MachineKeyValue)); err != nil {
return err
if err := machineKey.UnmarshalText([]byte(n.MachineKeyDatabaseField)); err != nil {
return fmt.Errorf("failed to unmarshal machine key from db: %w", err)
}
n.MachineKey = machineKey
var nodeKey key.NodePublic
if err := nodeKey.UnmarshalText([]byte(n.NodeKeyValue)); err != nil {
return err
if err := nodeKey.UnmarshalText([]byte(n.NodeKeyDatabaseField)); err != nil {
return fmt.Errorf("failed to unmarshal node key from db: %w", err)
}
n.NodeKey = nodeKey
var discoKey key.DiscoPublic
if err := discoKey.UnmarshalText([]byte(n.DiscoKeyValue)); err != nil {
return err
if err := discoKey.UnmarshalText([]byte(n.DiscoKeyDatabaseField)); err != nil {
return fmt.Errorf("failed to unmarshal disco key from db: %w", err)
}
n.DiscoKey = discoKey
var endpoints []netip.AddrPort
for _, ep := range n.EndpointsDatabaseField {
addrPort, err := netip.ParseAddrPort(ep)
if err != nil {
return fmt.Errorf("failed to parse endpoint from db: %w", err)
}
endpoints = append(endpoints, addrPort)
}
n.Endpoints = endpoints
var hi tailcfg.Hostinfo
if err := json.Unmarshal([]byte(n.HostinfoDatabaseField), &hi); err != nil {
return fmt.Errorf("failed to unmarshal Hostinfo from db: %w", err)
}
n.Hostinfo = &hi
return
}
@ -346,11 +362,6 @@ func (node *Node) Proto() *v1.Node {
return nodeProto
}
// GetHostInfo returns a Hostinfo struct for the node.
func (node *Node) GetHostInfo() tailcfg.Hostinfo {
return tailcfg.Hostinfo(node.HostInfo)
}
func (node *Node) GetFQDN(dnsConfig *tailcfg.DNSConfig, baseDomain string) (string, error) {
var hostname string
if dnsConfig != nil && dnsConfig.Proxied { // MagicDNS