create DB struct
This is step one in detaching the Database layer from Headscale (h). The ultimate goal is to have all function that does database operations in its own package, and keep the business logic and writing separate. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
b01f1f1867
commit
14e29a7bee
48 changed files with 1731 additions and 1572 deletions
186
hscontrol/db.go
186
hscontrol/db.go
|
@ -7,6 +7,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
|
@ -19,55 +20,90 @@ import (
|
|||
|
||||
const (
|
||||
dbVersion = "1"
|
||||
Postgres = "postgres"
|
||||
Sqlite = "sqlite3"
|
||||
)
|
||||
|
||||
errValueNotFound = Error("not found")
|
||||
ErrCannotParsePrefix = Error("cannot parse prefix")
|
||||
var (
|
||||
errValueNotFound = errors.New("not found")
|
||||
ErrCannotParsePrefix = errors.New("cannot parse prefix")
|
||||
errDatabaseNotSupported = errors.New("database type not supported")
|
||||
)
|
||||
|
||||
// KV is a key-value store in a psql table. For future use...
|
||||
// TODO(kradalby): Is this used for anything?
|
||||
type KV struct {
|
||||
Key string
|
||||
Value string
|
||||
}
|
||||
|
||||
func (h *Headscale) initDB() error {
|
||||
db, err := h.openDB()
|
||||
type HSDatabase struct {
|
||||
db *gorm.DB
|
||||
notifyStateChan chan<- struct{}
|
||||
|
||||
ipAllocationMutex sync.Mutex
|
||||
|
||||
ipPrefixes []netip.Prefix
|
||||
baseDomain string
|
||||
stripEmailDomain bool
|
||||
}
|
||||
|
||||
// TODO(kradalby): assemble this struct from toptions or something typed
|
||||
// rather than arguments.
|
||||
func NewHeadscaleDatabase(
|
||||
dbType, connectionAddr string,
|
||||
stripEmailDomain, debug bool,
|
||||
notifyStateChan chan<- struct{},
|
||||
ipPrefixes []netip.Prefix,
|
||||
baseDomain string,
|
||||
) (*HSDatabase, error) {
|
||||
dbConn, err := openDB(dbType, connectionAddr, debug)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h.db = db
|
||||
|
||||
if h.dbType == Postgres {
|
||||
db.Exec(`create extension if not exists "uuid-ossp";`)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_ = db.Migrator().RenameTable("namespaces", "users")
|
||||
db := HSDatabase{
|
||||
db: dbConn,
|
||||
notifyStateChan: notifyStateChan,
|
||||
|
||||
err = db.AutoMigrate(&User{})
|
||||
ipPrefixes: ipPrefixes,
|
||||
baseDomain: baseDomain,
|
||||
stripEmailDomain: stripEmailDomain,
|
||||
}
|
||||
|
||||
log.Debug().Msgf("database %#v", dbConn)
|
||||
|
||||
if dbType == Postgres {
|
||||
dbConn.Exec(`create extension if not exists "uuid-ossp";`)
|
||||
}
|
||||
|
||||
_ = dbConn.Migrator().RenameTable("namespaces", "users")
|
||||
|
||||
err = dbConn.AutoMigrate(User{})
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_ = db.Migrator().RenameColumn(&Machine{}, "namespace_id", "user_id")
|
||||
_ = db.Migrator().RenameColumn(&PreAuthKey{}, "namespace_id", "user_id")
|
||||
_ = dbConn.Migrator().RenameColumn(&Machine{}, "namespace_id", "user_id")
|
||||
_ = dbConn.Migrator().RenameColumn(&PreAuthKey{}, "namespace_id", "user_id")
|
||||
|
||||
_ = db.Migrator().RenameColumn(&Machine{}, "ip_address", "ip_addresses")
|
||||
_ = db.Migrator().RenameColumn(&Machine{}, "name", "hostname")
|
||||
_ = dbConn.Migrator().RenameColumn(&Machine{}, "ip_address", "ip_addresses")
|
||||
_ = dbConn.Migrator().RenameColumn(&Machine{}, "name", "hostname")
|
||||
|
||||
// GivenName is used as the primary source of DNS names, make sure
|
||||
// the field is populated and normalized if it was not when the
|
||||
// machine was registered.
|
||||
_ = db.Migrator().RenameColumn(&Machine{}, "nickname", "given_name")
|
||||
_ = dbConn.Migrator().RenameColumn(&Machine{}, "nickname", "given_name")
|
||||
|
||||
// If the Machine table has a column for registered,
|
||||
// find all occourences of "false" and drop them. Then
|
||||
// remove the column.
|
||||
if db.Migrator().HasColumn(&Machine{}, "registered") {
|
||||
if dbConn.Migrator().HasColumn(&Machine{}, "registered") {
|
||||
log.Info().
|
||||
Msg(`Database has legacy "registered" column in machine, removing...`)
|
||||
|
||||
machines := Machines{}
|
||||
if err := h.db.Not("registered").Find(&machines).Error; err != nil {
|
||||
if err := dbConn.Not("registered").Find(&machines).Error; err != nil {
|
||||
log.Error().Err(err).Msg("Error accessing db")
|
||||
}
|
||||
|
||||
|
@ -76,7 +112,7 @@ func (h *Headscale) initDB() error {
|
|||
Str("machine", machine.Hostname).
|
||||
Str("machine_key", machine.MachineKey).
|
||||
Msg("Deleting unregistered machine")
|
||||
if err := h.db.Delete(&Machine{}, machine.ID).Error; err != nil {
|
||||
if err := dbConn.Delete(&Machine{}, machine.ID).Error; err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("machine", machine.Hostname).
|
||||
|
@ -85,18 +121,18 @@ func (h *Headscale) initDB() error {
|
|||
}
|
||||
}
|
||||
|
||||
err := db.Migrator().DropColumn(&Machine{}, "registered")
|
||||
err := dbConn.Migrator().DropColumn(&Machine{}, "registered")
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Error dropping registered column")
|
||||
}
|
||||
}
|
||||
|
||||
err = db.AutoMigrate(&Route{})
|
||||
err = dbConn.AutoMigrate(&Route{})
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if db.Migrator().HasColumn(&Machine{}, "enabled_routes") {
|
||||
if dbConn.Migrator().HasColumn(&Machine{}, "enabled_routes") {
|
||||
log.Info().Msgf("Database has legacy enabled_routes column in machine, migrating...")
|
||||
|
||||
type MachineAux struct {
|
||||
|
@ -105,7 +141,7 @@ func (h *Headscale) initDB() error {
|
|||
}
|
||||
|
||||
machinesAux := []MachineAux{}
|
||||
err := db.Table("machines").Select("id, enabled_routes").Scan(&machinesAux).Error
|
||||
err := dbConn.Table("machines").Select("id, enabled_routes").Scan(&machinesAux).Error
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("Error accessing db")
|
||||
}
|
||||
|
@ -120,7 +156,7 @@ func (h *Headscale) initDB() error {
|
|||
continue
|
||||
}
|
||||
|
||||
err = db.Preload("Machine").
|
||||
err = dbConn.Preload("Machine").
|
||||
Where("machine_id = ? AND prefix = ?", machine.ID, IPPrefix(prefix)).
|
||||
First(&Route{}).
|
||||
Error
|
||||
|
@ -138,7 +174,7 @@ func (h *Headscale) initDB() error {
|
|||
Enabled: true,
|
||||
Prefix: IPPrefix(prefix),
|
||||
}
|
||||
if err := h.db.Create(&route).Error; err != nil {
|
||||
if err := dbConn.Create(&route).Error; err != nil {
|
||||
log.Error().Err(err).Msg("Error creating route")
|
||||
} else {
|
||||
log.Info().
|
||||
|
@ -149,20 +185,20 @@ func (h *Headscale) initDB() error {
|
|||
}
|
||||
}
|
||||
|
||||
err = db.Migrator().DropColumn(&Machine{}, "enabled_routes")
|
||||
err = dbConn.Migrator().DropColumn(&Machine{}, "enabled_routes")
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Error dropping enabled_routes column")
|
||||
}
|
||||
}
|
||||
|
||||
err = db.AutoMigrate(&Machine{})
|
||||
err = dbConn.AutoMigrate(&Machine{})
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if db.Migrator().HasColumn(&Machine{}, "given_name") {
|
||||
if dbConn.Migrator().HasColumn(&Machine{}, "given_name") {
|
||||
machines := Machines{}
|
||||
if err := h.db.Find(&machines).Error; err != nil {
|
||||
if err := dbConn.Find(&machines).Error; err != nil {
|
||||
log.Error().Err(err).Msg("Error accessing db")
|
||||
}
|
||||
|
||||
|
@ -170,7 +206,7 @@ func (h *Headscale) initDB() error {
|
|||
if machine.GivenName == "" {
|
||||
normalizedHostname, err := NormalizeToFQDNRules(
|
||||
machine.Hostname,
|
||||
h.cfg.OIDC.StripEmaildomain,
|
||||
stripEmailDomain,
|
||||
)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
|
@ -180,7 +216,7 @@ func (h *Headscale) initDB() error {
|
|||
Msg("Failed to normalize machine hostname in DB migration")
|
||||
}
|
||||
|
||||
err = h.RenameMachine(&machines[item], normalizedHostname)
|
||||
err = db.RenameMachine(&machines[item], normalizedHostname)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
@ -192,51 +228,51 @@ func (h *Headscale) initDB() error {
|
|||
}
|
||||
}
|
||||
|
||||
err = db.AutoMigrate(&KV{})
|
||||
err = dbConn.AutoMigrate(&KV{})
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = db.AutoMigrate(&PreAuthKey{})
|
||||
err = dbConn.AutoMigrate(&PreAuthKey{})
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = db.AutoMigrate(&PreAuthKeyACLTag{})
|
||||
err = dbConn.AutoMigrate(&PreAuthKeyACLTag{})
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_ = db.Migrator().DropTable("shared_machines")
|
||||
_ = dbConn.Migrator().DropTable("shared_machines")
|
||||
|
||||
err = db.AutoMigrate(&APIKey{})
|
||||
err = dbConn.AutoMigrate(&APIKey{})
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = h.setValue("db_version", dbVersion)
|
||||
// TODO(kradalby): is this needed?
|
||||
err = db.setValue("db_version", dbVersion)
|
||||
|
||||
return err
|
||||
return &db, err
|
||||
}
|
||||
|
||||
func (h *Headscale) openDB() (*gorm.DB, error) {
|
||||
var db *gorm.DB
|
||||
var err error
|
||||
func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) {
|
||||
log.Debug().Str("type", dbType).Str("connection", connectionAddr).Msg("opening database")
|
||||
|
||||
var log logger.Interface
|
||||
if h.dbDebug {
|
||||
log = logger.Default
|
||||
var dbLogger logger.Interface
|
||||
if debug {
|
||||
dbLogger = logger.Default
|
||||
} else {
|
||||
log = logger.Default.LogMode(logger.Silent)
|
||||
dbLogger = logger.Default.LogMode(logger.Silent)
|
||||
}
|
||||
|
||||
switch h.dbType {
|
||||
switch dbType {
|
||||
case Sqlite:
|
||||
db, err = gorm.Open(
|
||||
sqlite.Open(h.dbString+"?_synchronous=1&_journal_mode=WAL"),
|
||||
db, err := gorm.Open(
|
||||
sqlite.Open(connectionAddr+"?_synchronous=1&_journal_mode=WAL"),
|
||||
&gorm.Config{
|
||||
DisableForeignKeyConstraintWhenMigrating: true,
|
||||
Logger: log,
|
||||
Logger: dbLogger,
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -250,24 +286,30 @@ func (h *Headscale) openDB() (*gorm.DB, error) {
|
|||
sqlDB.SetMaxOpenConns(1)
|
||||
sqlDB.SetConnMaxIdleTime(time.Hour)
|
||||
|
||||
return db, err
|
||||
|
||||
case Postgres:
|
||||
db, err = gorm.Open(postgres.Open(h.dbString), &gorm.Config{
|
||||
return gorm.Open(postgres.Open(connectionAddr), &gorm.Config{
|
||||
DisableForeignKeyConstraintWhenMigrating: true,
|
||||
Logger: log,
|
||||
Logger: dbLogger,
|
||||
})
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, fmt.Errorf(
|
||||
"database of type %s is not supported: %w",
|
||||
dbType,
|
||||
errDatabaseNotSupported,
|
||||
)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
func (hsdb *HSDatabase) notifyStateChange() {
|
||||
hsdb.notifyStateChan <- struct{}{}
|
||||
}
|
||||
|
||||
// getValue returns the value for the given key in KV.
|
||||
func (h *Headscale) getValue(key string) (string, error) {
|
||||
func (hsdb *HSDatabase) getValue(key string) (string, error) {
|
||||
var row KV
|
||||
if result := h.db.First(&row, "key = ?", key); errors.Is(
|
||||
if result := hsdb.db.First(&row, "key = ?", key); errors.Is(
|
||||
result.Error,
|
||||
gorm.ErrRecordNotFound,
|
||||
) {
|
||||
|
@ -278,34 +320,34 @@ func (h *Headscale) getValue(key string) (string, error) {
|
|||
}
|
||||
|
||||
// setValue sets value for the given key in KV.
|
||||
func (h *Headscale) setValue(key string, value string) error {
|
||||
func (hsdb *HSDatabase) setValue(key string, value string) error {
|
||||
keyValue := KV{
|
||||
Key: key,
|
||||
Value: value,
|
||||
}
|
||||
|
||||
if _, err := h.getValue(key); err == nil {
|
||||
h.db.Model(&keyValue).Where("key = ?", key).Update("value", value)
|
||||
if _, err := hsdb.getValue(key); err == nil {
|
||||
hsdb.db.Model(&keyValue).Where("key = ?", key).Update("value", value)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := h.db.Create(keyValue).Error; err != nil {
|
||||
if err := hsdb.db.Create(keyValue).Error; err != nil {
|
||||
return fmt.Errorf("failed to create key value pair in the database: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Headscale) pingDB(ctx context.Context) error {
|
||||
func (hsdb *HSDatabase) pingDB(ctx context.Context) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second)
|
||||
defer cancel()
|
||||
db, err := h.db.DB()
|
||||
sqlDB, err := hsdb.db.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return db.PingContext(ctx)
|
||||
return sqlDB.PingContext(ctx)
|
||||
}
|
||||
|
||||
// This is a "wrapper" type around tailscales
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue