feat: Add PKCE Verifier for OIDC (#2314)

* feat: add PKCE verifier for OIDC

* Update CHANGELOG.md
This commit is contained in:
Rorical 2024-12-23 00:46:36 +08:00 committed by GitHub
parent 9313e5b058
commit b81420bef1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 187 additions and 15 deletions

View file

@ -28,12 +28,14 @@ import (
)
const (
randomByteSize = 16
randomByteSize = 16
defaultOAuthOptionsCount = 3
)
var (
errEmptyOIDCCallbackParams = errors.New("empty OIDC callback params")
errNoOIDCIDToken = errors.New("could not extract ID Token for OIDC callback")
errNoOIDCRegistrationInfo = errors.New("could not get registration info from cache")
errOIDCAllowedDomains = errors.New(
"authenticated principal does not match any allowed domain",
)
@ -47,11 +49,17 @@ var (
errOIDCNodeKeyMissing = errors.New("could not get node key from cache")
)
// RegistrationInfo contains both machine key and verifier information for OIDC validation.
type RegistrationInfo struct {
MachineKey key.MachinePublic
Verifier *string
}
type AuthProviderOIDC struct {
serverURL string
cfg *types.OIDCConfig
db *db.HSDatabase
registrationCache *zcache.Cache[string, key.MachinePublic]
registrationCache *zcache.Cache[string, RegistrationInfo]
notifier *notifier.Notifier
ipAlloc *db.IPAllocator
polMan policy.PolicyManager
@ -87,7 +95,7 @@ func NewAuthProviderOIDC(
Scopes: cfg.Scope,
}
registrationCache := zcache.New[string, key.MachinePublic](
registrationCache := zcache.New[string, RegistrationInfo](
registerCacheExpiration,
registerCacheCleanup,
)
@ -157,19 +165,36 @@ func (a *AuthProviderOIDC) RegisterHandler(
stateStr := hex.EncodeToString(randomBlob)[:32]
// place the node key into the state cache, so it can be retrieved later
a.registrationCache.Set(
stateStr,
machineKey,
)
// Initialize registration info with machine key
registrationInfo := RegistrationInfo{
MachineKey: machineKey,
}
// Add any extra parameter provided in the configuration to the Authorize Endpoint request
extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams))
extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams)+defaultOAuthOptionsCount)
// Add PKCE verification if enabled
if a.cfg.PKCE.Enabled {
verifier := oauth2.GenerateVerifier()
registrationInfo.Verifier = &verifier
extras = append(extras, oauth2.AccessTypeOffline)
switch a.cfg.PKCE.Method {
case types.PKCEMethodS256:
extras = append(extras, oauth2.S256ChallengeOption(verifier))
case types.PKCEMethodPlain:
// oauth2 does not have a plain challenge option, so we add it manually
extras = append(extras, oauth2.SetAuthURLParam("code_challenge_method", "plain"), oauth2.SetAuthURLParam("code_challenge", verifier))
}
}
// Add any extra parameters from configuration
for k, v := range a.cfg.ExtraParams {
extras = append(extras, oauth2.SetAuthURLParam(k, v))
}
// Cache the registration info
a.registrationCache.Set(stateStr, registrationInfo)
authURL := a.oauth2Config.AuthCodeURL(stateStr, extras...)
log.Debug().Msgf("Redirecting to %s for authentication", authURL)
@ -203,7 +228,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
return
}
idToken, err := a.extractIDToken(req.Context(), code)
idToken, err := a.extractIDToken(req.Context(), code, state)
if err != nil {
http.Error(writer, err.Error(), http.StatusBadRequest)
return
@ -318,8 +343,21 @@ func extractCodeAndStateParamFromRequest(
func (a *AuthProviderOIDC) extractIDToken(
ctx context.Context,
code string,
state string,
) (*oidc.IDToken, error) {
oauth2Token, err := a.oauth2Config.Exchange(ctx, code)
var exchangeOpts []oauth2.AuthCodeOption
if a.cfg.PKCE.Enabled {
regInfo, ok := a.registrationCache.Get(state)
if !ok {
return nil, errNoOIDCRegistrationInfo
}
if regInfo.Verifier != nil {
exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(*regInfo.Verifier)}
}
}
oauth2Token, err := a.oauth2Config.Exchange(ctx, code, exchangeOpts...)
if err != nil {
return nil, fmt.Errorf("could not exchange code for token: %w", err)
}
@ -394,7 +432,7 @@ func validateOIDCAllowedUsers(
// cache. If the machine key is found, it will try retrieve the
// node information from the database.
func (a *AuthProviderOIDC) getMachineKeyFromState(state string) (*types.Node, *key.MachinePublic) {
machineKey, ok := a.registrationCache.Get(state)
regInfo, ok := a.registrationCache.Get(state)
if !ok {
return nil, nil
}
@ -403,9 +441,9 @@ func (a *AuthProviderOIDC) getMachineKeyFromState(state string) (*types.Node, *k
// The error is not important, because if it does not
// exist, then this is a new node and we will move
// on to registration.
node, _ := a.db.GetNodeByMachineKey(machineKey)
node, _ := a.db.GetNodeByMachineKey(regInfo.MachineKey)
return node, &machineKey
return node, &regInfo.MachineKey
}
// reauthenticateNode updates the node expiry in the database