create DB struct

This is step one in detaching the Database layer from Headscale (h). The
ultimate goal is to have all function that does database operations in
its own package, and keep the business logic and writing separate.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2023-05-11 09:09:18 +02:00 committed by Kristoffer Dalby
parent b01f1f1867
commit 14e29a7bee
48 changed files with 1731 additions and 1572 deletions

View file

@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"net/netip"
"sync"
"time"
"github.com/glebarez/sqlite"
@ -19,55 +20,90 @@ import (
const (
dbVersion = "1"
Postgres = "postgres"
Sqlite = "sqlite3"
)
errValueNotFound = Error("not found")
ErrCannotParsePrefix = Error("cannot parse prefix")
var (
errValueNotFound = errors.New("not found")
ErrCannotParsePrefix = errors.New("cannot parse prefix")
errDatabaseNotSupported = errors.New("database type not supported")
)
// KV is a key-value store in a psql table. For future use...
// TODO(kradalby): Is this used for anything?
type KV struct {
Key string
Value string
}
func (h *Headscale) initDB() error {
db, err := h.openDB()
type HSDatabase struct {
db *gorm.DB
notifyStateChan chan<- struct{}
ipAllocationMutex sync.Mutex
ipPrefixes []netip.Prefix
baseDomain string
stripEmailDomain bool
}
// TODO(kradalby): assemble this struct from toptions or something typed
// rather than arguments.
func NewHeadscaleDatabase(
dbType, connectionAddr string,
stripEmailDomain, debug bool,
notifyStateChan chan<- struct{},
ipPrefixes []netip.Prefix,
baseDomain string,
) (*HSDatabase, error) {
dbConn, err := openDB(dbType, connectionAddr, debug)
if err != nil {
return err
}
h.db = db
if h.dbType == Postgres {
db.Exec(`create extension if not exists "uuid-ossp";`)
return nil, err
}
_ = db.Migrator().RenameTable("namespaces", "users")
db := HSDatabase{
db: dbConn,
notifyStateChan: notifyStateChan,
err = db.AutoMigrate(&User{})
ipPrefixes: ipPrefixes,
baseDomain: baseDomain,
stripEmailDomain: stripEmailDomain,
}
log.Debug().Msgf("database %#v", dbConn)
if dbType == Postgres {
dbConn.Exec(`create extension if not exists "uuid-ossp";`)
}
_ = dbConn.Migrator().RenameTable("namespaces", "users")
err = dbConn.AutoMigrate(User{})
if err != nil {
return err
return nil, err
}
_ = db.Migrator().RenameColumn(&Machine{}, "namespace_id", "user_id")
_ = db.Migrator().RenameColumn(&PreAuthKey{}, "namespace_id", "user_id")
_ = dbConn.Migrator().RenameColumn(&Machine{}, "namespace_id", "user_id")
_ = dbConn.Migrator().RenameColumn(&PreAuthKey{}, "namespace_id", "user_id")
_ = db.Migrator().RenameColumn(&Machine{}, "ip_address", "ip_addresses")
_ = db.Migrator().RenameColumn(&Machine{}, "name", "hostname")
_ = dbConn.Migrator().RenameColumn(&Machine{}, "ip_address", "ip_addresses")
_ = dbConn.Migrator().RenameColumn(&Machine{}, "name", "hostname")
// GivenName is used as the primary source of DNS names, make sure
// the field is populated and normalized if it was not when the
// machine was registered.
_ = db.Migrator().RenameColumn(&Machine{}, "nickname", "given_name")
_ = dbConn.Migrator().RenameColumn(&Machine{}, "nickname", "given_name")
// If the Machine table has a column for registered,
// find all occourences of "false" and drop them. Then
// remove the column.
if db.Migrator().HasColumn(&Machine{}, "registered") {
if dbConn.Migrator().HasColumn(&Machine{}, "registered") {
log.Info().
Msg(`Database has legacy "registered" column in machine, removing...`)
machines := Machines{}
if err := h.db.Not("registered").Find(&machines).Error; err != nil {
if err := dbConn.Not("registered").Find(&machines).Error; err != nil {
log.Error().Err(err).Msg("Error accessing db")
}
@ -76,7 +112,7 @@ func (h *Headscale) initDB() error {
Str("machine", machine.Hostname).
Str("machine_key", machine.MachineKey).
Msg("Deleting unregistered machine")
if err := h.db.Delete(&Machine{}, machine.ID).Error; err != nil {
if err := dbConn.Delete(&Machine{}, machine.ID).Error; err != nil {
log.Error().
Err(err).
Str("machine", machine.Hostname).
@ -85,18 +121,18 @@ func (h *Headscale) initDB() error {
}
}
err := db.Migrator().DropColumn(&Machine{}, "registered")
err := dbConn.Migrator().DropColumn(&Machine{}, "registered")
if err != nil {
log.Error().Err(err).Msg("Error dropping registered column")
}
}
err = db.AutoMigrate(&Route{})
err = dbConn.AutoMigrate(&Route{})
if err != nil {
return err
return nil, err
}
if db.Migrator().HasColumn(&Machine{}, "enabled_routes") {
if dbConn.Migrator().HasColumn(&Machine{}, "enabled_routes") {
log.Info().Msgf("Database has legacy enabled_routes column in machine, migrating...")
type MachineAux struct {
@ -105,7 +141,7 @@ func (h *Headscale) initDB() error {
}
machinesAux := []MachineAux{}
err := db.Table("machines").Select("id, enabled_routes").Scan(&machinesAux).Error
err := dbConn.Table("machines").Select("id, enabled_routes").Scan(&machinesAux).Error
if err != nil {
log.Fatal().Err(err).Msg("Error accessing db")
}
@ -120,7 +156,7 @@ func (h *Headscale) initDB() error {
continue
}
err = db.Preload("Machine").
err = dbConn.Preload("Machine").
Where("machine_id = ? AND prefix = ?", machine.ID, IPPrefix(prefix)).
First(&Route{}).
Error
@ -138,7 +174,7 @@ func (h *Headscale) initDB() error {
Enabled: true,
Prefix: IPPrefix(prefix),
}
if err := h.db.Create(&route).Error; err != nil {
if err := dbConn.Create(&route).Error; err != nil {
log.Error().Err(err).Msg("Error creating route")
} else {
log.Info().
@ -149,20 +185,20 @@ func (h *Headscale) initDB() error {
}
}
err = db.Migrator().DropColumn(&Machine{}, "enabled_routes")
err = dbConn.Migrator().DropColumn(&Machine{}, "enabled_routes")
if err != nil {
log.Error().Err(err).Msg("Error dropping enabled_routes column")
}
}
err = db.AutoMigrate(&Machine{})
err = dbConn.AutoMigrate(&Machine{})
if err != nil {
return err
return nil, err
}
if db.Migrator().HasColumn(&Machine{}, "given_name") {
if dbConn.Migrator().HasColumn(&Machine{}, "given_name") {
machines := Machines{}
if err := h.db.Find(&machines).Error; err != nil {
if err := dbConn.Find(&machines).Error; err != nil {
log.Error().Err(err).Msg("Error accessing db")
}
@ -170,7 +206,7 @@ func (h *Headscale) initDB() error {
if machine.GivenName == "" {
normalizedHostname, err := NormalizeToFQDNRules(
machine.Hostname,
h.cfg.OIDC.StripEmaildomain,
stripEmailDomain,
)
if err != nil {
log.Error().
@ -180,7 +216,7 @@ func (h *Headscale) initDB() error {
Msg("Failed to normalize machine hostname in DB migration")
}
err = h.RenameMachine(&machines[item], normalizedHostname)
err = db.RenameMachine(&machines[item], normalizedHostname)
if err != nil {
log.Error().
Caller().
@ -192,51 +228,51 @@ func (h *Headscale) initDB() error {
}
}
err = db.AutoMigrate(&KV{})
err = dbConn.AutoMigrate(&KV{})
if err != nil {
return err
return nil, err
}
err = db.AutoMigrate(&PreAuthKey{})
err = dbConn.AutoMigrate(&PreAuthKey{})
if err != nil {
return err
return nil, err
}
err = db.AutoMigrate(&PreAuthKeyACLTag{})
err = dbConn.AutoMigrate(&PreAuthKeyACLTag{})
if err != nil {
return err
return nil, err
}
_ = db.Migrator().DropTable("shared_machines")
_ = dbConn.Migrator().DropTable("shared_machines")
err = db.AutoMigrate(&APIKey{})
err = dbConn.AutoMigrate(&APIKey{})
if err != nil {
return err
return nil, err
}
err = h.setValue("db_version", dbVersion)
// TODO(kradalby): is this needed?
err = db.setValue("db_version", dbVersion)
return err
return &db, err
}
func (h *Headscale) openDB() (*gorm.DB, error) {
var db *gorm.DB
var err error
func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) {
log.Debug().Str("type", dbType).Str("connection", connectionAddr).Msg("opening database")
var log logger.Interface
if h.dbDebug {
log = logger.Default
var dbLogger logger.Interface
if debug {
dbLogger = logger.Default
} else {
log = logger.Default.LogMode(logger.Silent)
dbLogger = logger.Default.LogMode(logger.Silent)
}
switch h.dbType {
switch dbType {
case Sqlite:
db, err = gorm.Open(
sqlite.Open(h.dbString+"?_synchronous=1&_journal_mode=WAL"),
db, err := gorm.Open(
sqlite.Open(connectionAddr+"?_synchronous=1&_journal_mode=WAL"),
&gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
Logger: log,
Logger: dbLogger,
},
)
@ -250,24 +286,30 @@ func (h *Headscale) openDB() (*gorm.DB, error) {
sqlDB.SetMaxOpenConns(1)
sqlDB.SetConnMaxIdleTime(time.Hour)
return db, err
case Postgres:
db, err = gorm.Open(postgres.Open(h.dbString), &gorm.Config{
return gorm.Open(postgres.Open(connectionAddr), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
Logger: log,
Logger: dbLogger,
})
}
if err != nil {
return nil, err
}
return nil, fmt.Errorf(
"database of type %s is not supported: %w",
dbType,
errDatabaseNotSupported,
)
}
return db, nil
func (hsdb *HSDatabase) notifyStateChange() {
hsdb.notifyStateChan <- struct{}{}
}
// getValue returns the value for the given key in KV.
func (h *Headscale) getValue(key string) (string, error) {
func (hsdb *HSDatabase) getValue(key string) (string, error) {
var row KV
if result := h.db.First(&row, "key = ?", key); errors.Is(
if result := hsdb.db.First(&row, "key = ?", key); errors.Is(
result.Error,
gorm.ErrRecordNotFound,
) {
@ -278,34 +320,34 @@ func (h *Headscale) getValue(key string) (string, error) {
}
// setValue sets value for the given key in KV.
func (h *Headscale) setValue(key string, value string) error {
func (hsdb *HSDatabase) setValue(key string, value string) error {
keyValue := KV{
Key: key,
Value: value,
}
if _, err := h.getValue(key); err == nil {
h.db.Model(&keyValue).Where("key = ?", key).Update("value", value)
if _, err := hsdb.getValue(key); err == nil {
hsdb.db.Model(&keyValue).Where("key = ?", key).Update("value", value)
return nil
}
if err := h.db.Create(keyValue).Error; err != nil {
if err := hsdb.db.Create(keyValue).Error; err != nil {
return fmt.Errorf("failed to create key value pair in the database: %w", err)
}
return nil
}
func (h *Headscale) pingDB(ctx context.Context) error {
func (hsdb *HSDatabase) pingDB(ctx context.Context) error {
ctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
db, err := h.db.DB()
sqlDB, err := hsdb.db.DB()
if err != nil {
return err
}
return db.PingContext(ctx)
return sqlDB.PingContext(ctx)
}
// This is a "wrapper" type around tailscales