feat: Add PKCE Verifier for OIDC (#2314)

* feat: add PKCE verifier for OIDC

* Update CHANGELOG.md
This commit is contained in:
Rorical 2024-12-23 00:46:36 +08:00 committed by GitHub
parent 9313e5b058
commit b81420bef1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 187 additions and 15 deletions

View file

@ -26,11 +26,14 @@ import (
const (
defaultOIDCExpiryTime = 180 * 24 * time.Hour // 180 Days
maxDuration time.Duration = 1<<63 - 1
PKCEMethodPlain string = "plain"
PKCEMethodS256 string = "S256"
)
var (
errOidcMutuallyExclusive = errors.New("oidc_client_secret and oidc_client_secret_path are mutually exclusive")
errServerURLSuffix = errors.New("server_url cannot be part of base_domain in a way that could make the DERP and headscale server unreachable")
errInvalidPKCEMethod = errors.New("pkce.method must be either 'plain' or 'S256'")
)
type IPAllocationStrategy string
@ -162,6 +165,11 @@ type LetsEncryptConfig struct {
ChallengeType string
}
type PKCEConfig struct {
Enabled bool
Method string
}
type OIDCConfig struct {
OnlyStartIfOIDCIsAvailable bool
Issuer string
@ -176,6 +184,7 @@ type OIDCConfig struct {
Expiry time.Duration
UseExpiryFromToken bool
MapLegacyUsers bool
PKCE PKCEConfig
}
type DERPConfig struct {
@ -226,6 +235,13 @@ type Tuning struct {
NodeMapSessionBufferedChanSize int
}
func validatePKCEMethod(method string) error {
if method != PKCEMethodPlain && method != PKCEMethodS256 {
return errInvalidPKCEMethod
}
return nil
}
// LoadConfig prepares and loads the Headscale configuration into Viper.
// This means it sets the default values, reads the configuration file and
// environment variables, and handles deprecated configuration options.
@ -293,6 +309,8 @@ func LoadConfig(path string, isFile bool) error {
viper.SetDefault("oidc.expiry", "180d")
viper.SetDefault("oidc.use_expiry_from_token", false)
viper.SetDefault("oidc.map_legacy_users", true)
viper.SetDefault("oidc.pkce.enabled", false)
viper.SetDefault("oidc.pkce.method", "S256")
viper.SetDefault("logtail.enabled", false)
viper.SetDefault("randomize_client_port", false)
@ -340,6 +358,12 @@ func validateServerConfig() error {
// after #2170 is cleaned up
// depr.fatal("oidc.strip_email_domain")
if viper.GetBool("oidc.enabled") {
if err := validatePKCEMethod(viper.GetString("oidc.pkce.method")); err != nil {
return err
}
}
depr.Log()
for _, removed := range []string{
@ -928,6 +952,10 @@ func LoadServerConfig() (*Config, error) {
// after #2170 is cleaned up
StripEmaildomain: viper.GetBool("oidc.strip_email_domain"),
MapLegacyUsers: viper.GetBool("oidc.map_legacy_users"),
PKCE: PKCEConfig{
Enabled: viper.GetBool("oidc.pkce.enabled"),
Method: viper.GetString("oidc.pkce.method"),
},
},
LogTail: logTailConfig,