create DB struct
This is step one in detaching the Database layer from Headscale (h). The ultimate goal is to have all function that does database operations in its own package, and keep the business logic and writing separate. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
b01f1f1867
commit
14e29a7bee
48 changed files with 1731 additions and 1572 deletions
|
@ -9,17 +9,18 @@ import (
|
|||
"time"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
const (
|
||||
ErrUserExists = Error("User already exists")
|
||||
ErrUserNotFound = Error("User not found")
|
||||
ErrUserStillHasNodes = Error("User not empty: node(s) found")
|
||||
ErrInvalidUserName = Error("Invalid user name")
|
||||
var (
|
||||
ErrUserExists = errors.New("user already exists")
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrUserStillHasNodes = errors.New("user not empty: node(s) found")
|
||||
ErrInvalidUserName = errors.New("invalid user name")
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -40,17 +41,17 @@ type User struct {
|
|||
|
||||
// CreateUser creates a new User. Returns error if could not be created
|
||||
// or another user already exists.
|
||||
func (h *Headscale) CreateUser(name string) (*User, error) {
|
||||
func (hsdb *HSDatabase) CreateUser(name string) (*User, error) {
|
||||
err := CheckForFQDNRules(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user := User{}
|
||||
if err := h.db.Where("name = ?", name).First(&user).Error; err == nil {
|
||||
if err := hsdb.db.Where("name = ?", name).First(&user).Error; err == nil {
|
||||
return nil, ErrUserExists
|
||||
}
|
||||
user.Name = name
|
||||
if err := h.db.Create(&user).Error; err != nil {
|
||||
if err := hsdb.db.Create(&user).Error; err != nil {
|
||||
log.Error().
|
||||
Str("func", "CreateUser").
|
||||
Err(err).
|
||||
|
@ -64,13 +65,13 @@ func (h *Headscale) CreateUser(name string) (*User, error) {
|
|||
|
||||
// DestroyUser destroys a User. Returns error if the User does
|
||||
// not exist or if there are machines associated with it.
|
||||
func (h *Headscale) DestroyUser(name string) error {
|
||||
user, err := h.GetUser(name)
|
||||
func (hsdb *HSDatabase) DestroyUser(name string) error {
|
||||
user, err := hsdb.GetUser(name)
|
||||
if err != nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
|
||||
machines, err := h.ListMachinesByUser(name)
|
||||
machines, err := hsdb.ListMachinesByUser(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -78,18 +79,18 @@ func (h *Headscale) DestroyUser(name string) error {
|
|||
return ErrUserStillHasNodes
|
||||
}
|
||||
|
||||
keys, err := h.ListPreAuthKeys(name)
|
||||
keys, err := hsdb.ListPreAuthKeys(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, key := range keys {
|
||||
err = h.DestroyPreAuthKey(key)
|
||||
err = hsdb.DestroyPreAuthKey(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if result := h.db.Unscoped().Delete(&user); result.Error != nil {
|
||||
if result := hsdb.db.Unscoped().Delete(&user); result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
|
@ -98,9 +99,9 @@ func (h *Headscale) DestroyUser(name string) error {
|
|||
|
||||
// RenameUser renames a User. Returns error if the User does
|
||||
// not exist or if another User exists with the new name.
|
||||
func (h *Headscale) RenameUser(oldName, newName string) error {
|
||||
func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
|
||||
var err error
|
||||
oldUser, err := h.GetUser(oldName)
|
||||
oldUser, err := hsdb.GetUser(oldName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -108,7 +109,7 @@ func (h *Headscale) RenameUser(oldName, newName string) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = h.GetUser(newName)
|
||||
_, err = hsdb.GetUser(newName)
|
||||
if err == nil {
|
||||
return ErrUserExists
|
||||
}
|
||||
|
@ -118,7 +119,7 @@ func (h *Headscale) RenameUser(oldName, newName string) error {
|
|||
|
||||
oldUser.Name = newName
|
||||
|
||||
if result := h.db.Save(&oldUser); result.Error != nil {
|
||||
if result := hsdb.db.Save(&oldUser); result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
|
@ -126,9 +127,9 @@ func (h *Headscale) RenameUser(oldName, newName string) error {
|
|||
}
|
||||
|
||||
// GetUser fetches a user by name.
|
||||
func (h *Headscale) GetUser(name string) (*User, error) {
|
||||
func (hsdb *HSDatabase) GetUser(name string) (*User, error) {
|
||||
user := User{}
|
||||
if result := h.db.First(&user, "name = ?", name); errors.Is(
|
||||
if result := hsdb.db.First(&user, "name = ?", name); errors.Is(
|
||||
result.Error,
|
||||
gorm.ErrRecordNotFound,
|
||||
) {
|
||||
|
@ -139,9 +140,9 @@ func (h *Headscale) GetUser(name string) (*User, error) {
|
|||
}
|
||||
|
||||
// ListUsers gets all the existing users.
|
||||
func (h *Headscale) ListUsers() ([]User, error) {
|
||||
func (hsdb *HSDatabase) ListUsers() ([]User, error) {
|
||||
users := []User{}
|
||||
if err := h.db.Find(&users).Error; err != nil {
|
||||
if err := hsdb.db.Find(&users).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -149,18 +150,18 @@ func (h *Headscale) ListUsers() ([]User, error) {
|
|||
}
|
||||
|
||||
// ListMachinesByUser gets all the nodes in a given user.
|
||||
func (h *Headscale) ListMachinesByUser(name string) ([]Machine, error) {
|
||||
func (hsdb *HSDatabase) ListMachinesByUser(name string) ([]Machine, error) {
|
||||
err := CheckForFQDNRules(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user, err := h.GetUser(name)
|
||||
user, err := hsdb.GetUser(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
machines := []Machine{}
|
||||
if err := h.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&Machine{UserID: user.ID}).Find(&machines).Error; err != nil {
|
||||
if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&Machine{UserID: user.ID}).Find(&machines).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -168,17 +169,17 @@ func (h *Headscale) ListMachinesByUser(name string) ([]Machine, error) {
|
|||
}
|
||||
|
||||
// SetMachineUser assigns a Machine to a user.
|
||||
func (h *Headscale) SetMachineUser(machine *Machine, username string) error {
|
||||
func (hsdb *HSDatabase) SetMachineUser(machine *Machine, username string) error {
|
||||
err := CheckForFQDNRules(username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
user, err := h.GetUser(username)
|
||||
user, err := hsdb.GetUser(username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
machine.User = *user
|
||||
if result := h.db.Save(&machine); result.Error != nil {
|
||||
if result := hsdb.db.Save(&machine); result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
|
@ -211,7 +212,7 @@ func (n *User) toTailscaleLogin() *tailcfg.Login {
|
|||
return &login
|
||||
}
|
||||
|
||||
func (h *Headscale) getMapResponseUserProfiles(
|
||||
func (hsdb *HSDatabase) getMapResponseUserProfiles(
|
||||
machine Machine,
|
||||
peers Machines,
|
||||
) []tailcfg.UserProfile {
|
||||
|
@ -225,8 +226,8 @@ func (h *Headscale) getMapResponseUserProfiles(
|
|||
for _, user := range userMap {
|
||||
displayName := user.Name
|
||||
|
||||
if h.cfg.BaseDomain != "" {
|
||||
displayName = fmt.Sprintf("%s@%s", user.Name, h.cfg.BaseDomain)
|
||||
if hsdb.baseDomain != "" {
|
||||
displayName = fmt.Sprintf("%s@%s", user.Name, hsdb.baseDomain)
|
||||
}
|
||||
|
||||
profiles = append(profiles,
|
||||
|
@ -242,7 +243,7 @@ func (h *Headscale) getMapResponseUserProfiles(
|
|||
|
||||
func (n *User) toProto() *v1.User {
|
||||
return &v1.User{
|
||||
Id: strconv.FormatUint(uint64(n.ID), Base10),
|
||||
Id: strconv.FormatUint(uint64(n.ID), util.Base10),
|
||||
Name: n.Name,
|
||||
CreatedAt: timestamppb.New(n.CreatedAt),
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue