fix constraints
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
5e7c3153b9
commit
281025bb16
5 changed files with 122 additions and 11 deletions
|
@ -1,6 +1,7 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
|
@ -257,3 +258,110 @@ func testCopyOfDatabase(src string) (string, error) {
|
|||
func emptyCache() *zcache.Cache[string, types.Node] {
|
||||
return zcache.New[string, types.Node](time.Minute, time.Hour)
|
||||
}
|
||||
|
||||
func TestConstraints(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
run func(*testing.T, *gorm.DB)
|
||||
}{
|
||||
{
|
||||
name: "no-duplicate-username-if-no-oidc",
|
||||
run: func(t *testing.T, db *gorm.DB) {
|
||||
_, err := CreateUser(db, "user1")
|
||||
require.NoError(t, err)
|
||||
_, err = CreateUser(db, "user1")
|
||||
require.Error(t, err)
|
||||
// assert.Contains(t, err.Error(), "UNIQUE constraint failed: users.username")
|
||||
require.Contains(t, err.Error(), "user already exists")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no-oidc-duplicate-username-and-id",
|
||||
run: func(t *testing.T, db *gorm.DB) {
|
||||
user := types.User{
|
||||
Model: gorm.Model{ID: 1},
|
||||
Name: "user1",
|
||||
}
|
||||
user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true}
|
||||
|
||||
err := db.Save(&user).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
user = types.User{
|
||||
Model: gorm.Model{ID: 2},
|
||||
Name: "user1",
|
||||
}
|
||||
user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true}
|
||||
|
||||
err = db.Save(&user).Error
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "UNIQUE constraint failed: users.provider_identifier")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no-oidc-duplicate-id",
|
||||
run: func(t *testing.T, db *gorm.DB) {
|
||||
user := types.User{
|
||||
Model: gorm.Model{ID: 1},
|
||||
Name: "user1",
|
||||
}
|
||||
user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true}
|
||||
|
||||
err := db.Save(&user).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
user = types.User{
|
||||
Model: gorm.Model{ID: 2},
|
||||
Name: "user1.1",
|
||||
}
|
||||
user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true}
|
||||
|
||||
err = db.Save(&user).Error
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "UNIQUE constraint failed: users.provider_identifier")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "allow-duplicate-username-cli-then-oidc",
|
||||
run: func(t *testing.T, db *gorm.DB) {
|
||||
_, err := CreateUser(db, "user1") // Create CLI username
|
||||
require.NoError(t, err)
|
||||
|
||||
user := types.User{
|
||||
Name: "user1",
|
||||
}
|
||||
user.ProviderIdentifier.String = "http://test.com/user1"
|
||||
|
||||
err = db.Save(&user).Error
|
||||
require.NoError(t, err)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "allow-duplicate-username-oidc-then-cli",
|
||||
run: func(t *testing.T, db *gorm.DB) {
|
||||
user := types.User{
|
||||
Name: "user1",
|
||||
}
|
||||
user.ProviderIdentifier.String = "http://test.com/user1"
|
||||
|
||||
err := db.Save(&user).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = CreateUser(db, "user1") // Create CLI username
|
||||
require.NoError(t, err)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db, err := newTestDB()
|
||||
if err != nil {
|
||||
t.Fatalf("creating database: %s", err)
|
||||
}
|
||||
|
||||
tt.run(t, db.DB)
|
||||
})
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -28,11 +28,9 @@ func CreateUser(tx *gorm.DB, name string) (*types.User, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user := types.User{}
|
||||
if err := tx.Where("name = ?", name).First(&user).Error; err == nil {
|
||||
return nil, ErrUserExists
|
||||
user := types.User{
|
||||
Name: name,
|
||||
}
|
||||
user.Name = name
|
||||
if err := tx.Create(&user).Error; err != nil {
|
||||
return nil, fmt.Errorf("creating user: %w", err)
|
||||
}
|
||||
|
@ -177,6 +175,10 @@ func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if len(users) == 0 {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
|
||||
if len(users) != 1 {
|
||||
return nil, fmt.Errorf("expected exactly one user, found %d", len(users))
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue