feat: Add PKCE Verifier for OIDC (#2314)
* feat: add PKCE verifier for OIDC * Update CHANGELOG.md
This commit is contained in:
parent
9313e5b058
commit
b81420bef1
7 changed files with 187 additions and 15 deletions
|
@ -28,12 +28,14 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
randomByteSize = 16
|
||||
randomByteSize = 16
|
||||
defaultOAuthOptionsCount = 3
|
||||
)
|
||||
|
||||
var (
|
||||
errEmptyOIDCCallbackParams = errors.New("empty OIDC callback params")
|
||||
errNoOIDCIDToken = errors.New("could not extract ID Token for OIDC callback")
|
||||
errNoOIDCRegistrationInfo = errors.New("could not get registration info from cache")
|
||||
errOIDCAllowedDomains = errors.New(
|
||||
"authenticated principal does not match any allowed domain",
|
||||
)
|
||||
|
@ -47,11 +49,17 @@ var (
|
|||
errOIDCNodeKeyMissing = errors.New("could not get node key from cache")
|
||||
)
|
||||
|
||||
// RegistrationInfo contains both machine key and verifier information for OIDC validation.
|
||||
type RegistrationInfo struct {
|
||||
MachineKey key.MachinePublic
|
||||
Verifier *string
|
||||
}
|
||||
|
||||
type AuthProviderOIDC struct {
|
||||
serverURL string
|
||||
cfg *types.OIDCConfig
|
||||
db *db.HSDatabase
|
||||
registrationCache *zcache.Cache[string, key.MachinePublic]
|
||||
registrationCache *zcache.Cache[string, RegistrationInfo]
|
||||
notifier *notifier.Notifier
|
||||
ipAlloc *db.IPAllocator
|
||||
polMan policy.PolicyManager
|
||||
|
@ -87,7 +95,7 @@ func NewAuthProviderOIDC(
|
|||
Scopes: cfg.Scope,
|
||||
}
|
||||
|
||||
registrationCache := zcache.New[string, key.MachinePublic](
|
||||
registrationCache := zcache.New[string, RegistrationInfo](
|
||||
registerCacheExpiration,
|
||||
registerCacheCleanup,
|
||||
)
|
||||
|
@ -157,19 +165,36 @@ func (a *AuthProviderOIDC) RegisterHandler(
|
|||
|
||||
stateStr := hex.EncodeToString(randomBlob)[:32]
|
||||
|
||||
// place the node key into the state cache, so it can be retrieved later
|
||||
a.registrationCache.Set(
|
||||
stateStr,
|
||||
machineKey,
|
||||
)
|
||||
// Initialize registration info with machine key
|
||||
registrationInfo := RegistrationInfo{
|
||||
MachineKey: machineKey,
|
||||
}
|
||||
|
||||
// Add any extra parameter provided in the configuration to the Authorize Endpoint request
|
||||
extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams))
|
||||
extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams)+defaultOAuthOptionsCount)
|
||||
// Add PKCE verification if enabled
|
||||
if a.cfg.PKCE.Enabled {
|
||||
verifier := oauth2.GenerateVerifier()
|
||||
registrationInfo.Verifier = &verifier
|
||||
|
||||
extras = append(extras, oauth2.AccessTypeOffline)
|
||||
|
||||
switch a.cfg.PKCE.Method {
|
||||
case types.PKCEMethodS256:
|
||||
extras = append(extras, oauth2.S256ChallengeOption(verifier))
|
||||
case types.PKCEMethodPlain:
|
||||
// oauth2 does not have a plain challenge option, so we add it manually
|
||||
extras = append(extras, oauth2.SetAuthURLParam("code_challenge_method", "plain"), oauth2.SetAuthURLParam("code_challenge", verifier))
|
||||
}
|
||||
}
|
||||
|
||||
// Add any extra parameters from configuration
|
||||
for k, v := range a.cfg.ExtraParams {
|
||||
extras = append(extras, oauth2.SetAuthURLParam(k, v))
|
||||
}
|
||||
|
||||
// Cache the registration info
|
||||
a.registrationCache.Set(stateStr, registrationInfo)
|
||||
|
||||
authURL := a.oauth2Config.AuthCodeURL(stateStr, extras...)
|
||||
log.Debug().Msgf("Redirecting to %s for authentication", authURL)
|
||||
|
||||
|
@ -203,7 +228,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||
return
|
||||
}
|
||||
|
||||
idToken, err := a.extractIDToken(req.Context(), code)
|
||||
idToken, err := a.extractIDToken(req.Context(), code, state)
|
||||
if err != nil {
|
||||
http.Error(writer, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
|
@ -318,8 +343,21 @@ func extractCodeAndStateParamFromRequest(
|
|||
func (a *AuthProviderOIDC) extractIDToken(
|
||||
ctx context.Context,
|
||||
code string,
|
||||
state string,
|
||||
) (*oidc.IDToken, error) {
|
||||
oauth2Token, err := a.oauth2Config.Exchange(ctx, code)
|
||||
var exchangeOpts []oauth2.AuthCodeOption
|
||||
|
||||
if a.cfg.PKCE.Enabled {
|
||||
regInfo, ok := a.registrationCache.Get(state)
|
||||
if !ok {
|
||||
return nil, errNoOIDCRegistrationInfo
|
||||
}
|
||||
if regInfo.Verifier != nil {
|
||||
exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(*regInfo.Verifier)}
|
||||
}
|
||||
}
|
||||
|
||||
oauth2Token, err := a.oauth2Config.Exchange(ctx, code, exchangeOpts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not exchange code for token: %w", err)
|
||||
}
|
||||
|
@ -394,7 +432,7 @@ func validateOIDCAllowedUsers(
|
|||
// cache. If the machine key is found, it will try retrieve the
|
||||
// node information from the database.
|
||||
func (a *AuthProviderOIDC) getMachineKeyFromState(state string) (*types.Node, *key.MachinePublic) {
|
||||
machineKey, ok := a.registrationCache.Get(state)
|
||||
regInfo, ok := a.registrationCache.Get(state)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -403,9 +441,9 @@ func (a *AuthProviderOIDC) getMachineKeyFromState(state string) (*types.Node, *k
|
|||
// The error is not important, because if it does not
|
||||
// exist, then this is a new node and we will move
|
||||
// on to registration.
|
||||
node, _ := a.db.GetNodeByMachineKey(machineKey)
|
||||
node, _ := a.db.GetNodeByMachineKey(regInfo.MachineKey)
|
||||
|
||||
return node, &machineKey
|
||||
return node, ®Info.MachineKey
|
||||
}
|
||||
|
||||
// reauthenticateNode updates the node expiry in the database
|
||||
|
|
|
@ -26,11 +26,14 @@ import (
|
|||
const (
|
||||
defaultOIDCExpiryTime = 180 * 24 * time.Hour // 180 Days
|
||||
maxDuration time.Duration = 1<<63 - 1
|
||||
PKCEMethodPlain string = "plain"
|
||||
PKCEMethodS256 string = "S256"
|
||||
)
|
||||
|
||||
var (
|
||||
errOidcMutuallyExclusive = errors.New("oidc_client_secret and oidc_client_secret_path are mutually exclusive")
|
||||
errServerURLSuffix = errors.New("server_url cannot be part of base_domain in a way that could make the DERP and headscale server unreachable")
|
||||
errInvalidPKCEMethod = errors.New("pkce.method must be either 'plain' or 'S256'")
|
||||
)
|
||||
|
||||
type IPAllocationStrategy string
|
||||
|
@ -162,6 +165,11 @@ type LetsEncryptConfig struct {
|
|||
ChallengeType string
|
||||
}
|
||||
|
||||
type PKCEConfig struct {
|
||||
Enabled bool
|
||||
Method string
|
||||
}
|
||||
|
||||
type OIDCConfig struct {
|
||||
OnlyStartIfOIDCIsAvailable bool
|
||||
Issuer string
|
||||
|
@ -176,6 +184,7 @@ type OIDCConfig struct {
|
|||
Expiry time.Duration
|
||||
UseExpiryFromToken bool
|
||||
MapLegacyUsers bool
|
||||
PKCE PKCEConfig
|
||||
}
|
||||
|
||||
type DERPConfig struct {
|
||||
|
@ -226,6 +235,13 @@ type Tuning struct {
|
|||
NodeMapSessionBufferedChanSize int
|
||||
}
|
||||
|
||||
func validatePKCEMethod(method string) error {
|
||||
if method != PKCEMethodPlain && method != PKCEMethodS256 {
|
||||
return errInvalidPKCEMethod
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadConfig prepares and loads the Headscale configuration into Viper.
|
||||
// This means it sets the default values, reads the configuration file and
|
||||
// environment variables, and handles deprecated configuration options.
|
||||
|
@ -293,6 +309,8 @@ func LoadConfig(path string, isFile bool) error {
|
|||
viper.SetDefault("oidc.expiry", "180d")
|
||||
viper.SetDefault("oidc.use_expiry_from_token", false)
|
||||
viper.SetDefault("oidc.map_legacy_users", true)
|
||||
viper.SetDefault("oidc.pkce.enabled", false)
|
||||
viper.SetDefault("oidc.pkce.method", "S256")
|
||||
|
||||
viper.SetDefault("logtail.enabled", false)
|
||||
viper.SetDefault("randomize_client_port", false)
|
||||
|
@ -340,6 +358,12 @@ func validateServerConfig() error {
|
|||
// after #2170 is cleaned up
|
||||
// depr.fatal("oidc.strip_email_domain")
|
||||
|
||||
if viper.GetBool("oidc.enabled") {
|
||||
if err := validatePKCEMethod(viper.GetString("oidc.pkce.method")); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
depr.Log()
|
||||
|
||||
for _, removed := range []string{
|
||||
|
@ -928,6 +952,10 @@ func LoadServerConfig() (*Config, error) {
|
|||
// after #2170 is cleaned up
|
||||
StripEmaildomain: viper.GetBool("oidc.strip_email_domain"),
|
||||
MapLegacyUsers: viper.GetBool("oidc.map_legacy_users"),
|
||||
PKCE: PKCEConfig{
|
||||
Enabled: viper.GetBool("oidc.pkce.enabled"),
|
||||
Method: viper.GetString("oidc.pkce.method"),
|
||||
},
|
||||
},
|
||||
|
||||
LogTail: logTailConfig,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue