Redo OIDC configuration (#2020)

expand user, add claims to user

This commit expands the user table with additional fields that
can be retrieved from OIDC providers (and other places) and
uses this data in various tailscale response objects if it is
available.

This is the beginning of implementing
https://docs.google.com/document/d/1X85PMxIaVWDF6T_UPji3OeeUqVBcGj_uHRM5CI-AwlY/edit
trying to make OIDC more coherant and maintainable in addition
to giving the user a better experience and integration with a
provider.

remove usernames in magic dns, normalisation of emails

this commit removes the option to have usernames as part of MagicDNS
domains and headscale will now align with Tailscale, where there is a
root domain, and the machine name.

In addition, the various normalisation functions for dns names has been
made lighter not caring about username and special character that wont
occur.

Email are no longer normalised as part of the policy processing.

untagle oidc and regcache, use typed cache

This commits stops reusing the registration cache for oidc
purposes and switches the cache to be types and not use any
allowing the removal of a bunch of casting.

try to make reauth/register branches clearer in oidc

Currently there was a function that did a bunch of stuff,
finding the machine key, trying to find the node, reauthing
the node, returning some status, and it was called validate
which was very confusing.

This commit tries to split this into what to do if the node
exists, if it needs to register etc.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2024-10-02 14:50:17 +02:00 committed by GitHub
parent bc9e83b52e
commit 218138afee
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 628 additions and 917 deletions

View file

@ -17,12 +17,13 @@ import (
"github.com/coreos/go-oidc/v3/oidc"
"github.com/gorilla/mux"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"golang.org/x/oauth2"
"gorm.io/gorm"
"tailscale.com/types/key"
"zgo.at/zcache/v2"
)
const (
@ -45,49 +46,81 @@ var (
errOIDCNodeKeyMissing = errors.New("could not get node key from cache")
)
type IDTokenClaims struct {
Name string `json:"name,omitempty"`
Groups []string `json:"groups,omitempty"`
Email string `json:"email"`
Username string `json:"preferred_username,omitempty"`
type AuthProviderOIDC struct {
serverURL string
cfg *types.OIDCConfig
db *db.HSDatabase
registrationCache *zcache.Cache[string, key.MachinePublic]
notifier *notifier.Notifier
ipAlloc *db.IPAllocator
oidcProvider *oidc.Provider
oauth2Config *oauth2.Config
}
func (h *Headscale) initOIDC() error {
func NewAuthProviderOIDC(
ctx context.Context,
serverURL string,
cfg *types.OIDCConfig,
db *db.HSDatabase,
notif *notifier.Notifier,
ipAlloc *db.IPAllocator,
) (*AuthProviderOIDC, error) {
var err error
// grab oidc config if it hasn't been already
if h.oauth2Config == nil {
h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDC.Issuer)
if err != nil {
return fmt.Errorf("creating OIDC provider from issuer config: %w", err)
}
h.oauth2Config = &oauth2.Config{
ClientID: h.cfg.OIDC.ClientID,
ClientSecret: h.cfg.OIDC.ClientSecret,
Endpoint: h.oidcProvider.Endpoint(),
RedirectURL: fmt.Sprintf(
"%s/oidc/callback",
strings.TrimSuffix(h.cfg.ServerURL, "/"),
),
Scopes: h.cfg.OIDC.Scope,
}
oidcProvider, err := oidc.NewProvider(context.Background(), cfg.Issuer)
if err != nil {
return nil, fmt.Errorf("creating OIDC provider from issuer config: %w", err)
}
return nil
oauth2Config := &oauth2.Config{
ClientID: cfg.ClientID,
ClientSecret: cfg.ClientSecret,
Endpoint: oidcProvider.Endpoint(),
RedirectURL: fmt.Sprintf(
"%s/oidc/callback",
strings.TrimSuffix(serverURL, "/"),
),
Scopes: cfg.Scope,
}
registrationCache := zcache.New[string, key.MachinePublic](
registerCacheExpiration,
registerCacheCleanup,
)
return &AuthProviderOIDC{
serverURL: serverURL,
cfg: cfg,
db: db,
registrationCache: registrationCache,
notifier: notif,
ipAlloc: ipAlloc,
oidcProvider: oidcProvider,
oauth2Config: oauth2Config,
}, nil
}
func (h *Headscale) determineTokenExpiration(idTokenExpiration time.Time) time.Time {
if h.cfg.OIDC.UseExpiryFromToken {
func (a *AuthProviderOIDC) AuthURL(mKey key.MachinePublic) string {
return fmt.Sprintf(
"%s/register/%s",
strings.TrimSuffix(a.serverURL, "/"),
mKey.String())
}
func (a *AuthProviderOIDC) determineNodeExpiry(idTokenExpiration time.Time) time.Time {
if a.cfg.UseExpiryFromToken {
return idTokenExpiration
}
return time.Now().Add(h.cfg.OIDC.Expiry)
return time.Now().Add(a.cfg.Expiry)
}
// RegisterOIDC redirects to the OIDC provider for authentication
// Puts NodeKey in cache so the callback can retrieve it using the oidc state param
// Listens in /oidc/register/:mKey.
func (h *Headscale) RegisterOIDC(
// Listens in /register/:mKey.
func (a *AuthProviderOIDC) RegisterHandler(
writer http.ResponseWriter,
req *http.Request,
) {
@ -108,46 +141,32 @@ func (h *Headscale) RegisterOIDC(
[]byte(machineKeyStr),
)
if err != nil {
log.Warn().
Err(err).
Msg("Failed to parse incoming nodekey in OIDC registration")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("Wrong params"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
http.Error(writer, err.Error(), http.StatusBadRequest)
return
}
randomBlob := make([]byte, randomByteSize)
if _, err := rand.Read(randomBlob); err != nil {
util.LogErr(err, "could not read 16 bytes from rand")
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
stateStr := hex.EncodeToString(randomBlob)[:32]
// place the node key into the state cache, so it can be retrieved later
h.registrationCache.Set(
a.registrationCache.Set(
stateStr,
machineKey,
registerCacheExpiration,
)
// Add any extra parameter provided in the configuration to the Authorize Endpoint request
extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams))
extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams))
for k, v := range h.cfg.OIDC.ExtraParams {
for k, v := range a.cfg.ExtraParams {
extras = append(extras, oauth2.SetAuthURLParam(k, v))
}
authURL := h.oauth2Config.AuthCodeURL(stateStr, extras...)
authURL := a.oauth2Config.AuthCodeURL(stateStr, extras...)
log.Debug().Msgf("Redirecting to %s for authentication", authURL)
http.Redirect(writer, req, authURL, http.StatusFound)
@ -165,216 +184,165 @@ var oidcCallbackTemplate = template.Must(
template.New("oidccallback").Parse(oidcCallbackTemplateContent),
)
// OIDCCallback handles the callback from the OIDC endpoint
// OIDCCallbackHandler handles the callback from the OIDC endpoint
// Retrieves the nkey from the state cache and adds the node to the users email user
// TODO: A confirmation page for new nodes should be added to avoid phishing vulnerabilities
// TODO: Add groups information from OIDC tokens into node HostInfo
// Listens in /oidc/callback.
func (h *Headscale) OIDCCallback(
func (a *AuthProviderOIDC) OIDCCallbackHandler(
writer http.ResponseWriter,
req *http.Request,
) {
code, state, err := validateOIDCCallbackParams(writer, req)
code, state, err := extractCodeAndStateParamFromRequest(req)
if err != nil {
http.Error(writer, err.Error(), http.StatusBadRequest)
return
}
rawIDToken, err := h.getIDTokenForOIDCCallback(req.Context(), writer, code, state)
idToken, err := a.extractIDToken(req.Context(), code)
if err != nil {
http.Error(writer, err.Error(), http.StatusBadRequest)
return
}
nodeExpiry := a.determineNodeExpiry(idToken.Expiry)
var claims types.OIDCClaims
if err := idToken.Claims(&claims); err != nil {
http.Error(writer, fmt.Errorf("failed to decode ID token claims: %w", err).Error(), http.StatusInternalServerError)
return
}
idToken, err := h.verifyIDTokenForOIDCCallback(req.Context(), writer, rawIDToken)
if err != nil {
return
}
idTokenExpiry := h.determineTokenExpiration(idToken.Expiry)
// TODO: we can use userinfo at some point to grab additional information about the user (groups membership, etc)
// userInfo, err := oidcProvider.UserInfo(context.Background(), oauth2.StaticTokenSource(oauth2Token))
// if err != nil {
// c.String(http.StatusBadRequest, fmt.Sprintf("Failed to retrieve userinfo"))
// return
// }
claims, err := extractIDTokenClaims(writer, idToken)
if err := validateOIDCAllowedDomains(a.cfg.AllowedDomains, &claims); err != nil {
http.Error(writer, err.Error(), http.StatusUnauthorized)
return
}
if err := validateOIDCAllowedGroups(a.cfg.AllowedGroups, &claims); err != nil {
http.Error(writer, err.Error(), http.StatusUnauthorized)
return
}
if err := validateOIDCAllowedUsers(a.cfg.AllowedUsers, &claims); err != nil {
http.Error(writer, err.Error(), http.StatusUnauthorized)
return
}
user, err := a.createOrUpdateUserFromClaim(&claims)
if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
}
if err := validateOIDCAllowedDomains(writer, h.cfg.OIDC.AllowedDomains, claims); err != nil {
// Retrieve the node and the machine key from the state cache and
// database.
// If the node exists, then the node should be reauthenticated,
// if the node does not exist, and the machine key exists, then
// this is a new node that should be registered.
node, mKey := a.getMachineKeyFromState(state)
// Reauthenticate the node if it does exists.
if node != nil {
err := a.reauthenticateNode(node, nodeExpiry)
if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
}
// TODO(kradalby): replace with go-elem
var content bytes.Buffer
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
User: user.DisplayNameOrUsername(),
Verb: "Reauthenticated",
}); err != nil {
http.Error(writer, fmt.Errorf("rendering OIDC callback template: %w", err).Error(), http.StatusInternalServerError)
return
}
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(content.Bytes())
if err != nil {
util.LogErr(err, "Failed to write response")
}
return
}
if err := validateOIDCAllowedGroups(writer, h.cfg.OIDC.AllowedGroups, claims); err != nil {
// Register the node if it does not exist.
if mKey != nil {
if err := a.registerNode(user, mKey, nodeExpiry); err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
}
content, err := renderOIDCCallbackTemplate(user)
if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
}
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
if _, err := writer.Write(content.Bytes()); err != nil {
util.LogErr(err, "Failed to write response")
}
return
}
if err := validateOIDCAllowedUsers(writer, h.cfg.OIDC.AllowedUsers, claims); err != nil {
return
}
machineKey, nodeExists, err := h.validateNodeForOIDCCallback(
writer,
state,
claims,
idTokenExpiry,
)
if err != nil || nodeExists {
return
}
userName, err := getUserName(writer, claims, h.cfg.OIDC.StripEmaildomain)
if err != nil {
return
}
// register the node if it's new
log.Debug().Msg("Registering new node after successful callback")
user, err := h.findOrCreateNewUserForOIDCCallback(writer, userName)
if err != nil {
return
}
if err := h.registerNodeForOIDCCallback(writer, user, machineKey, idTokenExpiry); err != nil {
return
}
content, err := renderOIDCCallbackTemplate(writer, claims)
if err != nil {
return
}
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
if _, err := writer.Write(content.Bytes()); err != nil {
util.LogErr(err, "Failed to write response")
}
// Neither node nor machine key was found in the state cache meaning
// that we could not reauth nor register the node.
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
}
func validateOIDCCallbackParams(
writer http.ResponseWriter,
func extractCodeAndStateParamFromRequest(
req *http.Request,
) (string, string, error) {
code := req.URL.Query().Get("code")
state := req.URL.Query().Get("state")
if code == "" || state == "" {
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("Wrong params"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
return "", "", errEmptyOIDCCallbackParams
}
return code, state, nil
}
func (h *Headscale) getIDTokenForOIDCCallback(
// extractIDToken takes the code parameter from the callback
// and extracts the ID token from the oauth2 token.
func (a *AuthProviderOIDC) extractIDToken(
ctx context.Context,
writer http.ResponseWriter,
code, state string,
) (string, error) {
oauth2Token, err := h.oauth2Config.Exchange(ctx, code)
if err != nil {
util.LogErr(err, "Could not exchange code for token")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, werr := writer.Write([]byte("Could not exchange code for token"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
return "", err
}
log.Trace().
Caller().
Str("code", code).
Str("state", state).
Msg("Got oidc callback")
rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string)
if !rawIDTokenOK {
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("Could not extract ID Token"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
return "", errNoOIDCIDToken
}
return rawIDToken, nil
}
func (h *Headscale) verifyIDTokenForOIDCCallback(
ctx context.Context,
writer http.ResponseWriter,
rawIDToken string,
code string,
) (*oidc.IDToken, error) {
verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDC.ClientID})
oauth2Token, err := a.oauth2Config.Exchange(ctx, code)
if err != nil {
return nil, fmt.Errorf("could not exchange code for token: %w", err)
}
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
if !ok {
return nil, errNoOIDCIDToken
}
verifier := a.oidcProvider.Verifier(&oidc.Config{ClientID: a.cfg.ClientID})
idToken, err := verifier.Verify(ctx, rawIDToken)
if err != nil {
util.LogErr(err, "failed to verify id token")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, werr := writer.Write([]byte("Failed to verify id token"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
return nil, err
return nil, fmt.Errorf("failed to verify ID token: %w", err)
}
return idToken, nil
}
func extractIDTokenClaims(
writer http.ResponseWriter,
idToken *oidc.IDToken,
) (*IDTokenClaims, error) {
var claims IDTokenClaims
if err := idToken.Claims(&claims); err != nil {
util.LogErr(err, "Failed to decode id token claims")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, werr := writer.Write([]byte("Failed to decode id token claims"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
return nil, err
}
return &claims, nil
}
// validateOIDCAllowedDomains checks that if AllowedDomains is provided,
// that the authenticated principal ends with @<alloweddomain>.
func validateOIDCAllowedDomains(
writer http.ResponseWriter,
allowedDomains []string,
claims *IDTokenClaims,
claims *types.OIDCClaims,
) error {
if len(allowedDomains) > 0 {
if at := strings.LastIndex(claims.Email, "@"); at < 0 ||
!slices.Contains(allowedDomains, claims.Email[at+1:]) {
log.Trace().Msg("authenticated principal does not match any allowed domain")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("unauthorized principal (domain mismatch)"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
return errOIDCAllowedDomains
}
}
@ -387,9 +355,8 @@ func validateOIDCAllowedDomains(
// claims.Groups can be populated by adding a client scope named
// 'groups' that contains group membership.
func validateOIDCAllowedGroups(
writer http.ResponseWriter,
allowedGroups []string,
claims *IDTokenClaims,
claims *types.OIDCClaims,
) error {
if len(allowedGroups) > 0 {
for _, group := range allowedGroups {
@ -398,14 +365,6 @@ func validateOIDCAllowedGroups(
}
}
log.Trace().Msg("authenticated principal not in any allowed groups")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("unauthorized principal (allowed groups)"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
return errOIDCAllowedGroups
}
@ -415,249 +374,129 @@ func validateOIDCAllowedGroups(
// validateOIDCAllowedUsers checks that if AllowedUsers is provided,
// that the authenticated principal is part of that list.
func validateOIDCAllowedUsers(
writer http.ResponseWriter,
allowedUsers []string,
claims *IDTokenClaims,
claims *types.OIDCClaims,
) error {
if len(allowedUsers) > 0 &&
!slices.Contains(allowedUsers, claims.Email) {
log.Trace().Msg("authenticated principal does not match any allowed user")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("unauthorized principal (user mismatch)"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
return errOIDCAllowedUsers
}
return nil
}
// validateNode retrieves node information if it exist
// The error is not important, because if it does not
// exist, then this is a new node and we will move
// on to registration.
func (h *Headscale) validateNodeForOIDCCallback(
writer http.ResponseWriter,
state string,
claims *IDTokenClaims,
expiry time.Time,
) (*key.MachinePublic, bool, error) {
// retrieve nodekey from state cache
machineKeyIf, machineKeyFound := h.registrationCache.Get(state)
if !machineKeyFound {
log.Trace().
Msg("requested node state key expired before authorisation completed")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("state has expired"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
return nil, false, errOIDCNodeKeyMissing
}
var machineKey key.MachinePublic
machineKey, machineKeyOK := machineKeyIf.(key.MachinePublic)
if !machineKeyOK {
log.Trace().
Interface("got", machineKeyIf).
Msg("requested node state key is not a nodekey")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("state is invalid"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
return nil, false, errOIDCInvalidNodeState
// getMachineKeyFromState retrieves the machine key from the state
// cache. If the machine key is found, it will try retrieve the
// node information from the database.
func (a *AuthProviderOIDC) getMachineKeyFromState(state string) (*types.Node, *key.MachinePublic) {
machineKey, ok := a.registrationCache.Get(state)
if !ok {
return nil, nil
}
// retrieve node information if it exist
// The error is not important, because if it does not
// exist, then this is a new node and we will move
// on to registration.
node, _ := h.db.GetNodeByMachineKey(machineKey)
node, _ := a.db.GetNodeByMachineKey(machineKey)
if node != nil {
log.Trace().
Caller().
Str("node", node.Hostname).
Msg("node already registered, reauthenticating")
err := h.db.NodeSetExpiry(node.ID, expiry)
if err != nil {
util.LogErr(err, "Failed to refresh node")
http.Error(
writer,
"Failed to refresh node",
http.StatusInternalServerError,
)
return nil, true, err
}
log.Debug().
Str("node", node.Hostname).
Str("expiresAt", fmt.Sprintf("%v", expiry)).
Msg("successfully refreshed node")
var content bytes.Buffer
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
User: claims.Email,
Verb: "Reauthenticated",
}); err != nil {
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("Could not render OIDC callback template"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
return nil, true, fmt.Errorf("rendering OIDC callback template: %w", err)
}
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(content.Bytes())
if err != nil {
util.LogErr(err, "Failed to write response")
}
ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname)
h.nodeNotifier.NotifyByNodeID(
ctx,
types.StateUpdate{
Type: types.StateSelfUpdate,
ChangeNodes: []types.NodeID{node.ID},
},
node.ID,
)
ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname)
h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, expiry), node.ID)
return nil, true, nil
}
return &machineKey, false, nil
return node, &machineKey
}
func getUserName(
writer http.ResponseWriter,
claims *IDTokenClaims,
stripEmaildomain bool,
) (string, error) {
userName, err := util.NormalizeToFQDNRules(
claims.Email,
stripEmaildomain,
)
// reauthenticateNode updates the node expiry in the database
// and notifies the node and its peers about the change.
func (a *AuthProviderOIDC) reauthenticateNode(
node *types.Node,
expiry time.Time,
) error {
err := a.db.NodeSetExpiry(node.ID, expiry)
if err != nil {
util.LogErr(err, "couldn't normalize email")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("couldn't normalize email"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
return "", err
return err
}
return userName, nil
ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname)
a.notifier.NotifyByNodeID(
ctx,
types.StateUpdate{
Type: types.StateSelfUpdate,
ChangeNodes: []types.NodeID{node.ID},
},
node.ID,
)
ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname)
a.notifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, expiry), node.ID)
return nil
}
func (h *Headscale) findOrCreateNewUserForOIDCCallback(
writer http.ResponseWriter,
userName string,
func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
claims *types.OIDCClaims,
) (*types.User, error) {
user, err := h.db.GetUser(userName)
if errors.Is(err, db.ErrUserNotFound) {
user, err = h.db.CreateUser(userName)
if err != nil {
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("could not create user"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
var user *types.User
var err error
user, err = a.db.GetUserByOIDCIdentifier(claims.Sub)
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
return nil, fmt.Errorf("creating or updating user: %w", err)
}
return nil, fmt.Errorf("creating new user: %w", err)
}
} else if err != nil {
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("could not find or create user"))
if werr != nil {
util.LogErr(err, "Failed to write response")
// This check is for legacy, if the user cannot be found by the OIDC identifier
// look it up by username. This should only be needed once.
if user == nil {
user, err = a.db.GetUserByName(claims.Username)
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
return nil, fmt.Errorf("creating or updating user: %w", err)
}
return nil, fmt.Errorf("find or create user: %w", err)
// if the user is still not found, create a new empty user.
if user == nil {
user = &types.User{}
}
}
user.FromClaim(claims)
err = a.db.DB.Save(user).Error
if err != nil {
return nil, fmt.Errorf("creating or updating user: %w", err)
}
return user, nil
}
func (h *Headscale) registerNodeForOIDCCallback(
writer http.ResponseWriter,
func (a *AuthProviderOIDC) registerNode(
user *types.User,
machineKey *key.MachinePublic,
expiry time.Time,
) error {
ipv4, ipv6, err := h.ipAlloc.Next()
ipv4, ipv6, err := a.ipAlloc.Next()
if err != nil {
return err
}
if err := h.db.Write(func(tx *gorm.DB) error {
if _, err := db.RegisterNodeFromAuthCallback(
// TODO(kradalby): find a better way to use the cache across modules
tx,
h.registrationCache,
*machineKey,
user.Name,
&expiry,
util.RegisterMethodOIDC,
ipv4, ipv6,
); err != nil {
return err
}
return nil
}); err != nil {
util.LogErr(err, "could not register node")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("could not register node"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
return err
if _, err := a.db.RegisterNodeFromAuthCallback(
*machineKey,
types.UserID(user.ID),
&expiry,
util.RegisterMethodOIDC,
ipv4, ipv6,
); err != nil {
return fmt.Errorf("could not register node: %w", err)
}
return nil
}
// TODO(kradalby):
// Rewrite in elem-go
func renderOIDCCallbackTemplate(
writer http.ResponseWriter,
claims *IDTokenClaims,
user *types.User,
) (*bytes.Buffer, error) {
var content bytes.Buffer
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
User: claims.Email,
User: user.DisplayNameOrUsername(),
Verb: "Authenticated",
}); err != nil {
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("Could not render OIDC callback template"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
return nil, fmt.Errorf("rendering OIDC callback template: %w", err)
}