create DB struct

This is step one in detaching the Database layer from Headscale (h). The
ultimate goal is to have all function that does database operations in
its own package, and keep the business logic and writing separate.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2023-05-11 09:09:18 +02:00 committed by Kristoffer Dalby
parent b01f1f1867
commit 14e29a7bee
48 changed files with 1731 additions and 1572 deletions

View file

@ -11,13 +11,10 @@ import (
"gorm.io/gorm"
)
const (
ErrRouteIsNotAvailable = Error("route is not available")
)
var (
ExitRouteV4 = netip.MustParsePrefix("0.0.0.0/0")
ExitRouteV6 = netip.MustParsePrefix("::/0")
ErrRouteIsNotAvailable = errors.New("route is not available")
ExitRouteV4 = netip.MustParsePrefix("0.0.0.0/0")
ExitRouteV6 = netip.MustParsePrefix("::/0")
)
type Route struct {
@ -51,9 +48,9 @@ func (rs Routes) toPrefixes() []netip.Prefix {
return prefixes
}
func (h *Headscale) GetRoutes() ([]Route, error) {
func (hsdb *HSDatabase) GetRoutes() ([]Route, error) {
var routes []Route
err := h.db.Preload("Machine").Find(&routes).Error
err := hsdb.db.Preload("Machine").Find(&routes).Error
if err != nil {
return nil, err
}
@ -61,9 +58,9 @@ func (h *Headscale) GetRoutes() ([]Route, error) {
return routes, nil
}
func (h *Headscale) GetMachineRoutes(m *Machine) ([]Route, error) {
func (hsdb *HSDatabase) GetMachineRoutes(m *Machine) ([]Route, error) {
var routes []Route
err := h.db.
err := hsdb.db.
Preload("Machine").
Where("machine_id = ?", m.ID).
Find(&routes).Error
@ -74,9 +71,9 @@ func (h *Headscale) GetMachineRoutes(m *Machine) ([]Route, error) {
return routes, nil
}
func (h *Headscale) GetRoute(id uint64) (*Route, error) {
func (hsdb *HSDatabase) GetRoute(id uint64) (*Route, error) {
var route Route
err := h.db.Preload("Machine").First(&route, id).Error
err := hsdb.db.Preload("Machine").First(&route, id).Error
if err != nil {
return nil, err
}
@ -84,8 +81,8 @@ func (h *Headscale) GetRoute(id uint64) (*Route, error) {
return &route, nil
}
func (h *Headscale) EnableRoute(id uint64) error {
route, err := h.GetRoute(id)
func (hsdb *HSDatabase) EnableRoute(id uint64) error {
route, err := hsdb.GetRoute(id)
if err != nil {
return err
}
@ -94,14 +91,14 @@ func (h *Headscale) EnableRoute(id uint64) error {
// be enabled at the same time, as per
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
if route.isExitRoute() {
return h.enableRoutes(&route.Machine, ExitRouteV4.String(), ExitRouteV6.String())
return hsdb.enableRoutes(&route.Machine, ExitRouteV4.String(), ExitRouteV6.String())
}
return h.enableRoutes(&route.Machine, netip.Prefix(route.Prefix).String())
return hsdb.enableRoutes(&route.Machine, netip.Prefix(route.Prefix).String())
}
func (h *Headscale) DisableRoute(id uint64) error {
route, err := h.GetRoute(id)
func (hsdb *HSDatabase) DisableRoute(id uint64) error {
route, err := hsdb.GetRoute(id)
if err != nil {
return err
}
@ -112,15 +109,15 @@ func (h *Headscale) DisableRoute(id uint64) error {
if !route.isExitRoute() {
route.Enabled = false
route.IsPrimary = false
err = h.db.Save(route).Error
err = hsdb.db.Save(route).Error
if err != nil {
return err
}
return h.handlePrimarySubnetFailover()
return hsdb.handlePrimarySubnetFailover()
}
routes, err := h.GetMachineRoutes(&route.Machine)
routes, err := hsdb.GetMachineRoutes(&route.Machine)
if err != nil {
return err
}
@ -129,18 +126,18 @@ func (h *Headscale) DisableRoute(id uint64) error {
if routes[i].isExitRoute() {
routes[i].Enabled = false
routes[i].IsPrimary = false
err = h.db.Save(&routes[i]).Error
err = hsdb.db.Save(&routes[i]).Error
if err != nil {
return err
}
}
}
return h.handlePrimarySubnetFailover()
return hsdb.handlePrimarySubnetFailover()
}
func (h *Headscale) DeleteRoute(id uint64) error {
route, err := h.GetRoute(id)
func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
route, err := hsdb.GetRoute(id)
if err != nil {
return err
}
@ -149,14 +146,14 @@ func (h *Headscale) DeleteRoute(id uint64) error {
// be enabled at the same time, as per
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
if !route.isExitRoute() {
if err := h.db.Unscoped().Delete(&route).Error; err != nil {
if err := hsdb.db.Unscoped().Delete(&route).Error; err != nil {
return err
}
return h.handlePrimarySubnetFailover()
return hsdb.handlePrimarySubnetFailover()
}
routes, err := h.GetMachineRoutes(&route.Machine)
routes, err := hsdb.GetMachineRoutes(&route.Machine)
if err != nil {
return err
}
@ -168,32 +165,32 @@ func (h *Headscale) DeleteRoute(id uint64) error {
}
}
if err := h.db.Unscoped().Delete(&routesToDelete).Error; err != nil {
if err := hsdb.db.Unscoped().Delete(&routesToDelete).Error; err != nil {
return err
}
return h.handlePrimarySubnetFailover()
return hsdb.handlePrimarySubnetFailover()
}
func (h *Headscale) DeleteMachineRoutes(m *Machine) error {
routes, err := h.GetMachineRoutes(m)
func (hsdb *HSDatabase) DeleteMachineRoutes(m *Machine) error {
routes, err := hsdb.GetMachineRoutes(m)
if err != nil {
return err
}
for i := range routes {
if err := h.db.Unscoped().Delete(&routes[i]).Error; err != nil {
if err := hsdb.db.Unscoped().Delete(&routes[i]).Error; err != nil {
return err
}
}
return h.handlePrimarySubnetFailover()
return hsdb.handlePrimarySubnetFailover()
}
// isUniquePrefix returns if there is another machine providing the same route already.
func (h *Headscale) isUniquePrefix(route Route) bool {
func (hsdb *HSDatabase) isUniquePrefix(route Route) bool {
var count int64
h.db.
hsdb.db.
Model(&Route{}).
Where("prefix = ? AND machine_id != ? AND advertised = ? AND enabled = ?",
route.Prefix,
@ -203,9 +200,9 @@ func (h *Headscale) isUniquePrefix(route Route) bool {
return count == 0
}
func (h *Headscale) getPrimaryRoute(prefix netip.Prefix) (*Route, error) {
func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*Route, error) {
var route Route
err := h.db.
err := hsdb.db.
Preload("Machine").
Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", IPPrefix(prefix), true, true, true).
First(&route).Error
@ -222,9 +219,9 @@ func (h *Headscale) getPrimaryRoute(prefix netip.Prefix) (*Route, error) {
// getMachinePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover)
// Exit nodes are not considered for this, as they are never marked as Primary.
func (h *Headscale) getMachinePrimaryRoutes(m *Machine) ([]Route, error) {
func (hsdb *HSDatabase) getMachinePrimaryRoutes(m *Machine) ([]Route, error) {
var routes []Route
err := h.db.
err := hsdb.db.
Preload("Machine").
Where("machine_id = ? AND advertised = ? AND enabled = ? AND is_primary = ?", m.ID, true, true, true).
Find(&routes).Error
@ -235,9 +232,9 @@ func (h *Headscale) getMachinePrimaryRoutes(m *Machine) ([]Route, error) {
return routes, nil
}
func (h *Headscale) processMachineRoutes(machine *Machine) error {
func (hsdb *HSDatabase) processMachineRoutes(machine *Machine) error {
currentRoutes := []Route{}
err := h.db.Where("machine_id = ?", machine.ID).Find(&currentRoutes).Error
err := hsdb.db.Where("machine_id = ?", machine.ID).Find(&currentRoutes).Error
if err != nil {
return err
}
@ -251,7 +248,7 @@ func (h *Headscale) processMachineRoutes(machine *Machine) error {
if _, ok := advertisedRoutes[netip.Prefix(route.Prefix)]; ok {
if !route.Advertised {
currentRoutes[pos].Advertised = true
err := h.db.Save(&currentRoutes[pos]).Error
err := hsdb.db.Save(&currentRoutes[pos]).Error
if err != nil {
return err
}
@ -260,7 +257,7 @@ func (h *Headscale) processMachineRoutes(machine *Machine) error {
} else if route.Advertised {
currentRoutes[pos].Advertised = false
currentRoutes[pos].Enabled = false
err := h.db.Save(&currentRoutes[pos]).Error
err := hsdb.db.Save(&currentRoutes[pos]).Error
if err != nil {
return err
}
@ -275,7 +272,7 @@ func (h *Headscale) processMachineRoutes(machine *Machine) error {
Advertised: true,
Enabled: false,
}
err := h.db.Create(&route).Error
err := hsdb.db.Create(&route).Error
if err != nil {
return err
}
@ -285,10 +282,10 @@ func (h *Headscale) processMachineRoutes(machine *Machine) error {
return nil
}
func (h *Headscale) handlePrimarySubnetFailover() error {
func (hsdb *HSDatabase) handlePrimarySubnetFailover() error {
// first, get all the enabled routes
var routes []Route
err := h.db.
err := hsdb.db.
Preload("Machine").
Where("advertised = ? AND enabled = ?", true, true).
Find(&routes).Error
@ -303,14 +300,14 @@ func (h *Headscale) handlePrimarySubnetFailover() error {
}
if !route.IsPrimary {
_, err := h.getPrimaryRoute(netip.Prefix(route.Prefix))
if h.isUniquePrefix(route) || errors.Is(err, gorm.ErrRecordNotFound) {
_, err := hsdb.getPrimaryRoute(netip.Prefix(route.Prefix))
if hsdb.isUniquePrefix(route) || errors.Is(err, gorm.ErrRecordNotFound) {
log.Info().
Str("prefix", netip.Prefix(route.Prefix).String()).
Str("machine", route.Machine.GivenName).
Msg("Setting primary route")
routes[pos].IsPrimary = true
err := h.db.Save(&routes[pos]).Error
err := hsdb.db.Save(&routes[pos]).Error
if err != nil {
log.Error().Err(err).Msg("error marking route as primary")
@ -336,7 +333,7 @@ func (h *Headscale) handlePrimarySubnetFailover() error {
// find a new primary route
var newPrimaryRoutes []Route
err := h.db.
err := hsdb.db.
Preload("Machine").
Where("prefix = ? AND machine_id != ? AND advertised = ? AND enabled = ?",
route.Prefix,
@ -375,7 +372,7 @@ func (h *Headscale) handlePrimarySubnetFailover() error {
// disable the old primary route
routes[pos].IsPrimary = false
err = h.db.Save(&routes[pos]).Error
err = hsdb.db.Save(&routes[pos]).Error
if err != nil {
log.Error().Err(err).Msg("error disabling old primary route")
@ -384,7 +381,7 @@ func (h *Headscale) handlePrimarySubnetFailover() error {
// enable the new primary route
newPrimaryRoute.IsPrimary = true
err = h.db.Save(&newPrimaryRoute).Error
err = hsdb.db.Save(&newPrimaryRoute).Error
if err != nil {
log.Error().Err(err).Msg("error enabling new primary route")
@ -396,7 +393,7 @@ func (h *Headscale) handlePrimarySubnetFailover() error {
}
if routesChanged {
h.setLastStateChangeToNow()
hsdb.notifyStateChange()
}
return nil