updates from code review
This commit is contained in:
parent
35795c79c3
commit
e407d423d4
4 changed files with 131 additions and 43 deletions
88
oidc.go
88
oidc.go
|
@ -13,6 +13,7 @@ import (
|
|||
"golang.org/x/oauth2"
|
||||
"gorm.io/gorm"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -23,9 +24,33 @@ type IDTokenClaims struct {
|
|||
Username string `json:"preferred_username,omitempty"`
|
||||
}
|
||||
|
||||
var oidcProvider *oidc.Provider
|
||||
var oauth2Config *oauth2.Config
|
||||
var stateCache *cache.Cache
|
||||
func (h *Headscale) initOIDC() error {
|
||||
var err error
|
||||
// grab oidc config if it hasn't been already
|
||||
if h.oauth2Config == nil {
|
||||
h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDCIssuer)
|
||||
|
||||
if err != nil {
|
||||
log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
h.oauth2Config = &oauth2.Config{
|
||||
ClientID: h.cfg.OIDCClientID,
|
||||
ClientSecret: h.cfg.OIDCClientSecret,
|
||||
Endpoint: h.oidcProvider.Endpoint(),
|
||||
RedirectURL: fmt.Sprintf("%s/oidc/callback", strings.TrimSuffix(h.cfg.ServerURL, "/")),
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||
}
|
||||
}
|
||||
|
||||
// init the state cache if it hasn't been already
|
||||
if h.oidcStateCache == nil {
|
||||
h.oidcStateCache = cache.New(time.Minute*5, time.Minute*10)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterOIDC redirects to the OIDC provider for authentication
|
||||
// Puts machine key in cache so the callback can retrieve it using the oidc state param
|
||||
|
@ -37,30 +62,8 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
// grab oidc config if it hasn't been already
|
||||
if oauth2Config == nil {
|
||||
oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDCIssuer)
|
||||
|
||||
if err != nil {
|
||||
log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error())
|
||||
c.String(http.StatusInternalServerError, "Could not retrieve OIDC Config")
|
||||
return
|
||||
}
|
||||
|
||||
oauth2Config = &oauth2.Config{
|
||||
ClientID: h.cfg.OIDCClientID,
|
||||
ClientSecret: h.cfg.OIDCClientSecret,
|
||||
Endpoint: oidcProvider.Endpoint(),
|
||||
RedirectURL: fmt.Sprintf("%s/oidc/callback", strings.TrimSuffix(h.cfg.ServerURL, "/")),
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
b := make([]byte, 16)
|
||||
_, err = rand.Read(b)
|
||||
_, err := rand.Read(b)
|
||||
|
||||
if err != nil {
|
||||
log.Error().Msg("could not read 16 bytes from rand")
|
||||
|
@ -70,15 +73,10 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) {
|
|||
|
||||
stateStr := hex.EncodeToString(b)[:32]
|
||||
|
||||
// init the state cache if it hasn't been already
|
||||
if stateCache == nil {
|
||||
stateCache = cache.New(time.Minute*5, time.Minute*10)
|
||||
}
|
||||
|
||||
// place the machine key into the state cache, so it can be retrieved later
|
||||
stateCache.Set(stateStr, mKeyStr, time.Minute*5)
|
||||
h.oidcStateCache.Set(stateStr, mKeyStr, time.Minute*5)
|
||||
|
||||
authUrl := oauth2Config.AuthCodeURL(stateStr)
|
||||
authUrl := h.oauth2Config.AuthCodeURL(stateStr)
|
||||
log.Debug().Msgf("Redirecting to %s for authentication", authUrl)
|
||||
|
||||
c.Redirect(http.StatusFound, authUrl)
|
||||
|
@ -99,7 +97,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
oauth2Token, err := oauth2Config.Exchange(context.Background(), code)
|
||||
oauth2Token, err := h.oauth2Config.Exchange(context.Background(), code)
|
||||
if err != nil {
|
||||
c.String(http.StatusBadRequest, "Could not exchange code for token")
|
||||
return
|
||||
|
@ -111,7 +109,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
verifier := oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDCClientID})
|
||||
verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDCClientID})
|
||||
|
||||
idToken, err := verifier.Verify(context.Background(), rawIDToken)
|
||||
if err != nil {
|
||||
|
@ -133,7 +131,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
|||
}
|
||||
|
||||
//retrieve machinekey from state cache
|
||||
mKeyIf, mKeyFound := stateCache.Get(state)
|
||||
mKeyIf, mKeyFound := h.oidcStateCache.Get(state)
|
||||
|
||||
if !mKeyFound {
|
||||
c.String(http.StatusBadRequest, "state has expired")
|
||||
|
@ -157,6 +155,8 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
|||
//look for a namespace of the users email for now
|
||||
if !m.Registered {
|
||||
|
||||
log.Debug().Msg("Registering new machine after successful callback")
|
||||
|
||||
ns, err := h.GetNamespace(claims.Email)
|
||||
if err != nil {
|
||||
ns, err = h.CreateNamespace(claims.Email)
|
||||
|
@ -182,6 +182,22 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
|||
h.db.Save(&m)
|
||||
}
|
||||
|
||||
if m.isExpired() {
|
||||
maxExpiry := time.Now().UTC().Add(h.cfg.MaxMachineExpiry)
|
||||
|
||||
// use the maximum expiry if it's sooner than the requested expiry
|
||||
if maxExpiry.Before(*m.Expiry) {
|
||||
log.Debug().Msgf("Clamping expiry time to maximum: %v (%v)", maxExpiry, h.cfg.MaxMachineExpiry)
|
||||
m.Expiry = &maxExpiry
|
||||
h.db.Save(&m)
|
||||
} else if m.Expiry.IsZero() {
|
||||
log.Debug().Msgf("Using default machine expiry time: %v (%v)", maxExpiry, h.cfg.MaxMachineExpiry)
|
||||
defaultExpiry := time.Now().UTC().Add(h.cfg.DefaultMachineExpiry)
|
||||
m.Expiry = &defaultExpiry
|
||||
h.db.Save(&m)
|
||||
}
|
||||
}
|
||||
|
||||
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
|
||||
<html>
|
||||
<body>
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue