introduce rw lock for db, ish...

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2023-07-17 13:35:05 +02:00 committed by Kristoffer Dalby
parent a1a3ff4ba8
commit eff529f2c5
12 changed files with 369 additions and 156 deletions

View file

@ -13,6 +13,13 @@ import (
var ErrRouteIsNotAvailable = errors.New("route is not available")
func (hsdb *HSDatabase) GetRoutes() (types.Routes, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.getRoutes()
}
func (hsdb *HSDatabase) getRoutes() (types.Routes, error) {
var routes types.Routes
err := hsdb.db.Preload("Machine").Find(&routes).Error
if err != nil {
@ -23,6 +30,13 @@ func (hsdb *HSDatabase) GetRoutes() (types.Routes, error) {
}
func (hsdb *HSDatabase) GetMachineAdvertisedRoutes(machine *types.Machine) (types.Routes, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.getMachineAdvertisedRoutes(machine)
}
func (hsdb *HSDatabase) getMachineAdvertisedRoutes(machine *types.Machine) (types.Routes, error) {
var routes types.Routes
err := hsdb.db.
Preload("Machine").
@ -36,6 +50,13 @@ func (hsdb *HSDatabase) GetMachineAdvertisedRoutes(machine *types.Machine) (type
}
func (hsdb *HSDatabase) GetMachineRoutes(m *types.Machine) (types.Routes, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.getMachineRoutes(m)
}
func (hsdb *HSDatabase) getMachineRoutes(m *types.Machine) (types.Routes, error) {
var routes types.Routes
err := hsdb.db.
Preload("Machine").
@ -49,6 +70,13 @@ func (hsdb *HSDatabase) GetMachineRoutes(m *types.Machine) (types.Routes, error)
}
func (hsdb *HSDatabase) GetRoute(id uint64) (*types.Route, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.getRoute(id)
}
func (hsdb *HSDatabase) getRoute(id uint64) (*types.Route, error) {
var route types.Route
err := hsdb.db.Preload("Machine").First(&route, id).Error
if err != nil {
@ -59,7 +87,14 @@ func (hsdb *HSDatabase) GetRoute(id uint64) (*types.Route, error) {
}
func (hsdb *HSDatabase) EnableRoute(id uint64) error {
route, err := hsdb.GetRoute(id)
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
return hsdb.enableRoute(id)
}
func (hsdb *HSDatabase) enableRoute(id uint64) error {
route, err := hsdb.getRoute(id)
if err != nil {
return err
}
@ -79,7 +114,10 @@ func (hsdb *HSDatabase) EnableRoute(id uint64) error {
}
func (hsdb *HSDatabase) DisableRoute(id uint64) error {
route, err := hsdb.GetRoute(id)
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
route, err := hsdb.getRoute(id)
if err != nil {
return err
}
@ -95,10 +133,10 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error {
return err
}
return hsdb.HandlePrimarySubnetFailover()
return hsdb.handlePrimarySubnetFailover()
}
routes, err := hsdb.GetMachineRoutes(&route.Machine)
routes, err := hsdb.getMachineRoutes(&route.Machine)
if err != nil {
return err
}
@ -114,11 +152,14 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error {
}
}
return hsdb.HandlePrimarySubnetFailover()
return hsdb.handlePrimarySubnetFailover()
}
func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
route, err := hsdb.GetRoute(id)
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
route, err := hsdb.getRoute(id)
if err != nil {
return err
}
@ -131,10 +172,10 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
return err
}
return hsdb.HandlePrimarySubnetFailover()
return hsdb.handlePrimarySubnetFailover()
}
routes, err := hsdb.GetMachineRoutes(&route.Machine)
routes, err := hsdb.getMachineRoutes(&route.Machine)
if err != nil {
return err
}
@ -150,11 +191,11 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
return err
}
return hsdb.HandlePrimarySubnetFailover()
return hsdb.handlePrimarySubnetFailover()
}
func (hsdb *HSDatabase) DeleteMachineRoutes(m *types.Machine) error {
routes, err := hsdb.GetMachineRoutes(m)
func (hsdb *HSDatabase) deleteMachineRoutes(m *types.Machine) error {
routes, err := hsdb.getMachineRoutes(m)
if err != nil {
return err
}
@ -165,7 +206,7 @@ func (hsdb *HSDatabase) DeleteMachineRoutes(m *types.Machine) error {
}
}
return hsdb.HandlePrimarySubnetFailover()
return hsdb.handlePrimarySubnetFailover()
}
// isUniquePrefix returns if there is another machine providing the same route already.
@ -201,6 +242,9 @@ func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*types.Route, erro
// getMachinePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover)
// Exit nodes are not considered for this, as they are never marked as Primary.
func (hsdb *HSDatabase) GetMachinePrimaryRoutes(m *types.Machine) (types.Routes, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
var routes types.Routes
err := hsdb.db.
Preload("Machine").
@ -214,6 +258,13 @@ func (hsdb *HSDatabase) GetMachinePrimaryRoutes(m *types.Machine) (types.Routes,
}
func (hsdb *HSDatabase) ProcessMachineRoutes(machine *types.Machine) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
return hsdb.processMachineRoutes(machine)
}
func (hsdb *HSDatabase) processMachineRoutes(machine *types.Machine) error {
currentRoutes := types.Routes{}
err := hsdb.db.Where("machine_id = ?", machine.ID).Find(&currentRoutes).Error
if err != nil {
@ -264,6 +315,13 @@ func (hsdb *HSDatabase) ProcessMachineRoutes(machine *types.Machine) error {
}
func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
return hsdb.handlePrimarySubnetFailover()
}
func (hsdb *HSDatabase) handlePrimarySubnetFailover() error {
// first, get all the enabled routes
var routes types.Routes
err := hsdb.db.
@ -388,11 +446,14 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
aclPolicy *policy.ACLPolicy,
machine *types.Machine,
) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
if len(machine.IPAddresses) == 0 {
return nil // This machine has no IPAddresses, so can't possibly match any autoApprovers ACLs
}
routes, err := hsdb.GetMachineAdvertisedRoutes(machine)
routes, err := hsdb.getMachineAdvertisedRoutes(machine)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
log.Error().
Caller().
@ -445,7 +506,7 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
}
for _, approvedRoute := range approvedRoutes {
err := hsdb.EnableRoute(uint64(approvedRoute.ID))
err := hsdb.enableRoute(uint64(approvedRoute.ID))
if err != nil {
log.Err(err).
Str("approvedRoute", approvedRoute.String()).