introduce rw lock for db, ish...

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2023-07-17 13:35:05 +02:00 committed by Kristoffer Dalby
parent a1a3ff4ba8
commit eff529f2c5
12 changed files with 369 additions and 156 deletions

View file

@ -18,6 +18,9 @@ var (
// CreateUser creates a new User. Returns error if could not be created
// or another user already exists.
func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
err := util.CheckForFQDNRules(name)
if err != nil {
return nil, err
@ -42,12 +45,15 @@ func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) {
// DestroyUser destroys a User. Returns error if the User does
// not exist or if there are machines associated with it.
func (hsdb *HSDatabase) DestroyUser(name string) error {
user, err := hsdb.GetUser(name)
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
user, err := hsdb.getUser(name)
if err != nil {
return ErrUserNotFound
}
machines, err := hsdb.ListMachinesByUser(name)
machines, err := hsdb.listMachinesByUser(name)
if err != nil {
return err
}
@ -55,12 +61,12 @@ func (hsdb *HSDatabase) DestroyUser(name string) error {
return ErrUserStillHasNodes
}
keys, err := hsdb.ListPreAuthKeys(name)
keys, err := hsdb.listPreAuthKeys(name)
if err != nil {
return err
}
for _, key := range keys {
err = hsdb.DestroyPreAuthKey(key)
err = hsdb.destroyPreAuthKey(key)
if err != nil {
return err
}
@ -76,8 +82,11 @@ func (hsdb *HSDatabase) DestroyUser(name string) error {
// RenameUser renames a User. Returns error if the User does
// not exist or if another User exists with the new name.
func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
var err error
oldUser, err := hsdb.GetUser(oldName)
oldUser, err := hsdb.getUser(oldName)
if err != nil {
return err
}
@ -85,7 +94,7 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
if err != nil {
return err
}
_, err = hsdb.GetUser(newName)
_, err = hsdb.getUser(newName)
if err == nil {
return ErrUserExists
}
@ -104,6 +113,13 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
// GetUser fetches a user by name.
func (hsdb *HSDatabase) GetUser(name string) (*types.User, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.getUser(name)
}
func (hsdb *HSDatabase) getUser(name string) (*types.User, error) {
user := types.User{}
if result := hsdb.db.First(&user, "name = ?", name); errors.Is(
result.Error,
@ -117,6 +133,13 @@ func (hsdb *HSDatabase) GetUser(name string) (*types.User, error) {
// ListUsers gets all the existing users.
func (hsdb *HSDatabase) ListUsers() ([]types.User, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.listUsers()
}
func (hsdb *HSDatabase) listUsers() ([]types.User, error) {
users := []types.User{}
if err := hsdb.db.Find(&users).Error; err != nil {
return nil, err
@ -127,11 +150,18 @@ func (hsdb *HSDatabase) ListUsers() ([]types.User, error) {
// ListMachinesByUser gets all the nodes in a given user.
func (hsdb *HSDatabase) ListMachinesByUser(name string) (types.Machines, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.listMachinesByUser(name)
}
func (hsdb *HSDatabase) listMachinesByUser(name string) (types.Machines, error) {
err := util.CheckForFQDNRules(name)
if err != nil {
return nil, err
}
user, err := hsdb.GetUser(name)
user, err := hsdb.getUser(name)
if err != nil {
return nil, err
}
@ -144,13 +174,16 @@ func (hsdb *HSDatabase) ListMachinesByUser(name string) (types.Machines, error)
return machines, nil
}
// SetMachineUser assigns a Machine to a user.
func (hsdb *HSDatabase) SetMachineUser(machine *types.Machine, username string) error {
// AssignMachineToUser assigns a Machine to a user.
func (hsdb *HSDatabase) AssignMachineToUser(machine *types.Machine, username string) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
err := util.CheckForFQDNRules(username)
if err != nil {
return err
}
user, err := hsdb.GetUser(username)
user, err := hsdb.getUser(username)
if err != nil {
return err
}