Implement namespace matching
This commit is contained in:
parent
a347d276bd
commit
677bd9b657
5 changed files with 267 additions and 55 deletions
93
oidc.go
93
oidc.go
|
@ -5,14 +5,16 @@ import (
|
|||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/patrickmn/go-cache"
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/oauth2"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type IDTokenClaims struct {
|
||||
|
@ -26,7 +28,7 @@ func (h *Headscale) initOIDC() error {
|
|||
var err error
|
||||
// grab oidc config if it hasn't been already
|
||||
if h.oauth2Config == nil {
|
||||
h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDCIssuer)
|
||||
h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDC.Issuer)
|
||||
|
||||
if err != nil {
|
||||
log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error())
|
||||
|
@ -34,8 +36,8 @@ func (h *Headscale) initOIDC() error {
|
|||
}
|
||||
|
||||
h.oauth2Config = &oauth2.Config{
|
||||
ClientID: h.cfg.OIDCClientID,
|
||||
ClientSecret: h.cfg.OIDCClientSecret,
|
||||
ClientID: h.cfg.OIDC.ClientID,
|
||||
ClientSecret: h.cfg.OIDC.ClientSecret,
|
||||
Endpoint: h.oidcProvider.Endpoint(),
|
||||
RedirectURL: fmt.Sprintf("%s/oidc/callback", strings.TrimSuffix(h.cfg.ServerURL, "/")),
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||
|
@ -62,7 +64,6 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) {
|
|||
|
||||
b := make([]byte, 16)
|
||||
_, err := rand.Read(b)
|
||||
|
||||
if err != nil {
|
||||
log.Error().Msg("could not read 16 bytes from rand")
|
||||
c.String(http.StatusInternalServerError, "could not read 16 bytes from rand")
|
||||
|
@ -86,7 +87,6 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) {
|
|||
// TODO: Add groups information from OIDC tokens into machine HostInfo
|
||||
// Listens in /oidc/callback
|
||||
func (h *Headscale) OIDCCallback(c *gin.Context) {
|
||||
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
|
||||
|
@ -109,7 +109,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDCClientID})
|
||||
verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDC.ClientID})
|
||||
|
||||
idToken, err := verifier.Verify(context.Background(), rawIDToken)
|
||||
if err != nil {
|
||||
|
@ -131,7 +131,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
//retrieve machinekey from state cache
|
||||
// retrieve machinekey from state cache
|
||||
mKeyIf, mKeyFound := h.oidcStateCache.Get(state)
|
||||
|
||||
if !mKeyFound {
|
||||
|
@ -149,7 +149,6 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
|||
|
||||
// retrieve machine information
|
||||
m, err := h.GetMachineByMachineKey(mKeyStr)
|
||||
|
||||
if err != nil {
|
||||
log.Error().Msg("machine key not found in database")
|
||||
c.String(http.StatusInternalServerError, "could not get machine info from database")
|
||||
|
@ -158,40 +157,40 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
|||
|
||||
now := time.Now().UTC()
|
||||
|
||||
// register the machine if it's new
|
||||
if !m.Registered {
|
||||
nsName := strings.ReplaceAll(claims.Email, "@", "-") // TODO: Implement a better email sanitisation
|
||||
if nsName, ok := h.getNamespaceFromEmail(claims.Email); ok {
|
||||
// register the machine if it's new
|
||||
if !m.Registered {
|
||||
|
||||
log.Debug().Msg("Registering new machine after successful callback")
|
||||
|
||||
ns, err := h.GetNamespace(nsName)
|
||||
if err != nil {
|
||||
ns, err = h.CreateNamespace(nsName)
|
||||
log.Debug().Msg("Registering new machine after successful callback")
|
||||
|
||||
ns, err := h.GetNamespace(nsName)
|
||||
if err != nil {
|
||||
log.Error().Msgf("could not create new namespace '%s'", claims.Email)
|
||||
c.String(http.StatusInternalServerError, "could not create new namespace")
|
||||
ns, err = h.CreateNamespace(nsName)
|
||||
|
||||
if err != nil {
|
||||
log.Error().Msgf("could not create new namespace '%s'", claims.Email)
|
||||
c.String(http.StatusInternalServerError, "could not create new namespace")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ip, err := h.getAvailableIP()
|
||||
if err != nil {
|
||||
c.String(http.StatusInternalServerError, "could not get an IP from the pool")
|
||||
return
|
||||
}
|
||||
|
||||
m.IPAddress = ip.String()
|
||||
m.NamespaceID = ns.ID
|
||||
m.Registered = true
|
||||
m.RegisterMethod = "oidc"
|
||||
m.LastSuccessfulUpdate = &now
|
||||
h.db.Save(&m)
|
||||
}
|
||||
|
||||
ip, err := h.getAvailableIP()
|
||||
if err != nil {
|
||||
c.String(http.StatusInternalServerError, "could not get an IP from the pool")
|
||||
return
|
||||
}
|
||||
h.updateMachineExpiry(m)
|
||||
|
||||
m.IPAddress = ip.String()
|
||||
m.NamespaceID = ns.ID
|
||||
m.Registered = true
|
||||
m.RegisterMethod = "oidc"
|
||||
m.LastSuccessfulUpdate = &now
|
||||
h.db.Save(&m)
|
||||
}
|
||||
|
||||
h.updateMachineExpiry(m)
|
||||
|
||||
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
|
||||
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
|
||||
<html>
|
||||
<body>
|
||||
<h1>headscale</h1>
|
||||
|
@ -202,4 +201,24 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
|||
</html>
|
||||
|
||||
`, claims.Email)))
|
||||
|
||||
}
|
||||
|
||||
log.Error().
|
||||
Str("email", claims.Email).
|
||||
Str("username", claims.Username).
|
||||
Str("machine", m.Name).
|
||||
Msg("Email could not be mapped to a namespace")
|
||||
c.String(http.StatusBadRequest, "email from claim could not be mapped to a namespace")
|
||||
}
|
||||
|
||||
func (h *Headscale) getNamespaceFromEmail(email string) (string, bool) {
|
||||
for match, namespace := range h.cfg.OIDC.MatchMap {
|
||||
regex := regexp.MustCompile(match)
|
||||
if regex.MatchString(email) {
|
||||
return namespace, true
|
||||
}
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue