Replace database locks with transactions (#1701)
This commits removes the locks used to guard data integrity for the database and replaces them with Transactions, turns out that SQL had a way to deal with this all along. This reduces the complexity we had with multiple locks that might stack or recurse (database, nofitifer, mapper). All notifications and state updates are now triggered _after_ a database change. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
cbf57e27a7
commit
83769ba715
32 changed files with 1496 additions and 1128 deletions
|
@ -7,23 +7,15 @@ import (
|
|||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/samber/lo"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
var ErrRouteIsNotAvailable = errors.New("route is not available")
|
||||
|
||||
func (hsdb *HSDatabase) GetRoutes() (types.Routes, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
return hsdb.getRoutes()
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) getRoutes() (types.Routes, error) {
|
||||
func GetRoutes(tx *gorm.DB) (types.Routes, error) {
|
||||
var routes types.Routes
|
||||
err := hsdb.db.
|
||||
err := tx.
|
||||
Preload("Node").
|
||||
Preload("Node.User").
|
||||
Find(&routes).Error
|
||||
|
@ -34,9 +26,9 @@ func (hsdb *HSDatabase) getRoutes() (types.Routes, error) {
|
|||
return routes, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) getAdvertisedAndEnabledRoutes() (types.Routes, error) {
|
||||
func getAdvertisedAndEnabledRoutes(tx *gorm.DB) (types.Routes, error) {
|
||||
var routes types.Routes
|
||||
err := hsdb.db.
|
||||
err := tx.
|
||||
Preload("Node").
|
||||
Preload("Node.User").
|
||||
Where("advertised = ? AND enabled = ?", true, true).
|
||||
|
@ -48,9 +40,9 @@ func (hsdb *HSDatabase) getAdvertisedAndEnabledRoutes() (types.Routes, error) {
|
|||
return routes, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) getRoutesByPrefix(pref netip.Prefix) (types.Routes, error) {
|
||||
func getRoutesByPrefix(tx *gorm.DB, pref netip.Prefix) (types.Routes, error) {
|
||||
var routes types.Routes
|
||||
err := hsdb.db.
|
||||
err := tx.
|
||||
Preload("Node").
|
||||
Preload("Node.User").
|
||||
Where("prefix = ?", types.IPPrefix(pref)).
|
||||
|
@ -62,16 +54,9 @@ func (hsdb *HSDatabase) getRoutesByPrefix(pref netip.Prefix) (types.Routes, erro
|
|||
return routes, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetNodeAdvertisedRoutes(node *types.Node) (types.Routes, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
return hsdb.getNodeAdvertisedRoutes(node)
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) getNodeAdvertisedRoutes(node *types.Node) (types.Routes, error) {
|
||||
func GetNodeAdvertisedRoutes(tx *gorm.DB, node *types.Node) (types.Routes, error) {
|
||||
var routes types.Routes
|
||||
err := hsdb.db.
|
||||
err := tx.
|
||||
Preload("Node").
|
||||
Preload("Node.User").
|
||||
Where("node_id = ? AND advertised = true", node.ID).
|
||||
|
@ -84,15 +69,14 @@ func (hsdb *HSDatabase) getNodeAdvertisedRoutes(node *types.Node) (types.Routes,
|
|||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetNodeRoutes(node *types.Node) (types.Routes, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
return hsdb.getNodeRoutes(node)
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (types.Routes, error) {
|
||||
return GetNodeRoutes(rx, node)
|
||||
})
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) getNodeRoutes(node *types.Node) (types.Routes, error) {
|
||||
func GetNodeRoutes(tx *gorm.DB, node *types.Node) (types.Routes, error) {
|
||||
var routes types.Routes
|
||||
err := hsdb.db.
|
||||
err := tx.
|
||||
Preload("Node").
|
||||
Preload("Node.User").
|
||||
Where("node_id = ?", node.ID).
|
||||
|
@ -104,16 +88,9 @@ func (hsdb *HSDatabase) getNodeRoutes(node *types.Node) (types.Routes, error) {
|
|||
return routes, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetRoute(id uint64) (*types.Route, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
return hsdb.getRoute(id)
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) getRoute(id uint64) (*types.Route, error) {
|
||||
func GetRoute(tx *gorm.DB, id uint64) (*types.Route, error) {
|
||||
var route types.Route
|
||||
err := hsdb.db.
|
||||
err := tx.
|
||||
Preload("Node").
|
||||
Preload("Node.User").
|
||||
First(&route, id).Error
|
||||
|
@ -124,40 +101,34 @@ func (hsdb *HSDatabase) getRoute(id uint64) (*types.Route, error) {
|
|||
return &route, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) EnableRoute(id uint64) error {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
|
||||
return hsdb.enableRoute(id)
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) enableRoute(id uint64) error {
|
||||
route, err := hsdb.getRoute(id)
|
||||
func EnableRoute(tx *gorm.DB, id uint64) (*types.StateUpdate, error) {
|
||||
route, err := GetRoute(tx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Tailscale requires both IPv4 and IPv6 exit routes to
|
||||
// be enabled at the same time, as per
|
||||
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
|
||||
if route.IsExitRoute() {
|
||||
return hsdb.enableRoutes(
|
||||
return enableRoutes(
|
||||
tx,
|
||||
&route.Node,
|
||||
types.ExitRouteV4.String(),
|
||||
types.ExitRouteV6.String(),
|
||||
)
|
||||
}
|
||||
|
||||
return hsdb.enableRoutes(&route.Node, netip.Prefix(route.Prefix).String())
|
||||
return enableRoutes(tx, &route.Node, netip.Prefix(route.Prefix).String())
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) DisableRoute(id uint64) error {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
|
||||
route, err := hsdb.getRoute(id)
|
||||
func DisableRoute(tx *gorm.DB,
|
||||
id uint64,
|
||||
isConnected map[key.MachinePublic]bool,
|
||||
) (*types.StateUpdate, error) {
|
||||
route, err := GetRoute(tx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var routes types.Routes
|
||||
|
@ -166,64 +137,79 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error {
|
|||
// Tailscale requires both IPv4 and IPv6 exit routes to
|
||||
// be enabled at the same time, as per
|
||||
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
|
||||
var update *types.StateUpdate
|
||||
if !route.IsExitRoute() {
|
||||
err = hsdb.failoverRouteWithNotify(route)
|
||||
update, err = failoverRouteReturnUpdate(tx, isConnected, route)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
route.Enabled = false
|
||||
route.IsPrimary = false
|
||||
err = hsdb.db.Save(route).Error
|
||||
err = tx.Save(route).Error
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
routes, err = hsdb.getNodeRoutes(&node)
|
||||
routes, err = GetNodeRoutes(tx, &node)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := range routes {
|
||||
if routes[i].IsExitRoute() {
|
||||
routes[i].Enabled = false
|
||||
routes[i].IsPrimary = false
|
||||
err = hsdb.db.Save(&routes[i]).Error
|
||||
err = tx.Save(&routes[i]).Error
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if routes == nil {
|
||||
routes, err = hsdb.getNodeRoutes(&node)
|
||||
routes, err = GetNodeRoutes(tx, &node)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
node.Routes = routes
|
||||
|
||||
stateUpdate := types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: types.Nodes{&node},
|
||||
Message: "called from db.DisableRoute",
|
||||
}
|
||||
if stateUpdate.Valid() {
|
||||
hsdb.notifier.NotifyAll(stateUpdate)
|
||||
// If update is empty, it means that one was not created
|
||||
// by failover (as a failover was not necessary), create
|
||||
// one and return to the caller.
|
||||
if update == nil {
|
||||
update = &types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: types.Nodes{
|
||||
&node,
|
||||
},
|
||||
Message: "called from db.DisableRoute",
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return update, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
func (hsdb *HSDatabase) DeleteRoute(
|
||||
id uint64,
|
||||
isConnected map[key.MachinePublic]bool,
|
||||
) (*types.StateUpdate, error) {
|
||||
return Write(hsdb.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
|
||||
return DeleteRoute(tx, id, isConnected)
|
||||
})
|
||||
}
|
||||
|
||||
route, err := hsdb.getRoute(id)
|
||||
func DeleteRoute(
|
||||
tx *gorm.DB,
|
||||
id uint64,
|
||||
isConnected map[key.MachinePublic]bool,
|
||||
) (*types.StateUpdate, error) {
|
||||
route, err := GetRoute(tx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var routes types.Routes
|
||||
|
@ -232,19 +218,20 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
|
|||
// Tailscale requires both IPv4 and IPv6 exit routes to
|
||||
// be enabled at the same time, as per
|
||||
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
|
||||
var update *types.StateUpdate
|
||||
if !route.IsExitRoute() {
|
||||
err := hsdb.failoverRouteWithNotify(route)
|
||||
update, err = failoverRouteReturnUpdate(tx, isConnected, route)
|
||||
if err != nil {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if err := hsdb.db.Unscoped().Delete(&route).Error; err != nil {
|
||||
return err
|
||||
if err := tx.Unscoped().Delete(&route).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
routes, err := hsdb.getNodeRoutes(&node)
|
||||
routes, err := GetNodeRoutes(tx, &node)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
routesToDelete := types.Routes{}
|
||||
|
@ -254,56 +241,59 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
|
|||
}
|
||||
}
|
||||
|
||||
if err := hsdb.db.Unscoped().Delete(&routesToDelete).Error; err != nil {
|
||||
return err
|
||||
if err := tx.Unscoped().Delete(&routesToDelete).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// If update is empty, it means that one was not created
|
||||
// by failover (as a failover was not necessary), create
|
||||
// one and return to the caller.
|
||||
if routes == nil {
|
||||
routes, err = hsdb.getNodeRoutes(&node)
|
||||
routes, err = GetNodeRoutes(tx, &node)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
node.Routes = routes
|
||||
|
||||
stateUpdate := types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: types.Nodes{&node},
|
||||
Message: "called from db.DeleteRoute",
|
||||
}
|
||||
if stateUpdate.Valid() {
|
||||
hsdb.notifier.NotifyAll(stateUpdate)
|
||||
if update == nil {
|
||||
update = &types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: types.Nodes{
|
||||
&node,
|
||||
},
|
||||
Message: "called from db.DeleteRoute",
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return update, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) deleteNodeRoutes(node *types.Node) error {
|
||||
routes, err := hsdb.getNodeRoutes(node)
|
||||
func deleteNodeRoutes(tx *gorm.DB, node *types.Node, isConnected map[key.MachinePublic]bool) error {
|
||||
routes, err := GetNodeRoutes(tx, node)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i := range routes {
|
||||
if err := hsdb.db.Unscoped().Delete(&routes[i]).Error; err != nil {
|
||||
if err := tx.Unscoped().Delete(&routes[i]).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO(kradalby): This is a bit too aggressive, we could probably
|
||||
// figure out which routes needs to be failed over rather than all.
|
||||
hsdb.failoverRouteWithNotify(&routes[i])
|
||||
failoverRouteReturnUpdate(tx, isConnected, &routes[i])
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isUniquePrefix returns if there is another node providing the same route already.
|
||||
func (hsdb *HSDatabase) isUniquePrefix(route types.Route) bool {
|
||||
func isUniquePrefix(tx *gorm.DB, route types.Route) bool {
|
||||
var count int64
|
||||
hsdb.db.
|
||||
Model(&types.Route{}).
|
||||
tx.Model(&types.Route{}).
|
||||
Where("prefix = ? AND node_id != ? AND advertised = ? AND enabled = ?",
|
||||
route.Prefix,
|
||||
route.NodeID,
|
||||
|
@ -312,9 +302,9 @@ func (hsdb *HSDatabase) isUniquePrefix(route types.Route) bool {
|
|||
return count == 0
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*types.Route, error) {
|
||||
func getPrimaryRoute(tx *gorm.DB, prefix netip.Prefix) (*types.Route, error) {
|
||||
var route types.Route
|
||||
err := hsdb.db.
|
||||
err := tx.
|
||||
Preload("Node").
|
||||
Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", types.IPPrefix(prefix), true, true, true).
|
||||
First(&route).Error
|
||||
|
@ -329,14 +319,17 @@ func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*types.Route, erro
|
|||
return &route, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetNodePrimaryRoutes(node *types.Node) (types.Routes, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (types.Routes, error) {
|
||||
return GetNodePrimaryRoutes(rx, node)
|
||||
})
|
||||
}
|
||||
|
||||
// getNodePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover)
|
||||
// Exit nodes are not considered for this, as they are never marked as Primary.
|
||||
func (hsdb *HSDatabase) GetNodePrimaryRoutes(node *types.Node) (types.Routes, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
func GetNodePrimaryRoutes(tx *gorm.DB, node *types.Node) (types.Routes, error) {
|
||||
var routes types.Routes
|
||||
err := hsdb.db.
|
||||
err := tx.
|
||||
Preload("Node").
|
||||
Where("node_id = ? AND advertised = ? AND enabled = ? AND is_primary = ?", node.ID, true, true, true).
|
||||
Find(&routes).Error
|
||||
|
@ -347,22 +340,21 @@ func (hsdb *HSDatabase) GetNodePrimaryRoutes(node *types.Node) (types.Routes, er
|
|||
return routes, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) SaveNodeRoutes(node *types.Node) (bool, error) {
|
||||
return Write(hsdb.DB, func(tx *gorm.DB) (bool, error) {
|
||||
return SaveNodeRoutes(tx, node)
|
||||
})
|
||||
}
|
||||
|
||||
// SaveNodeRoutes takes a node and updates the database with
|
||||
// the new routes.
|
||||
// It returns a bool whether an update should be sent as the
|
||||
// saved route impacts nodes.
|
||||
func (hsdb *HSDatabase) SaveNodeRoutes(node *types.Node) (bool, error) {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
|
||||
return hsdb.saveNodeRoutes(node)
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) {
|
||||
func SaveNodeRoutes(tx *gorm.DB, node *types.Node) (bool, error) {
|
||||
sendUpdate := false
|
||||
|
||||
currentRoutes := types.Routes{}
|
||||
err := hsdb.db.Where("node_id = ?", node.ID).Find(¤tRoutes).Error
|
||||
err := tx.Where("node_id = ?", node.ID).Find(¤tRoutes).Error
|
||||
if err != nil {
|
||||
return sendUpdate, err
|
||||
}
|
||||
|
@ -382,7 +374,7 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) {
|
|||
if _, ok := advertisedRoutes[netip.Prefix(route.Prefix)]; ok {
|
||||
if !route.Advertised {
|
||||
currentRoutes[pos].Advertised = true
|
||||
err := hsdb.db.Save(¤tRoutes[pos]).Error
|
||||
err := tx.Save(¤tRoutes[pos]).Error
|
||||
if err != nil {
|
||||
return sendUpdate, err
|
||||
}
|
||||
|
@ -398,7 +390,7 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) {
|
|||
} else if route.Advertised {
|
||||
currentRoutes[pos].Advertised = false
|
||||
currentRoutes[pos].Enabled = false
|
||||
err := hsdb.db.Save(¤tRoutes[pos]).Error
|
||||
err := tx.Save(¤tRoutes[pos]).Error
|
||||
if err != nil {
|
||||
return sendUpdate, err
|
||||
}
|
||||
|
@ -413,7 +405,7 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) {
|
|||
Advertised: true,
|
||||
Enabled: false,
|
||||
}
|
||||
err := hsdb.db.Create(&route).Error
|
||||
err := tx.Create(&route).Error
|
||||
if err != nil {
|
||||
return sendUpdate, err
|
||||
}
|
||||
|
@ -425,127 +417,89 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) {
|
|||
|
||||
// EnsureFailoverRouteIsAvailable takes a node and checks if the node's route
|
||||
// currently have a functioning host that exposes the network.
|
||||
func (hsdb *HSDatabase) EnsureFailoverRouteIsAvailable(node *types.Node) error {
|
||||
nodeRoutes, err := hsdb.getNodeRoutes(node)
|
||||
func EnsureFailoverRouteIsAvailable(
|
||||
tx *gorm.DB,
|
||||
isConnected map[key.MachinePublic]bool,
|
||||
node *types.Node,
|
||||
) (*types.StateUpdate, error) {
|
||||
nodeRoutes, err := GetNodeRoutes(tx, node)
|
||||
if err != nil {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var changedNodes types.Nodes
|
||||
for _, nodeRoute := range nodeRoutes {
|
||||
routes, err := hsdb.getRoutesByPrefix(netip.Prefix(nodeRoute.Prefix))
|
||||
routes, err := getRoutesByPrefix(tx, netip.Prefix(nodeRoute.Prefix))
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, route := range routes {
|
||||
if route.IsPrimary {
|
||||
// if we have a primary route, and the node is connected
|
||||
// nothing needs to be done.
|
||||
if hsdb.notifier.IsConnected(route.Node.MachineKey) {
|
||||
if isConnected[route.Node.MachineKey] {
|
||||
continue
|
||||
}
|
||||
|
||||
// if not, we need to failover the route
|
||||
err := hsdb.failoverRouteWithNotify(&route)
|
||||
update, err := failoverRouteReturnUpdate(tx, isConnected, &route)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if update != nil {
|
||||
changedNodes = append(changedNodes, update.ChangeNodes...)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) FailoverNodeRoutesWithNotify(node *types.Node) error {
|
||||
routes, err := hsdb.getNodeRoutes(node)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var changedKeys []key.MachinePublic
|
||||
|
||||
for _, route := range routes {
|
||||
changed, err := hsdb.failoverRoute(&route)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
changedKeys = append(changedKeys, changed...)
|
||||
}
|
||||
|
||||
changedKeys = lo.Uniq(changedKeys)
|
||||
|
||||
var nodes types.Nodes
|
||||
|
||||
for _, key := range changedKeys {
|
||||
node, err := hsdb.GetNodeByMachineKey(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nodes = append(nodes, node)
|
||||
}
|
||||
|
||||
if nodes != nil {
|
||||
stateUpdate := types.StateUpdate{
|
||||
if len(changedNodes) != 0 {
|
||||
return &types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: nodes,
|
||||
Message: "called from db.FailoverNodeRoutesWithNotify",
|
||||
}
|
||||
if stateUpdate.Valid() {
|
||||
hsdb.notifier.NotifyAll(stateUpdate)
|
||||
}
|
||||
ChangeNodes: changedNodes,
|
||||
Message: "called from db.EnsureFailoverRouteIsAvailable",
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) failoverRouteWithNotify(r *types.Route) error {
|
||||
changedKeys, err := hsdb.failoverRoute(r)
|
||||
func failoverRouteReturnUpdate(
|
||||
tx *gorm.DB,
|
||||
isConnected map[key.MachinePublic]bool,
|
||||
r *types.Route,
|
||||
) (*types.StateUpdate, error) {
|
||||
changedKeys, err := failoverRoute(tx, isConnected, r)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
Interface("isConnected", isConnected).
|
||||
Interface("changedKeys", changedKeys).
|
||||
Msg("building route failover")
|
||||
|
||||
if len(changedKeys) == 0 {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var nodes types.Nodes
|
||||
|
||||
log.Trace().
|
||||
Str("hostname", r.Node.Hostname).
|
||||
Msg("loading machines with new primary routes from db")
|
||||
|
||||
for _, key := range changedKeys {
|
||||
node, err := hsdb.getNodeByMachineKey(key)
|
||||
node, err := GetNodeByMachineKey(tx, key)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nodes = append(nodes, node)
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
Str("hostname", r.Node.Hostname).
|
||||
Msg("notifying peers about primary route change")
|
||||
|
||||
if nodes != nil {
|
||||
stateUpdate := types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: nodes,
|
||||
Message: "called from db.failoverRouteWithNotify",
|
||||
}
|
||||
if stateUpdate.Valid() {
|
||||
hsdb.notifier.NotifyAll(stateUpdate)
|
||||
}
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
Str("hostname", r.Node.Hostname).
|
||||
Msg("notified peers about primary route change")
|
||||
|
||||
return nil
|
||||
return &types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: nodes,
|
||||
Message: "called from db.failoverRouteReturnUpdate",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// failoverRoute takes a route that is no longer available,
|
||||
|
@ -556,12 +510,16 @@ func (hsdb *HSDatabase) failoverRouteWithNotify(r *types.Route) error {
|
|||
//
|
||||
// and tries to find a new route to take over its place.
|
||||
// If the given route was not primary, it returns early.
|
||||
func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, error) {
|
||||
func failoverRoute(
|
||||
tx *gorm.DB,
|
||||
isConnected map[key.MachinePublic]bool,
|
||||
r *types.Route,
|
||||
) ([]key.MachinePublic, error) {
|
||||
if r == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// This route is not a primary route, and it isnt
|
||||
// This route is not a primary route, and it is not
|
||||
// being served to nodes.
|
||||
if !r.IsPrimary {
|
||||
return nil, nil
|
||||
|
@ -572,7 +530,7 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
routes, err := hsdb.getRoutesByPrefix(netip.Prefix(r.Prefix))
|
||||
routes, err := getRoutesByPrefix(tx, netip.Prefix(r.Prefix))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -589,14 +547,14 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro
|
|||
continue
|
||||
}
|
||||
|
||||
if hsdb.notifier.IsConnected(route.Node.MachineKey) {
|
||||
if isConnected[route.Node.MachineKey] {
|
||||
newPrimary = &routes[idx]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If a new route was not found/available,
|
||||
// return with an error.
|
||||
// return without an error.
|
||||
// We do not want to update the database as
|
||||
// the one currently marked as primary is the
|
||||
// best we got.
|
||||
|
@ -610,7 +568,7 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro
|
|||
|
||||
// Remove primary from the old route
|
||||
r.IsPrimary = false
|
||||
err = hsdb.db.Save(&r).Error
|
||||
err = tx.Save(&r).Error
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("error disabling new primary route")
|
||||
|
||||
|
@ -623,7 +581,7 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro
|
|||
|
||||
// Set primary for the new primary
|
||||
newPrimary.IsPrimary = true
|
||||
err = hsdb.db.Save(&newPrimary).Error
|
||||
err = tx.Save(&newPrimary).Error
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("error enabling new primary route")
|
||||
|
||||
|
@ -638,25 +596,26 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro
|
|||
return []key.MachinePublic{r.Node.MachineKey, newPrimary.Node.MachineKey}, nil
|
||||
}
|
||||
|
||||
// EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy.
|
||||
func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
|
||||
aclPolicy *policy.ACLPolicy,
|
||||
node *types.Node,
|
||||
) error {
|
||||
if len(aclPolicy.AutoApprovers.ExitNode) == 0 && len(aclPolicy.AutoApprovers.Routes) == 0 {
|
||||
// No autoapprovers configured
|
||||
return nil
|
||||
}
|
||||
) (*types.StateUpdate, error) {
|
||||
return Write(hsdb.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
|
||||
return EnableAutoApprovedRoutes(tx, aclPolicy, node)
|
||||
})
|
||||
}
|
||||
|
||||
// EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy.
|
||||
func EnableAutoApprovedRoutes(
|
||||
tx *gorm.DB,
|
||||
aclPolicy *policy.ACLPolicy,
|
||||
node *types.Node,
|
||||
) (*types.StateUpdate, error) {
|
||||
if len(node.IPAddresses) == 0 {
|
||||
// This node has no IPAddresses, so can't possibly match any autoApprovers ACLs
|
||||
return nil
|
||||
return nil, nil // This node has no IPAddresses, so can't possibly match any autoApprovers ACLs
|
||||
}
|
||||
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
|
||||
routes, err := hsdb.getNodeAdvertisedRoutes(node)
|
||||
routes, err := GetNodeAdvertisedRoutes(tx, node)
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
@ -664,7 +623,7 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
|
|||
Str("node", node.Hostname).
|
||||
Msg("Could not get advertised routes for node")
|
||||
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Trace().Interface("routes", routes).Msg("routes for autoapproving")
|
||||
|
@ -685,7 +644,7 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
|
|||
Uint64("nodeId", node.ID).
|
||||
Msg("Failed to resolve autoApprovers for advertised route")
|
||||
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
|
@ -706,7 +665,7 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
|
|||
Str("alias", approvedAlias).
|
||||
Msg("Failed to expand alias when processing autoApprovers policy")
|
||||
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// approvedIPs should contain all of node's IPs if it matches the rule, so check for first
|
||||
|
@ -717,17 +676,25 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
|
|||
}
|
||||
}
|
||||
|
||||
update := &types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: types.Nodes{},
|
||||
Message: "created in db.EnableAutoApprovedRoutes",
|
||||
}
|
||||
|
||||
for _, approvedRoute := range approvedRoutes {
|
||||
err := hsdb.enableRoute(uint64(approvedRoute.ID))
|
||||
perHostUpdate, err := EnableRoute(tx, uint64(approvedRoute.ID))
|
||||
if err != nil {
|
||||
log.Err(err).
|
||||
Str("approvedRoute", approvedRoute.String()).
|
||||
Uint64("nodeId", node.ID).
|
||||
Msg("Failed to enable approved route")
|
||||
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
update.ChangeNodes = append(update.ChangeNodes, perHostUpdate.ChangeNodes...)
|
||||
}
|
||||
|
||||
return nil
|
||||
return update, nil
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue