fix oidc test, add tests for migration

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2024-10-18 06:59:27 -06:00 committed by Juan Font
parent 2fe65624c0
commit 4dd12a2f97
7 changed files with 475 additions and 49 deletions

View file

@ -436,7 +436,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
) (*types.User, error) {
var user *types.User
var err error
user, err = a.db.GetUserByOIDCIdentifier(claims.Sub)
user, err = a.db.GetUserByOIDCIdentifier(claims.Identifier())
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
return nil, fmt.Errorf("creating or updating user: %w", err)
}
@ -448,10 +448,12 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
// TODO(kradalby): Remove when strip_email_domain and migration is removed
// after #2170 is cleaned up.
if a.cfg.MapLegacyUsers && user == nil {
log.Trace().Str("username", claims.Username).Str("sub", claims.Sub).Msg("user not found by OIDC identifier, looking up by username")
if oldUsername, err := getUserName(claims, a.cfg.StripEmaildomain); err == nil {
log.Trace().Str("old_username", oldUsername).Str("sub", claims.Sub).Msg("found username")
user, err = a.db.GetUserByName(oldUsername)
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
return nil, fmt.Errorf("creating or updating user: %w", err)
return nil, fmt.Errorf("getting user: %w", err)
}
// If the user exists, but it already has a provider identifier (OIDC sub), create a new user.
@ -525,6 +527,9 @@ func getUserName(
claims *types.OIDCClaims,
stripEmaildomain bool,
) (string, error) {
if !claims.EmailVerified {
return "", fmt.Errorf("email not verified")
}
userName, err := util.NormalizeToFQDNRules(
claims.Email,
stripEmaildomain,

View file

@ -908,7 +908,10 @@ func LoadServerConfig() (*Config, error) {
}
}(),
UseExpiryFromToken: viper.GetBool("oidc.use_expiry_from_token"),
MapLegacyUsers: viper.GetBool("oidc.map_legacy_users"),
// TODO(kradalby): Remove when strip_email_domain is removed
// after #2170 is cleaned up
StripEmaildomain: viper.GetBool("oidc.strip_email_domain"),
MapLegacyUsers: viper.GetBool("oidc.map_legacy_users"),
},
LogTail: logTailConfig,

View file

@ -3,7 +3,6 @@ package types
import (
"cmp"
"strconv"
"strings"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/util"
@ -39,7 +38,7 @@ type User struct {
// Unique identifier of the user from OIDC,
// comes from `sub` claim in the OIDC token
// and is used to lookup the user.
ProviderIdentifier string `gorm:"index,uniqueIndex:idx_name_provider_identifier"`
ProviderIdentifier string `gorm:"unique,index,uniqueIndex:idx_name_provider_identifier"`
// Provider is the origin of the user account,
// same as RegistrationMethod, without authkey.
@ -58,9 +57,10 @@ type User struct {
// If the username does not contain an '@' it will be added to the end.
func (u *User) Username() string {
username := cmp.Or(u.Email, u.Name, u.ProviderIdentifier, strconv.FormatUint(uint64(u.ID), 10))
if !strings.Contains(username, "@") {
username = username + "@"
}
// TODO(kradalby): Wire up all of this for the future
// if !strings.Contains(username, "@") {
// username = username + "@"
// }
return username
}
@ -138,10 +138,14 @@ type OIDCClaims struct {
Username string `json:"preferred_username,omitempty"`
}
func (c *OIDCClaims) Identifier() string {
return c.Iss + "/" + c.Sub
}
// FromClaim overrides a User from OIDC claims.
// All fields will be updated, except for the ID.
func (u *User) FromClaim(claims *OIDCClaims) {
u.ProviderIdentifier = claims.Iss + "/" + claims.Sub
u.ProviderIdentifier = claims.Identifier()
u.DisplayName = claims.Name
if claims.EmailVerified {
u.Email = claims.Email