Replace database locks with transactions (#1701)
This commits removes the locks used to guard data integrity for the database and replaces them with Transactions, turns out that SQL had a way to deal with this all along. This reduces the complexity we had with multiple locks that might stack or recurse (database, nofitifer, mapper). All notifications and state updates are now triggered _after_ a database change. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
cbf57e27a7
commit
83769ba715
32 changed files with 1496 additions and 1128 deletions
|
@ -20,7 +20,6 @@ var (
|
|||
ErrPreAuthKeyACLTagInvalid = errors.New("AuthKey tag is invalid")
|
||||
)
|
||||
|
||||
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
|
||||
func (hsdb *HSDatabase) CreatePreAuthKey(
|
||||
userName string,
|
||||
reusable bool,
|
||||
|
@ -28,11 +27,21 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
|
|||
expiration *time.Time,
|
||||
aclTags []string,
|
||||
) (*types.PreAuthKey, error) {
|
||||
// TODO(kradalby): figure out this lock
|
||||
// hsdb.mu.Lock()
|
||||
// defer hsdb.mu.Unlock()
|
||||
return Write(hsdb.DB, func(tx *gorm.DB) (*types.PreAuthKey, error) {
|
||||
return CreatePreAuthKey(tx, userName, reusable, ephemeral, expiration, aclTags)
|
||||
})
|
||||
}
|
||||
|
||||
user, err := hsdb.GetUser(userName)
|
||||
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
|
||||
func CreatePreAuthKey(
|
||||
tx *gorm.DB,
|
||||
userName string,
|
||||
reusable bool,
|
||||
ephemeral bool,
|
||||
expiration *time.Time,
|
||||
aclTags []string,
|
||||
) (*types.PreAuthKey, error) {
|
||||
user, err := GetUser(tx, userName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -48,7 +57,7 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
|
|||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
kstr, err := hsdb.generateKey()
|
||||
kstr, err := generateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -63,29 +72,25 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
|
|||
Expiration: expiration,
|
||||
}
|
||||
|
||||
err = hsdb.db.Transaction(func(db *gorm.DB) error {
|
||||
if err := db.Save(&key).Error; err != nil {
|
||||
return fmt.Errorf("failed to create key in the database: %w", err)
|
||||
}
|
||||
if err := tx.Save(&key).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to create key in the database: %w", err)
|
||||
}
|
||||
|
||||
if len(aclTags) > 0 {
|
||||
seenTags := map[string]bool{}
|
||||
if len(aclTags) > 0 {
|
||||
seenTags := map[string]bool{}
|
||||
|
||||
for _, tag := range aclTags {
|
||||
if !seenTags[tag] {
|
||||
if err := db.Save(&types.PreAuthKeyACLTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to ceate key tag in the database: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
seenTags[tag] = true
|
||||
for _, tag := range aclTags {
|
||||
if !seenTags[tag] {
|
||||
if err := tx.Save(&types.PreAuthKeyACLTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"failed to ceate key tag in the database: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
seenTags[tag] = true
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -94,22 +99,21 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
|
|||
return &key, nil
|
||||
}
|
||||
|
||||
// ListPreAuthKeys returns the list of PreAuthKeys for a user.
|
||||
func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
return hsdb.listPreAuthKeys(userName)
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) {
|
||||
return ListPreAuthKeys(rx, userName)
|
||||
})
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) listPreAuthKeys(userName string) ([]types.PreAuthKey, error) {
|
||||
user, err := hsdb.getUser(userName)
|
||||
// ListPreAuthKeys returns the list of PreAuthKeys for a user.
|
||||
func ListPreAuthKeys(tx *gorm.DB, userName string) ([]types.PreAuthKey, error) {
|
||||
user, err := GetUser(tx, userName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keys := []types.PreAuthKey{}
|
||||
if err := hsdb.db.Preload("User").Preload("ACLTags").Where(&types.PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil {
|
||||
if err := tx.Preload("User").Preload("ACLTags").Where(&types.PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -117,11 +121,8 @@ func (hsdb *HSDatabase) listPreAuthKeys(userName string) ([]types.PreAuthKey, er
|
|||
}
|
||||
|
||||
// GetPreAuthKey returns a PreAuthKey for a given key.
|
||||
func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKey, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
pak, err := hsdb.ValidatePreAuthKey(key)
|
||||
func GetPreAuthKey(tx *gorm.DB, user string, key string) (*types.PreAuthKey, error) {
|
||||
pak, err := ValidatePreAuthKey(tx, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -135,15 +136,8 @@ func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKe
|
|||
|
||||
// DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey
|
||||
// does not exist.
|
||||
func (hsdb *HSDatabase) DestroyPreAuthKey(pak types.PreAuthKey) error {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
|
||||
return hsdb.destroyPreAuthKey(pak)
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) destroyPreAuthKey(pak types.PreAuthKey) error {
|
||||
return hsdb.db.Transaction(func(db *gorm.DB) error {
|
||||
func DestroyPreAuthKey(tx *gorm.DB, pak types.PreAuthKey) error {
|
||||
return tx.Transaction(func(db *gorm.DB) error {
|
||||
if result := db.Unscoped().Where(types.PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&types.PreAuthKeyACLTag{}); result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
@ -156,12 +150,15 @@ func (hsdb *HSDatabase) destroyPreAuthKey(pak types.PreAuthKey) error {
|
|||
})
|
||||
}
|
||||
|
||||
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
|
||||
func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
return hsdb.Write(func(tx *gorm.DB) error {
|
||||
return ExpirePreAuthKey(tx, k)
|
||||
})
|
||||
}
|
||||
|
||||
if err := hsdb.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
|
||||
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
|
||||
func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
|
||||
if err := tx.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -169,26 +166,26 @@ func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error {
|
|||
}
|
||||
|
||||
// UsePreAuthKey marks a PreAuthKey as used.
|
||||
func (hsdb *HSDatabase) UsePreAuthKey(k *types.PreAuthKey) error {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
|
||||
func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
|
||||
k.Used = true
|
||||
if err := hsdb.db.Save(k).Error; err != nil {
|
||||
if err := tx.Save(k).Error; err != nil {
|
||||
return fmt.Errorf("failed to update key used status in the database: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.PreAuthKey, error) {
|
||||
return ValidatePreAuthKey(rx, k)
|
||||
})
|
||||
}
|
||||
|
||||
// ValidatePreAuthKey does the heavy lifting for validation of the PreAuthKey coming from a node
|
||||
// If returns no error and a PreAuthKey, it can be used.
|
||||
func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
func ValidatePreAuthKey(tx *gorm.DB, k string) (*types.PreAuthKey, error) {
|
||||
pak := types.PreAuthKey{}
|
||||
if result := hsdb.db.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is(
|
||||
if result := tx.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is(
|
||||
result.Error,
|
||||
gorm.ErrRecordNotFound,
|
||||
) {
|
||||
|
@ -204,7 +201,7 @@ func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error)
|
|||
}
|
||||
|
||||
nodes := types.Nodes{}
|
||||
if err := hsdb.db.
|
||||
if err := tx.
|
||||
Preload("AuthKey").
|
||||
Where(&types.Node{AuthKeyID: uint(pak.ID)}).
|
||||
Find(&nodes).Error; err != nil {
|
||||
|
@ -218,7 +215,7 @@ func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error)
|
|||
return &pak, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) generateKey() (string, error) {
|
||||
func generateKey() (string, error) {
|
||||
size := 24
|
||||
bytes := make([]byte, size)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue