use dedicated registration ID for auth flow (#2337)
This commit is contained in:
parent
97e5d95399
commit
4c8e847f47
26 changed files with 586 additions and 586 deletions
|
@ -21,7 +21,6 @@ import (
|
|||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/oauth2"
|
||||
"tailscale.com/types/key"
|
||||
"zgo.at/zcache/v2"
|
||||
)
|
||||
|
||||
|
@ -49,8 +48,8 @@ var (
|
|||
|
||||
// RegistrationInfo contains both machine key and verifier information for OIDC validation.
|
||||
type RegistrationInfo struct {
|
||||
MachineKey key.MachinePublic
|
||||
Verifier *string
|
||||
RegistrationID types.RegistrationID
|
||||
Verifier *string
|
||||
}
|
||||
|
||||
type AuthProviderOIDC struct {
|
||||
|
@ -112,11 +111,11 @@ func NewAuthProviderOIDC(
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (a *AuthProviderOIDC) AuthURL(mKey key.MachinePublic) string {
|
||||
func (a *AuthProviderOIDC) AuthURL(registrationID types.RegistrationID) string {
|
||||
return fmt.Sprintf(
|
||||
"%s/register/%s",
|
||||
strings.TrimSuffix(a.serverURL, "/"),
|
||||
mKey.String())
|
||||
registrationID.String())
|
||||
}
|
||||
|
||||
func (a *AuthProviderOIDC) determineNodeExpiry(idTokenExpiration time.Time) time.Time {
|
||||
|
@ -129,32 +128,29 @@ func (a *AuthProviderOIDC) determineNodeExpiry(idTokenExpiration time.Time) time
|
|||
|
||||
// RegisterOIDC redirects to the OIDC provider for authentication
|
||||
// Puts NodeKey in cache so the callback can retrieve it using the oidc state param
|
||||
// Listens in /register/:mKey.
|
||||
// Listens in /register/:registration_id.
|
||||
func (a *AuthProviderOIDC) RegisterHandler(
|
||||
writer http.ResponseWriter,
|
||||
req *http.Request,
|
||||
) {
|
||||
vars := mux.Vars(req)
|
||||
machineKeyStr, ok := vars["mkey"]
|
||||
|
||||
log.Debug().
|
||||
Caller().
|
||||
Str("machine_key", machineKeyStr).
|
||||
Bool("ok", ok).
|
||||
Msg("Received oidc register call")
|
||||
registrationIdStr, ok := vars["registration_id"]
|
||||
|
||||
// We need to make sure we dont open for XSS style injections, if the parameter that
|
||||
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
|
||||
// the template and log an error.
|
||||
var machineKey key.MachinePublic
|
||||
err := machineKey.UnmarshalText(
|
||||
[]byte(machineKeyStr),
|
||||
)
|
||||
registrationId, err := types.RegistrationIDFromString(registrationIdStr)
|
||||
if err != nil {
|
||||
http.Error(writer, err.Error(), http.StatusBadRequest)
|
||||
http.Error(writer, "invalid registration ID", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Caller().
|
||||
Str("registration_id", registrationId.String()).
|
||||
Bool("ok", ok).
|
||||
Msg("Received oidc register call")
|
||||
|
||||
// Set the state and nonce cookies to protect against CSRF attacks
|
||||
state, err := setCSRFCookie(writer, req, "state")
|
||||
if err != nil {
|
||||
|
@ -171,7 +167,7 @@ func (a *AuthProviderOIDC) RegisterHandler(
|
|||
|
||||
// Initialize registration info with machine key
|
||||
registrationInfo := RegistrationInfo{
|
||||
MachineKey: machineKey,
|
||||
RegistrationID: registrationId,
|
||||
}
|
||||
|
||||
extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams)+defaultOAuthOptionsCount)
|
||||
|
@ -290,49 +286,27 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||
return
|
||||
}
|
||||
|
||||
// Retrieve the node and the machine key from the state cache and
|
||||
// database.
|
||||
// TODO(kradalby): Is this comment right?
|
||||
// If the node exists, then the node should be reauthenticated,
|
||||
// if the node does not exist, and the machine key exists, then
|
||||
// this is a new node that should be registered.
|
||||
node, mKey := a.getMachineKeyFromState(state)
|
||||
registrationId := a.getRegistrationIDFromState(state)
|
||||
|
||||
// Reauthenticate the node if it does exists.
|
||||
if node != nil {
|
||||
err := a.reauthenticateNode(node, nodeExpiry)
|
||||
// Register the node if it does not exist.
|
||||
if registrationId != nil {
|
||||
verb := "Reauthenticated"
|
||||
newNode, err := a.handleRegistrationID(user, *registrationId, nodeExpiry)
|
||||
if err != nil {
|
||||
http.Error(writer, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if newNode {
|
||||
verb = "Authenticated"
|
||||
}
|
||||
|
||||
// TODO(kradalby): replace with go-elem
|
||||
var content bytes.Buffer
|
||||
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
|
||||
User: user.DisplayNameOrUsername(),
|
||||
Verb: "Reauthenticated",
|
||||
}); err != nil {
|
||||
http.Error(writer, fmt.Errorf("rendering OIDC callback template: %w", err).Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
_, err = writer.Write(content.Bytes())
|
||||
if err != nil {
|
||||
util.LogErr(err, "Failed to write response")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Register the node if it does not exist.
|
||||
if mKey != nil {
|
||||
if err := a.registerNode(user, mKey, nodeExpiry); err != nil {
|
||||
http.Error(writer, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
content, err := renderOIDCCallbackTemplate(user)
|
||||
content, err := renderOIDCCallbackTemplate(user, verb)
|
||||
if err != nil {
|
||||
http.Error(writer, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
|
@ -456,49 +430,14 @@ func validateOIDCAllowedUsers(
|
|||
return nil
|
||||
}
|
||||
|
||||
// getMachineKeyFromState retrieves the machine key from the state
|
||||
// cache. If the machine key is found, it will try retrieve the
|
||||
// node information from the database.
|
||||
func (a *AuthProviderOIDC) getMachineKeyFromState(state string) (*types.Node, *key.MachinePublic) {
|
||||
// getRegistrationIDFromState retrieves the registration ID from the state.
|
||||
func (a *AuthProviderOIDC) getRegistrationIDFromState(state string) *types.RegistrationID {
|
||||
regInfo, ok := a.registrationCache.Get(state)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// retrieve node information if it exist
|
||||
// The error is not important, because if it does not
|
||||
// exist, then this is a new node and we will move
|
||||
// on to registration.
|
||||
node, _ := a.db.GetNodeByMachineKey(regInfo.MachineKey)
|
||||
|
||||
return node, ®Info.MachineKey
|
||||
}
|
||||
|
||||
// reauthenticateNode updates the node expiry in the database
|
||||
// and notifies the node and its peers about the change.
|
||||
func (a *AuthProviderOIDC) reauthenticateNode(
|
||||
node *types.Node,
|
||||
expiry time.Time,
|
||||
) error {
|
||||
err := a.db.NodeSetExpiry(node.ID, expiry)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname)
|
||||
a.notifier.NotifyByNodeID(
|
||||
ctx,
|
||||
types.StateUpdate{
|
||||
Type: types.StateSelfUpdate,
|
||||
ChangeNodes: []types.NodeID{node.ID},
|
||||
},
|
||||
node.ID,
|
||||
)
|
||||
|
||||
ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname)
|
||||
a.notifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, expiry), node.ID)
|
||||
|
||||
return nil
|
||||
return ®Info.RegistrationID
|
||||
}
|
||||
|
||||
func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
||||
|
@ -556,43 +495,63 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
|||
return user, nil
|
||||
}
|
||||
|
||||
func (a *AuthProviderOIDC) registerNode(
|
||||
func (a *AuthProviderOIDC) handleRegistrationID(
|
||||
user *types.User,
|
||||
machineKey *key.MachinePublic,
|
||||
registrationID types.RegistrationID,
|
||||
expiry time.Time,
|
||||
) error {
|
||||
) (bool, error) {
|
||||
ipv4, ipv6, err := a.ipAlloc.Next()
|
||||
if err != nil {
|
||||
return err
|
||||
return false, err
|
||||
}
|
||||
|
||||
if _, err := a.db.RegisterNodeFromAuthCallback(
|
||||
*machineKey,
|
||||
node, newNode, err := a.db.HandleNodeFromAuthPath(
|
||||
registrationID,
|
||||
types.UserID(user.ID),
|
||||
&expiry,
|
||||
util.RegisterMethodOIDC,
|
||||
ipv4, ipv6,
|
||||
); err != nil {
|
||||
return fmt.Errorf("could not register node: %w", err)
|
||||
}
|
||||
|
||||
err = nodesChangedHook(a.db, a.polMan, a.notifier)
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("updating resources using node: %w", err)
|
||||
return false, fmt.Errorf("could not register node: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
// Send an update to all nodes if this is a new node that they need to know
|
||||
// about.
|
||||
// If this is a refresh, just send new expiry updates.
|
||||
if newNode {
|
||||
err = nodesChangedHook(a.db, a.polMan, a.notifier)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("updating resources using node: %w", err)
|
||||
}
|
||||
} else {
|
||||
ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname)
|
||||
a.notifier.NotifyByNodeID(
|
||||
ctx,
|
||||
types.StateUpdate{
|
||||
Type: types.StateSelfUpdate,
|
||||
ChangeNodes: []types.NodeID{node.ID},
|
||||
},
|
||||
node.ID,
|
||||
)
|
||||
|
||||
ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname)
|
||||
a.notifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, expiry), node.ID)
|
||||
}
|
||||
|
||||
return newNode, nil
|
||||
}
|
||||
|
||||
// TODO(kradalby):
|
||||
// Rewrite in elem-go.
|
||||
func renderOIDCCallbackTemplate(
|
||||
user *types.User,
|
||||
verb string,
|
||||
) (*bytes.Buffer, error) {
|
||||
var content bytes.Buffer
|
||||
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
|
||||
User: user.DisplayNameOrUsername(),
|
||||
Verb: "Authenticated",
|
||||
Verb: verb,
|
||||
}); err != nil {
|
||||
return nil, fmt.Errorf("rendering OIDC callback template: %w", err)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue