Redo route code (#2422)

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2025-02-26 07:22:55 -08:00 committed by GitHub
parent 16868190c8
commit 7891378f57
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
53 changed files with 2977 additions and 6251 deletions

View file

@ -12,12 +12,10 @@ import (
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/puzpuzpuz/xsync/v3"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/ptr"
)
const (
@ -50,7 +48,6 @@ func ListPeers(tx *gorm.DB, nodeID types.NodeID) (types.Nodes, error) {
Preload("AuthKey").
Preload("AuthKey.User").
Preload("User").
Preload("Routes").
Where("id <> ?",
nodeID).Find(&nodes).Error; err != nil {
return types.Nodes{}, err
@ -73,7 +70,6 @@ func ListNodes(tx *gorm.DB) (types.Nodes, error) {
Preload("AuthKey").
Preload("AuthKey.User").
Preload("User").
Preload("Routes").
Find(&nodes).Error; err != nil {
return nil, err
}
@ -127,7 +123,6 @@ func GetNodeByID(tx *gorm.DB, id types.NodeID) (*types.Node, error) {
Preload("AuthKey").
Preload("AuthKey.User").
Preload("User").
Preload("Routes").
Find(&types.Node{ID: id}).First(&mach); result.Error != nil {
return nil, result.Error
}
@ -151,7 +146,6 @@ func GetNodeByMachineKey(
Preload("AuthKey").
Preload("AuthKey.User").
Preload("User").
Preload("Routes").
First(&mach, "machine_key = ?", machineKey.String()); result.Error != nil {
return nil, result.Error
}
@ -175,7 +169,6 @@ func GetNodeByNodeKey(
Preload("AuthKey").
Preload("AuthKey.User").
Preload("User").
Preload("Routes").
First(&mach, "node_key = ?", nodeKey.String()); result.Error != nil {
return nil, result.Error
}
@ -201,7 +194,7 @@ func SetTags(
if len(tags) == 0 {
// if no tags are provided, we remove all forced tags
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", "[]").Error; err != nil {
return fmt.Errorf("failed to remove tags for node in the database: %w", err)
return fmt.Errorf("removing tags: %w", err)
}
return nil
@ -220,7 +213,34 @@ func SetTags(
}
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", string(b)).Error; err != nil {
return fmt.Errorf("failed to update tags for node in the database: %w", err)
return fmt.Errorf("updating tags: %w", err)
}
return nil
}
// SetTags takes a Node struct pointer and update the forced tags.
func SetApprovedRoutes(
tx *gorm.DB,
nodeID types.NodeID,
routes []netip.Prefix,
) error {
if len(routes) == 0 {
// if no routes are provided, we remove all
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", "[]").Error; err != nil {
return fmt.Errorf("removing approved routes: %w", err)
}
return nil
}
b, err := json.Marshal(routes)
if err != nil {
return err
}
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", string(b)).Error; err != nil {
return fmt.Errorf("updating approved routes: %w", err)
}
return nil
@ -267,9 +287,9 @@ func NodeSetExpiry(tx *gorm.DB,
return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("expiry", expiry).Error
}
func (hsdb *HSDatabase) DeleteNode(node *types.Node, isLikelyConnected *xsync.MapOf[types.NodeID, bool]) ([]types.NodeID, error) {
return Write(hsdb.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
return DeleteNode(tx, node, isLikelyConnected)
func (hsdb *HSDatabase) DeleteNode(node *types.Node) error {
return hsdb.Write(func(tx *gorm.DB) error {
return DeleteNode(tx, node)
})
}
@ -277,19 +297,13 @@ func (hsdb *HSDatabase) DeleteNode(node *types.Node, isLikelyConnected *xsync.Ma
// Caller is responsible for notifying all of change.
func DeleteNode(tx *gorm.DB,
node *types.Node,
isLikelyConnected *xsync.MapOf[types.NodeID, bool],
) ([]types.NodeID, error) {
changed, err := deleteNodeRoutes(tx, node, isLikelyConnected)
if err != nil {
return changed, err
}
) error {
// Unscoped causes the node to be fully removed from the database.
if err := tx.Unscoped().Delete(&types.Node{}, node.ID).Error; err != nil {
return changed, err
return err
}
return changed, nil
return nil
}
// DeleteEphemeralNode deletes a Node from the database, note that this method
@ -306,12 +320,6 @@ func (hsdb *HSDatabase) DeleteEphemeralNode(
})
}
// SetLastSeen sets a node's last seen field indicating that we
// have recently communicating with this node.
func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error {
return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("last_seen", lastSeen).Error
}
// HandleNodeFromAuthPath is called from the OIDC or CLI auth path
// with a registrationID to register or reauthenticate a node.
// If the node found in the registration cache is not already registered,
@ -458,10 +466,6 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
return nil, fmt.Errorf("failed register(save) node in the database: %w", err)
}
if _, err := SaveNodeRoutes(tx, &node); err != nil {
return nil, fmt.Errorf("failed to save node routes: %w", err)
}
log.Trace().
Caller().
Str("node", node.Hostname).
@ -504,141 +508,6 @@ func NodeSave(tx *gorm.DB, node *types.Node) error {
return tx.Save(node).Error
}
func (hsdb *HSDatabase) GetAdvertisedRoutes(node *types.Node) ([]netip.Prefix, error) {
return Read(hsdb.DB, func(rx *gorm.DB) ([]netip.Prefix, error) {
return GetAdvertisedRoutes(rx, node)
})
}
// GetAdvertisedRoutes returns the routes that are be advertised by the given node.
func GetAdvertisedRoutes(tx *gorm.DB, node *types.Node) ([]netip.Prefix, error) {
routes := types.Routes{}
err := tx.
Preload("Node").
Where("node_id = ? AND advertised = ?", node.ID, true).Find(&routes).Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("getting advertised routes for node(%d): %w", node.ID, err)
}
var prefixes []netip.Prefix
for _, route := range routes {
prefixes = append(prefixes, netip.Prefix(route.Prefix))
}
return prefixes, nil
}
func (hsdb *HSDatabase) GetEnabledRoutes(node *types.Node) ([]netip.Prefix, error) {
return Read(hsdb.DB, func(rx *gorm.DB) ([]netip.Prefix, error) {
return GetEnabledRoutes(rx, node)
})
}
// GetEnabledRoutes returns the routes that are enabled for the node.
func GetEnabledRoutes(tx *gorm.DB, node *types.Node) ([]netip.Prefix, error) {
routes := types.Routes{}
err := tx.
Preload("Node").
Where("node_id = ? AND advertised = ? AND enabled = ?", node.ID, true, true).
Find(&routes).Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("getting enabled routes for node(%d): %w", node.ID, err)
}
var prefixes []netip.Prefix
for _, route := range routes {
prefixes = append(prefixes, netip.Prefix(route.Prefix))
}
return prefixes, nil
}
func IsRoutesEnabled(tx *gorm.DB, node *types.Node, routeStr string) bool {
route, err := netip.ParsePrefix(routeStr)
if err != nil {
return false
}
enabledRoutes, err := GetEnabledRoutes(tx, node)
if err != nil {
return false
}
for _, enabledRoute := range enabledRoutes {
if route == enabledRoute {
return true
}
}
return false
}
func (hsdb *HSDatabase) enableRoutes(
node *types.Node,
newRoutes ...netip.Prefix,
) (*types.StateUpdate, error) {
return Write(hsdb.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
return enableRoutes(tx, node, newRoutes...)
})
}
// enableRoutes enables new routes based on a list of new routes.
func enableRoutes(tx *gorm.DB,
node *types.Node, newRoutes ...netip.Prefix,
) (*types.StateUpdate, error) {
advertisedRoutes, err := GetAdvertisedRoutes(tx, node)
if err != nil {
return nil, err
}
for _, newRoute := range newRoutes {
if !slices.Contains(advertisedRoutes, newRoute) {
return nil, fmt.Errorf(
"route (%s) is not available on node %s: %w",
node.Hostname,
newRoute, ErrNodeRouteIsNotAvailable,
)
}
}
// Separate loop so we don't leave things in a half-updated state
for _, prefix := range newRoutes {
route := types.Route{}
err := tx.Preload("Node").
Where("node_id = ? AND prefix = ?", node.ID, prefix.String()).
First(&route).Error
if err == nil {
route.Enabled = true
// Mark already as primary if there is only this node offering this subnet
// (and is not an exit route)
if !route.IsExitRoute() {
route.IsPrimary = isUniquePrefix(tx, route)
}
err = tx.Save(&route).Error
if err != nil {
return nil, fmt.Errorf("failed to enable route: %w", err)
}
} else {
return nil, fmt.Errorf("failed to find route: %w", err)
}
}
// Ensure the node has the latest routes when notifying the other
// nodes
nRoutes, err := GetNodeRoutes(tx, node)
if err != nil {
return nil, fmt.Errorf("failed to read back routes: %w", err)
}
node.Routes = nRoutes
return ptr.To(types.UpdatePeerChanged(node.ID)), nil
}
func generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
suppliedName = util.ConvertWithFQDNRules(suppliedName)
if len(suppliedName) > util.LabelHostnameLength {