introduce rw lock for db, ish...
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
a1a3ff4ba8
commit
eff529f2c5
12 changed files with 369 additions and 156 deletions
|
@ -13,6 +13,13 @@ import (
|
|||
var ErrRouteIsNotAvailable = errors.New("route is not available")
|
||||
|
||||
func (hsdb *HSDatabase) GetRoutes() (types.Routes, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
return hsdb.getRoutes()
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) getRoutes() (types.Routes, error) {
|
||||
var routes types.Routes
|
||||
err := hsdb.db.Preload("Machine").Find(&routes).Error
|
||||
if err != nil {
|
||||
|
@ -23,6 +30,13 @@ func (hsdb *HSDatabase) GetRoutes() (types.Routes, error) {
|
|||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetMachineAdvertisedRoutes(machine *types.Machine) (types.Routes, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
return hsdb.getMachineAdvertisedRoutes(machine)
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) getMachineAdvertisedRoutes(machine *types.Machine) (types.Routes, error) {
|
||||
var routes types.Routes
|
||||
err := hsdb.db.
|
||||
Preload("Machine").
|
||||
|
@ -36,6 +50,13 @@ func (hsdb *HSDatabase) GetMachineAdvertisedRoutes(machine *types.Machine) (type
|
|||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetMachineRoutes(m *types.Machine) (types.Routes, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
return hsdb.getMachineRoutes(m)
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) getMachineRoutes(m *types.Machine) (types.Routes, error) {
|
||||
var routes types.Routes
|
||||
err := hsdb.db.
|
||||
Preload("Machine").
|
||||
|
@ -49,6 +70,13 @@ func (hsdb *HSDatabase) GetMachineRoutes(m *types.Machine) (types.Routes, error)
|
|||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetRoute(id uint64) (*types.Route, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
return hsdb.getRoute(id)
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) getRoute(id uint64) (*types.Route, error) {
|
||||
var route types.Route
|
||||
err := hsdb.db.Preload("Machine").First(&route, id).Error
|
||||
if err != nil {
|
||||
|
@ -59,7 +87,14 @@ func (hsdb *HSDatabase) GetRoute(id uint64) (*types.Route, error) {
|
|||
}
|
||||
|
||||
func (hsdb *HSDatabase) EnableRoute(id uint64) error {
|
||||
route, err := hsdb.GetRoute(id)
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
|
||||
return hsdb.enableRoute(id)
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) enableRoute(id uint64) error {
|
||||
route, err := hsdb.getRoute(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -79,7 +114,10 @@ func (hsdb *HSDatabase) EnableRoute(id uint64) error {
|
|||
}
|
||||
|
||||
func (hsdb *HSDatabase) DisableRoute(id uint64) error {
|
||||
route, err := hsdb.GetRoute(id)
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
|
||||
route, err := hsdb.getRoute(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -95,10 +133,10 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error {
|
|||
return err
|
||||
}
|
||||
|
||||
return hsdb.HandlePrimarySubnetFailover()
|
||||
return hsdb.handlePrimarySubnetFailover()
|
||||
}
|
||||
|
||||
routes, err := hsdb.GetMachineRoutes(&route.Machine)
|
||||
routes, err := hsdb.getMachineRoutes(&route.Machine)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -114,11 +152,14 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error {
|
|||
}
|
||||
}
|
||||
|
||||
return hsdb.HandlePrimarySubnetFailover()
|
||||
return hsdb.handlePrimarySubnetFailover()
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
|
||||
route, err := hsdb.GetRoute(id)
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
|
||||
route, err := hsdb.getRoute(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -131,10 +172,10 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
|
|||
return err
|
||||
}
|
||||
|
||||
return hsdb.HandlePrimarySubnetFailover()
|
||||
return hsdb.handlePrimarySubnetFailover()
|
||||
}
|
||||
|
||||
routes, err := hsdb.GetMachineRoutes(&route.Machine)
|
||||
routes, err := hsdb.getMachineRoutes(&route.Machine)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -150,11 +191,11 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
|
|||
return err
|
||||
}
|
||||
|
||||
return hsdb.HandlePrimarySubnetFailover()
|
||||
return hsdb.handlePrimarySubnetFailover()
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) DeleteMachineRoutes(m *types.Machine) error {
|
||||
routes, err := hsdb.GetMachineRoutes(m)
|
||||
func (hsdb *HSDatabase) deleteMachineRoutes(m *types.Machine) error {
|
||||
routes, err := hsdb.getMachineRoutes(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -165,7 +206,7 @@ func (hsdb *HSDatabase) DeleteMachineRoutes(m *types.Machine) error {
|
|||
}
|
||||
}
|
||||
|
||||
return hsdb.HandlePrimarySubnetFailover()
|
||||
return hsdb.handlePrimarySubnetFailover()
|
||||
}
|
||||
|
||||
// isUniquePrefix returns if there is another machine providing the same route already.
|
||||
|
@ -201,6 +242,9 @@ func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*types.Route, erro
|
|||
// getMachinePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover)
|
||||
// Exit nodes are not considered for this, as they are never marked as Primary.
|
||||
func (hsdb *HSDatabase) GetMachinePrimaryRoutes(m *types.Machine) (types.Routes, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
var routes types.Routes
|
||||
err := hsdb.db.
|
||||
Preload("Machine").
|
||||
|
@ -214,6 +258,13 @@ func (hsdb *HSDatabase) GetMachinePrimaryRoutes(m *types.Machine) (types.Routes,
|
|||
}
|
||||
|
||||
func (hsdb *HSDatabase) ProcessMachineRoutes(machine *types.Machine) error {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
|
||||
return hsdb.processMachineRoutes(machine)
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) processMachineRoutes(machine *types.Machine) error {
|
||||
currentRoutes := types.Routes{}
|
||||
err := hsdb.db.Where("machine_id = ?", machine.ID).Find(¤tRoutes).Error
|
||||
if err != nil {
|
||||
|
@ -264,6 +315,13 @@ func (hsdb *HSDatabase) ProcessMachineRoutes(machine *types.Machine) error {
|
|||
}
|
||||
|
||||
func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
|
||||
return hsdb.handlePrimarySubnetFailover()
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) handlePrimarySubnetFailover() error {
|
||||
// first, get all the enabled routes
|
||||
var routes types.Routes
|
||||
err := hsdb.db.
|
||||
|
@ -388,11 +446,14 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
|
|||
aclPolicy *policy.ACLPolicy,
|
||||
machine *types.Machine,
|
||||
) error {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
|
||||
if len(machine.IPAddresses) == 0 {
|
||||
return nil // This machine has no IPAddresses, so can't possibly match any autoApprovers ACLs
|
||||
}
|
||||
|
||||
routes, err := hsdb.GetMachineAdvertisedRoutes(machine)
|
||||
routes, err := hsdb.getMachineAdvertisedRoutes(machine)
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
@ -445,7 +506,7 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
|
|||
}
|
||||
|
||||
for _, approvedRoute := range approvedRoutes {
|
||||
err := hsdb.EnableRoute(uint64(approvedRoute.ID))
|
||||
err := hsdb.enableRoute(uint64(approvedRoute.ID))
|
||||
if err != nil {
|
||||
log.Err(err).
|
||||
Str("approvedRoute", approvedRoute.String()).
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue