introduce rw lock for db, ish...

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2023-07-17 13:35:05 +02:00 committed by Kristoffer Dalby
parent a1a3ff4ba8
commit eff529f2c5
12 changed files with 369 additions and 156 deletions

View file

@ -28,6 +28,10 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
expiration *time.Time,
aclTags []string,
) (*types.PreAuthKey, error) {
// TODO(kradalby): figure out this lock
// hsdb.mu.Lock()
// defer hsdb.mu.Unlock()
user, err := hsdb.GetUser(userName)
if err != nil {
return nil, err
@ -92,7 +96,14 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
// ListPreAuthKeys returns the list of PreAuthKeys for a user.
func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, error) {
user, err := hsdb.GetUser(userName)
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.listPreAuthKeys(userName)
}
func (hsdb *HSDatabase) listPreAuthKeys(userName string) ([]types.PreAuthKey, error) {
user, err := hsdb.getUser(userName)
if err != nil {
return nil, err
}
@ -107,6 +118,9 @@ func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, er
// GetPreAuthKey returns a PreAuthKey for a given key.
func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKey, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
pak, err := hsdb.ValidatePreAuthKey(key)
if err != nil {
return nil, err
@ -122,6 +136,13 @@ func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKe
// DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey
// does not exist.
func (hsdb *HSDatabase) DestroyPreAuthKey(pak types.PreAuthKey) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
return hsdb.destroyPreAuthKey(pak)
}
func (hsdb *HSDatabase) destroyPreAuthKey(pak types.PreAuthKey) error {
return hsdb.db.Transaction(func(db *gorm.DB) error {
if result := db.Unscoped().Where(types.PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&types.PreAuthKeyACLTag{}); result.Error != nil {
return result.Error
@ -137,6 +158,9 @@ func (hsdb *HSDatabase) DestroyPreAuthKey(pak types.PreAuthKey) error {
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
if err := hsdb.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
return err
}
@ -146,6 +170,9 @@ func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error {
// UsePreAuthKey marks a PreAuthKey as used.
func (hsdb *HSDatabase) UsePreAuthKey(k *types.PreAuthKey) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
k.Used = true
if err := hsdb.db.Save(k).Error; err != nil {
return fmt.Errorf("failed to update key used status in the database: %w", err)
@ -157,6 +184,9 @@ func (hsdb *HSDatabase) UsePreAuthKey(k *types.PreAuthKey) error {
// ValidatePreAuthKey does the heavy lifting for validation of the PreAuthKey coming from a node
// If returns no error and a PreAuthKey, it can be used.
func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
pak := types.PreAuthKey{}
if result := hsdb.db.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is(
result.Error,
@ -174,7 +204,10 @@ func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error)
}
machines := types.Machines{}
if err := hsdb.db.Preload("AuthKey").Where(&types.Machine{AuthKeyID: uint(pak.ID)}).Find(&machines).Error; err != nil {
if err := hsdb.db.
Preload("AuthKey").
Where(&types.Machine{AuthKeyID: uint(pak.ID)}).
Find(&machines).Error; err != nil {
return nil, err
}