use dedicated registration ID for auth flow (#2337)

This commit is contained in:
Kristoffer Dalby 2025-01-26 22:20:11 +01:00 committed by GitHub
parent 97e5d95399
commit 4c8e847f47
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
26 changed files with 586 additions and 586 deletions

View file

@ -21,7 +21,6 @@ import (
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"golang.org/x/oauth2"
"tailscale.com/types/key"
"zgo.at/zcache/v2"
)
@ -49,8 +48,8 @@ var (
// RegistrationInfo contains both machine key and verifier information for OIDC validation.
type RegistrationInfo struct {
MachineKey key.MachinePublic
Verifier *string
RegistrationID types.RegistrationID
Verifier *string
}
type AuthProviderOIDC struct {
@ -112,11 +111,11 @@ func NewAuthProviderOIDC(
}, nil
}
func (a *AuthProviderOIDC) AuthURL(mKey key.MachinePublic) string {
func (a *AuthProviderOIDC) AuthURL(registrationID types.RegistrationID) string {
return fmt.Sprintf(
"%s/register/%s",
strings.TrimSuffix(a.serverURL, "/"),
mKey.String())
registrationID.String())
}
func (a *AuthProviderOIDC) determineNodeExpiry(idTokenExpiration time.Time) time.Time {
@ -129,32 +128,29 @@ func (a *AuthProviderOIDC) determineNodeExpiry(idTokenExpiration time.Time) time
// RegisterOIDC redirects to the OIDC provider for authentication
// Puts NodeKey in cache so the callback can retrieve it using the oidc state param
// Listens in /register/:mKey.
// Listens in /register/:registration_id.
func (a *AuthProviderOIDC) RegisterHandler(
writer http.ResponseWriter,
req *http.Request,
) {
vars := mux.Vars(req)
machineKeyStr, ok := vars["mkey"]
log.Debug().
Caller().
Str("machine_key", machineKeyStr).
Bool("ok", ok).
Msg("Received oidc register call")
registrationIdStr, ok := vars["registration_id"]
// We need to make sure we dont open for XSS style injections, if the parameter that
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
// the template and log an error.
var machineKey key.MachinePublic
err := machineKey.UnmarshalText(
[]byte(machineKeyStr),
)
registrationId, err := types.RegistrationIDFromString(registrationIdStr)
if err != nil {
http.Error(writer, err.Error(), http.StatusBadRequest)
http.Error(writer, "invalid registration ID", http.StatusBadRequest)
return
}
log.Debug().
Caller().
Str("registration_id", registrationId.String()).
Bool("ok", ok).
Msg("Received oidc register call")
// Set the state and nonce cookies to protect against CSRF attacks
state, err := setCSRFCookie(writer, req, "state")
if err != nil {
@ -171,7 +167,7 @@ func (a *AuthProviderOIDC) RegisterHandler(
// Initialize registration info with machine key
registrationInfo := RegistrationInfo{
MachineKey: machineKey,
RegistrationID: registrationId,
}
extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams)+defaultOAuthOptionsCount)
@ -290,49 +286,27 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
return
}
// Retrieve the node and the machine key from the state cache and
// database.
// TODO(kradalby): Is this comment right?
// If the node exists, then the node should be reauthenticated,
// if the node does not exist, and the machine key exists, then
// this is a new node that should be registered.
node, mKey := a.getMachineKeyFromState(state)
registrationId := a.getRegistrationIDFromState(state)
// Reauthenticate the node if it does exists.
if node != nil {
err := a.reauthenticateNode(node, nodeExpiry)
// Register the node if it does not exist.
if registrationId != nil {
verb := "Reauthenticated"
newNode, err := a.handleRegistrationID(user, *registrationId, nodeExpiry)
if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
}
if newNode {
verb = "Authenticated"
}
// TODO(kradalby): replace with go-elem
var content bytes.Buffer
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
User: user.DisplayNameOrUsername(),
Verb: "Reauthenticated",
}); err != nil {
http.Error(writer, fmt.Errorf("rendering OIDC callback template: %w", err).Error(), http.StatusInternalServerError)
return
}
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(content.Bytes())
if err != nil {
util.LogErr(err, "Failed to write response")
}
return
}
// Register the node if it does not exist.
if mKey != nil {
if err := a.registerNode(user, mKey, nodeExpiry); err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
}
content, err := renderOIDCCallbackTemplate(user)
content, err := renderOIDCCallbackTemplate(user, verb)
if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
@ -456,49 +430,14 @@ func validateOIDCAllowedUsers(
return nil
}
// getMachineKeyFromState retrieves the machine key from the state
// cache. If the machine key is found, it will try retrieve the
// node information from the database.
func (a *AuthProviderOIDC) getMachineKeyFromState(state string) (*types.Node, *key.MachinePublic) {
// getRegistrationIDFromState retrieves the registration ID from the state.
func (a *AuthProviderOIDC) getRegistrationIDFromState(state string) *types.RegistrationID {
regInfo, ok := a.registrationCache.Get(state)
if !ok {
return nil, nil
return nil
}
// retrieve node information if it exist
// The error is not important, because if it does not
// exist, then this is a new node and we will move
// on to registration.
node, _ := a.db.GetNodeByMachineKey(regInfo.MachineKey)
return node, &regInfo.MachineKey
}
// reauthenticateNode updates the node expiry in the database
// and notifies the node and its peers about the change.
func (a *AuthProviderOIDC) reauthenticateNode(
node *types.Node,
expiry time.Time,
) error {
err := a.db.NodeSetExpiry(node.ID, expiry)
if err != nil {
return err
}
ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname)
a.notifier.NotifyByNodeID(
ctx,
types.StateUpdate{
Type: types.StateSelfUpdate,
ChangeNodes: []types.NodeID{node.ID},
},
node.ID,
)
ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname)
a.notifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, expiry), node.ID)
return nil
return &regInfo.RegistrationID
}
func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
@ -556,43 +495,63 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
return user, nil
}
func (a *AuthProviderOIDC) registerNode(
func (a *AuthProviderOIDC) handleRegistrationID(
user *types.User,
machineKey *key.MachinePublic,
registrationID types.RegistrationID,
expiry time.Time,
) error {
) (bool, error) {
ipv4, ipv6, err := a.ipAlloc.Next()
if err != nil {
return err
return false, err
}
if _, err := a.db.RegisterNodeFromAuthCallback(
*machineKey,
node, newNode, err := a.db.HandleNodeFromAuthPath(
registrationID,
types.UserID(user.ID),
&expiry,
util.RegisterMethodOIDC,
ipv4, ipv6,
); err != nil {
return fmt.Errorf("could not register node: %w", err)
}
err = nodesChangedHook(a.db, a.polMan, a.notifier)
)
if err != nil {
return fmt.Errorf("updating resources using node: %w", err)
return false, fmt.Errorf("could not register node: %w", err)
}
return nil
// Send an update to all nodes if this is a new node that they need to know
// about.
// If this is a refresh, just send new expiry updates.
if newNode {
err = nodesChangedHook(a.db, a.polMan, a.notifier)
if err != nil {
return false, fmt.Errorf("updating resources using node: %w", err)
}
} else {
ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname)
a.notifier.NotifyByNodeID(
ctx,
types.StateUpdate{
Type: types.StateSelfUpdate,
ChangeNodes: []types.NodeID{node.ID},
},
node.ID,
)
ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname)
a.notifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, expiry), node.ID)
}
return newNode, nil
}
// TODO(kradalby):
// Rewrite in elem-go.
func renderOIDCCallbackTemplate(
user *types.User,
verb string,
) (*bytes.Buffer, error) {
var content bytes.Buffer
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
User: user.DisplayNameOrUsername(),
Verb: "Authenticated",
Verb: verb,
}); err != nil {
return nil, fmt.Errorf("rendering OIDC callback template: %w", err)
}