create DB struct

This is step one in detaching the Database layer from Headscale (h). The
ultimate goal is to have all function that does database operations in
its own package, and keep the business logic and writing separate.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2023-05-11 09:09:18 +02:00 committed by Kristoffer Dalby
parent b01f1f1867
commit 14e29a7bee
48 changed files with 1731 additions and 1572 deletions

View file

@ -9,17 +9,18 @@ import (
"time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"google.golang.org/protobuf/types/known/timestamppb"
"gorm.io/gorm"
"tailscale.com/tailcfg"
)
const (
ErrUserExists = Error("User already exists")
ErrUserNotFound = Error("User not found")
ErrUserStillHasNodes = Error("User not empty: node(s) found")
ErrInvalidUserName = Error("Invalid user name")
var (
ErrUserExists = errors.New("user already exists")
ErrUserNotFound = errors.New("user not found")
ErrUserStillHasNodes = errors.New("user not empty: node(s) found")
ErrInvalidUserName = errors.New("invalid user name")
)
const (
@ -40,17 +41,17 @@ type User struct {
// CreateUser creates a new User. Returns error if could not be created
// or another user already exists.
func (h *Headscale) CreateUser(name string) (*User, error) {
func (hsdb *HSDatabase) CreateUser(name string) (*User, error) {
err := CheckForFQDNRules(name)
if err != nil {
return nil, err
}
user := User{}
if err := h.db.Where("name = ?", name).First(&user).Error; err == nil {
if err := hsdb.db.Where("name = ?", name).First(&user).Error; err == nil {
return nil, ErrUserExists
}
user.Name = name
if err := h.db.Create(&user).Error; err != nil {
if err := hsdb.db.Create(&user).Error; err != nil {
log.Error().
Str("func", "CreateUser").
Err(err).
@ -64,13 +65,13 @@ func (h *Headscale) CreateUser(name string) (*User, error) {
// DestroyUser destroys a User. Returns error if the User does
// not exist or if there are machines associated with it.
func (h *Headscale) DestroyUser(name string) error {
user, err := h.GetUser(name)
func (hsdb *HSDatabase) DestroyUser(name string) error {
user, err := hsdb.GetUser(name)
if err != nil {
return ErrUserNotFound
}
machines, err := h.ListMachinesByUser(name)
machines, err := hsdb.ListMachinesByUser(name)
if err != nil {
return err
}
@ -78,18 +79,18 @@ func (h *Headscale) DestroyUser(name string) error {
return ErrUserStillHasNodes
}
keys, err := h.ListPreAuthKeys(name)
keys, err := hsdb.ListPreAuthKeys(name)
if err != nil {
return err
}
for _, key := range keys {
err = h.DestroyPreAuthKey(key)
err = hsdb.DestroyPreAuthKey(key)
if err != nil {
return err
}
}
if result := h.db.Unscoped().Delete(&user); result.Error != nil {
if result := hsdb.db.Unscoped().Delete(&user); result.Error != nil {
return result.Error
}
@ -98,9 +99,9 @@ func (h *Headscale) DestroyUser(name string) error {
// RenameUser renames a User. Returns error if the User does
// not exist or if another User exists with the new name.
func (h *Headscale) RenameUser(oldName, newName string) error {
func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
var err error
oldUser, err := h.GetUser(oldName)
oldUser, err := hsdb.GetUser(oldName)
if err != nil {
return err
}
@ -108,7 +109,7 @@ func (h *Headscale) RenameUser(oldName, newName string) error {
if err != nil {
return err
}
_, err = h.GetUser(newName)
_, err = hsdb.GetUser(newName)
if err == nil {
return ErrUserExists
}
@ -118,7 +119,7 @@ func (h *Headscale) RenameUser(oldName, newName string) error {
oldUser.Name = newName
if result := h.db.Save(&oldUser); result.Error != nil {
if result := hsdb.db.Save(&oldUser); result.Error != nil {
return result.Error
}
@ -126,9 +127,9 @@ func (h *Headscale) RenameUser(oldName, newName string) error {
}
// GetUser fetches a user by name.
func (h *Headscale) GetUser(name string) (*User, error) {
func (hsdb *HSDatabase) GetUser(name string) (*User, error) {
user := User{}
if result := h.db.First(&user, "name = ?", name); errors.Is(
if result := hsdb.db.First(&user, "name = ?", name); errors.Is(
result.Error,
gorm.ErrRecordNotFound,
) {
@ -139,9 +140,9 @@ func (h *Headscale) GetUser(name string) (*User, error) {
}
// ListUsers gets all the existing users.
func (h *Headscale) ListUsers() ([]User, error) {
func (hsdb *HSDatabase) ListUsers() ([]User, error) {
users := []User{}
if err := h.db.Find(&users).Error; err != nil {
if err := hsdb.db.Find(&users).Error; err != nil {
return nil, err
}
@ -149,18 +150,18 @@ func (h *Headscale) ListUsers() ([]User, error) {
}
// ListMachinesByUser gets all the nodes in a given user.
func (h *Headscale) ListMachinesByUser(name string) ([]Machine, error) {
func (hsdb *HSDatabase) ListMachinesByUser(name string) ([]Machine, error) {
err := CheckForFQDNRules(name)
if err != nil {
return nil, err
}
user, err := h.GetUser(name)
user, err := hsdb.GetUser(name)
if err != nil {
return nil, err
}
machines := []Machine{}
if err := h.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&Machine{UserID: user.ID}).Find(&machines).Error; err != nil {
if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&Machine{UserID: user.ID}).Find(&machines).Error; err != nil {
return nil, err
}
@ -168,17 +169,17 @@ func (h *Headscale) ListMachinesByUser(name string) ([]Machine, error) {
}
// SetMachineUser assigns a Machine to a user.
func (h *Headscale) SetMachineUser(machine *Machine, username string) error {
func (hsdb *HSDatabase) SetMachineUser(machine *Machine, username string) error {
err := CheckForFQDNRules(username)
if err != nil {
return err
}
user, err := h.GetUser(username)
user, err := hsdb.GetUser(username)
if err != nil {
return err
}
machine.User = *user
if result := h.db.Save(&machine); result.Error != nil {
if result := hsdb.db.Save(&machine); result.Error != nil {
return result.Error
}
@ -211,7 +212,7 @@ func (n *User) toTailscaleLogin() *tailcfg.Login {
return &login
}
func (h *Headscale) getMapResponseUserProfiles(
func (hsdb *HSDatabase) getMapResponseUserProfiles(
machine Machine,
peers Machines,
) []tailcfg.UserProfile {
@ -225,8 +226,8 @@ func (h *Headscale) getMapResponseUserProfiles(
for _, user := range userMap {
displayName := user.Name
if h.cfg.BaseDomain != "" {
displayName = fmt.Sprintf("%s@%s", user.Name, h.cfg.BaseDomain)
if hsdb.baseDomain != "" {
displayName = fmt.Sprintf("%s@%s", user.Name, hsdb.baseDomain)
}
profiles = append(profiles,
@ -242,7 +243,7 @@ func (h *Headscale) getMapResponseUserProfiles(
func (n *User) toProto() *v1.User {
return &v1.User{
Id: strconv.FormatUint(uint64(n.ID), Base10),
Id: strconv.FormatUint(uint64(n.ID), util.Base10),
Name: n.Name,
CreatedAt: timestamppb.New(n.CreatedAt),
}