updates from code review

This commit is contained in:
Raal Goff 2021-10-08 17:43:52 +08:00
parent 35795c79c3
commit e407d423d4
4 changed files with 131 additions and 43 deletions

88
oidc.go
View file

@ -13,6 +13,7 @@ import (
"golang.org/x/oauth2"
"gorm.io/gorm"
"net/http"
"strings"
"time"
)
@ -23,9 +24,33 @@ type IDTokenClaims struct {
Username string `json:"preferred_username,omitempty"`
}
var oidcProvider *oidc.Provider
var oauth2Config *oauth2.Config
var stateCache *cache.Cache
func (h *Headscale) initOIDC() error {
var err error
// grab oidc config if it hasn't been already
if h.oauth2Config == nil {
h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDCIssuer)
if err != nil {
log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error())
return err
}
h.oauth2Config = &oauth2.Config{
ClientID: h.cfg.OIDCClientID,
ClientSecret: h.cfg.OIDCClientSecret,
Endpoint: h.oidcProvider.Endpoint(),
RedirectURL: fmt.Sprintf("%s/oidc/callback", strings.TrimSuffix(h.cfg.ServerURL, "/")),
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
}
}
// init the state cache if it hasn't been already
if h.oidcStateCache == nil {
h.oidcStateCache = cache.New(time.Minute*5, time.Minute*10)
}
return nil
}
// RegisterOIDC redirects to the OIDC provider for authentication
// Puts machine key in cache so the callback can retrieve it using the oidc state param
@ -37,30 +62,8 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) {
return
}
var err error
// grab oidc config if it hasn't been already
if oauth2Config == nil {
oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDCIssuer)
if err != nil {
log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error())
c.String(http.StatusInternalServerError, "Could not retrieve OIDC Config")
return
}
oauth2Config = &oauth2.Config{
ClientID: h.cfg.OIDCClientID,
ClientSecret: h.cfg.OIDCClientSecret,
Endpoint: oidcProvider.Endpoint(),
RedirectURL: fmt.Sprintf("%s/oidc/callback", strings.TrimSuffix(h.cfg.ServerURL, "/")),
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
}
}
b := make([]byte, 16)
_, err = rand.Read(b)
_, err := rand.Read(b)
if err != nil {
log.Error().Msg("could not read 16 bytes from rand")
@ -70,15 +73,10 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) {
stateStr := hex.EncodeToString(b)[:32]
// init the state cache if it hasn't been already
if stateCache == nil {
stateCache = cache.New(time.Minute*5, time.Minute*10)
}
// place the machine key into the state cache, so it can be retrieved later
stateCache.Set(stateStr, mKeyStr, time.Minute*5)
h.oidcStateCache.Set(stateStr, mKeyStr, time.Minute*5)
authUrl := oauth2Config.AuthCodeURL(stateStr)
authUrl := h.oauth2Config.AuthCodeURL(stateStr)
log.Debug().Msgf("Redirecting to %s for authentication", authUrl)
c.Redirect(http.StatusFound, authUrl)
@ -99,7 +97,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
return
}
oauth2Token, err := oauth2Config.Exchange(context.Background(), code)
oauth2Token, err := h.oauth2Config.Exchange(context.Background(), code)
if err != nil {
c.String(http.StatusBadRequest, "Could not exchange code for token")
return
@ -111,7 +109,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
return
}
verifier := oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDCClientID})
verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDCClientID})
idToken, err := verifier.Verify(context.Background(), rawIDToken)
if err != nil {
@ -133,7 +131,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
}
//retrieve machinekey from state cache
mKeyIf, mKeyFound := stateCache.Get(state)
mKeyIf, mKeyFound := h.oidcStateCache.Get(state)
if !mKeyFound {
c.String(http.StatusBadRequest, "state has expired")
@ -157,6 +155,8 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
//look for a namespace of the users email for now
if !m.Registered {
log.Debug().Msg("Registering new machine after successful callback")
ns, err := h.GetNamespace(claims.Email)
if err != nil {
ns, err = h.CreateNamespace(claims.Email)
@ -182,6 +182,22 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
h.db.Save(&m)
}
if m.isExpired() {
maxExpiry := time.Now().UTC().Add(h.cfg.MaxMachineExpiry)
// use the maximum expiry if it's sooner than the requested expiry
if maxExpiry.Before(*m.Expiry) {
log.Debug().Msgf("Clamping expiry time to maximum: %v (%v)", maxExpiry, h.cfg.MaxMachineExpiry)
m.Expiry = &maxExpiry
h.db.Save(&m)
} else if m.Expiry.IsZero() {
log.Debug().Msgf("Using default machine expiry time: %v (%v)", maxExpiry, h.cfg.MaxMachineExpiry)
defaultExpiry := time.Now().UTC().Add(h.cfg.DefaultMachineExpiry)
m.Expiry = &defaultExpiry
h.db.Save(&m)
}
}
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
<html>
<body>