Return better web errors to the user (#2398)

* add dedicated http error to propagate to user

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* classify user errors in http handlers

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* move validation of pre auth key out of db

This move separates the logic a bit and allow us to
write specific errors for the caller, in this case the web
layer so we can present the user with the correct error
codes without bleeding web stuff into a generic validate.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* update changelog

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

---------

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2025-02-01 15:25:18 +01:00 committed by GitHub
parent 1c7f3bc440
commit 45752db0f6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 268 additions and 229 deletions

View file

@ -141,21 +141,21 @@ func (a *AuthProviderOIDC) RegisterHandler(
// the template and log an error.
registrationId, err := types.RegistrationIDFromString(registrationIdStr)
if err != nil {
httpError(writer, err, "invalid registration ID", http.StatusBadRequest)
httpError(writer, NewHTTPError(http.StatusBadRequest, "invalid registration id", err))
return
}
// Set the state and nonce cookies to protect against CSRF attacks
state, err := setCSRFCookie(writer, req, "state")
if err != nil {
httpError(writer, err, "Internal server error", http.StatusInternalServerError)
httpError(writer, err)
return
}
// Set the state and nonce cookies to protect against CSRF attacks
nonce, err := setCSRFCookie(writer, req, "nonce")
if err != nil {
httpError(writer, err, "Internal server error", http.StatusInternalServerError)
httpError(writer, err)
return
}
@ -219,34 +219,34 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
) {
code, state, err := extractCodeAndStateParamFromRequest(req)
if err != nil {
httpError(writer, err, err.Error(), http.StatusBadRequest)
httpError(writer, err)
return
}
cookieState, err := req.Cookie("state")
if err != nil {
httpError(writer, err, "state not found", http.StatusBadRequest)
httpError(writer, NewHTTPError(http.StatusBadRequest, "state not found", err))
return
}
if state != cookieState.Value {
httpError(writer, err, "state did not match", http.StatusBadRequest)
httpError(writer, NewHTTPError(http.StatusForbidden, "state did not match", nil))
return
}
idToken, err := a.extractIDToken(req.Context(), code, state)
if err != nil {
httpError(writer, err, err.Error(), http.StatusBadRequest)
httpError(writer, err)
return
}
nonce, err := req.Cookie("nonce")
if err != nil {
httpError(writer, err, "nonce not found", http.StatusBadRequest)
httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found", err))
return
}
if idToken.Nonce != nonce.Value {
httpError(writer, err, "nonce did not match", http.StatusBadRequest)
httpError(writer, NewHTTPError(http.StatusForbidden, "nonce did not match", nil))
return
}
@ -254,29 +254,28 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
var claims types.OIDCClaims
if err := idToken.Claims(&claims); err != nil {
err = fmt.Errorf("decoding ID token claims: %w", err)
httpError(writer, err, err.Error(), http.StatusInternalServerError)
httpError(writer, fmt.Errorf("decoding ID token claims: %w", err))
return
}
if err := validateOIDCAllowedDomains(a.cfg.AllowedDomains, &claims); err != nil {
httpError(writer, err, err.Error(), http.StatusUnauthorized)
httpError(writer, err)
return
}
if err := validateOIDCAllowedGroups(a.cfg.AllowedGroups, &claims); err != nil {
httpError(writer, err, err.Error(), http.StatusUnauthorized)
httpError(writer, err)
return
}
if err := validateOIDCAllowedUsers(a.cfg.AllowedUsers, &claims); err != nil {
httpError(writer, err, err.Error(), http.StatusUnauthorized)
httpError(writer, err)
return
}
user, err := a.createOrUpdateUserFromClaim(&claims)
if err != nil {
httpError(writer, err, err.Error(), http.StatusInternalServerError)
httpError(writer, err)
return
}
@ -289,9 +288,9 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
// Register the node if it does not exist.
if registrationId != nil {
verb := "Reauthenticated"
newNode, err := a.handleRegistrationID(user, *registrationId, nodeExpiry)
newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry)
if err != nil {
httpError(writer, err, err.Error(), http.StatusInternalServerError)
httpError(writer, err)
return
}
@ -302,7 +301,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
// TODO(kradalby): replace with go-elem
content, err := renderOIDCCallbackTemplate(user, verb)
if err != nil {
httpError(writer, err, err.Error(), http.StatusInternalServerError)
httpError(writer, err)
return
}
@ -317,7 +316,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
// Neither node nor machine key was found in the state cache meaning
// that we could not reauth nor register the node.
httpError(writer, nil, "login session expired, try again", http.StatusInternalServerError)
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil))
return
}
@ -328,7 +327,7 @@ func extractCodeAndStateParamFromRequest(
state := req.URL.Query().Get("state")
if code == "" || state == "" {
return "", "", errEmptyOIDCCallbackParams
return "", "", NewHTTPError(http.StatusBadRequest, "missing code or state parameter", errEmptyOIDCCallbackParams)
}
return code, state, nil
@ -346,7 +345,7 @@ func (a *AuthProviderOIDC) extractIDToken(
if a.cfg.PKCE.Enabled {
regInfo, ok := a.registrationCache.Get(state)
if !ok {
return nil, errNoOIDCRegistrationInfo
return nil, NewHTTPError(http.StatusNotFound, "registration not found", errNoOIDCRegistrationInfo)
}
if regInfo.Verifier != nil {
exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(*regInfo.Verifier)}
@ -355,18 +354,18 @@ func (a *AuthProviderOIDC) extractIDToken(
oauth2Token, err := a.oauth2Config.Exchange(ctx, code, exchangeOpts...)
if err != nil {
return nil, fmt.Errorf("could not exchange code for token: %w", err)
return nil, NewHTTPError(http.StatusForbidden, "invalid code", fmt.Errorf("could not exchange code for token: %w", err))
}
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
if !ok {
return nil, errNoOIDCIDToken
return nil, NewHTTPError(http.StatusBadRequest, "no id_token", errNoOIDCIDToken)
}
verifier := a.oidcProvider.Verifier(&oidc.Config{ClientID: a.cfg.ClientID})
idToken, err := verifier.Verify(ctx, rawIDToken)
if err != nil {
return nil, fmt.Errorf("failed to verify ID token: %w", err)
return nil, NewHTTPError(http.StatusForbidden, "failed to verify id_token", fmt.Errorf("failed to verify ID token: %w", err))
}
return idToken, nil
@ -381,7 +380,7 @@ func validateOIDCAllowedDomains(
if len(allowedDomains) > 0 {
if at := strings.LastIndex(claims.Email, "@"); at < 0 ||
!slices.Contains(allowedDomains, claims.Email[at+1:]) {
return errOIDCAllowedDomains
return NewHTTPError(http.StatusUnauthorized, "unauthorised domain", errOIDCAllowedDomains)
}
}
@ -403,7 +402,7 @@ func validateOIDCAllowedGroups(
}
}
return errOIDCAllowedGroups
return NewHTTPError(http.StatusUnauthorized, "unauthorised group", errOIDCAllowedGroups)
}
return nil
@ -417,7 +416,7 @@ func validateOIDCAllowedUsers(
) error {
if len(allowedUsers) > 0 &&
!slices.Contains(allowedUsers, claims.Email) {
return errOIDCAllowedUsers
return NewHTTPError(http.StatusUnauthorized, "unauthorised user", errOIDCAllowedUsers)
}
return nil
@ -488,7 +487,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
return user, nil
}
func (a *AuthProviderOIDC) handleRegistrationID(
func (a *AuthProviderOIDC) handleRegistration(
user *types.User,
registrationID types.RegistrationID,
expiry time.Time,