use dedicated registration ID for auth flow (#2337)

This commit is contained in:
Kristoffer Dalby 2025-01-26 22:20:11 +01:00 committed by GitHub
parent 97e5d95399
commit 4c8e847f47
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
26 changed files with 586 additions and 586 deletions

View file

@ -41,7 +41,7 @@ type KV struct {
type HSDatabase struct {
DB *gorm.DB
cfg *types.DatabaseConfig
regCache *zcache.Cache[string, types.Node]
regCache *zcache.Cache[types.RegistrationID, types.RegisterNode]
baseDomain string
}
@ -51,7 +51,7 @@ type HSDatabase struct {
func NewHeadscaleDatabase(
cfg types.DatabaseConfig,
baseDomain string,
regCache *zcache.Cache[string, types.Node],
regCache *zcache.Cache[types.RegistrationID, types.RegisterNode],
) (*HSDatabase, error) {
dbConn, err := openDB(cfg)
if err != nil {

View file

@ -260,8 +260,8 @@ func testCopyOfDatabase(src string) (string, error) {
return dst, err
}
func emptyCache() *zcache.Cache[string, types.Node] {
return zcache.New[string, types.Node](time.Minute, time.Hour)
func emptyCache() *zcache.Cache[types.RegistrationID, types.RegisterNode] {
return zcache.New[types.RegistrationID, types.RegisterNode](time.Minute, time.Hour)
}
// requireConstraintFailed checks if the error is a constraint failure with

View file

@ -158,6 +158,30 @@ func GetNodeByMachineKey(
return &mach, nil
}
func (hsdb *HSDatabase) GetNodeByNodeKey(nodeKey key.NodePublic) (*types.Node, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
return GetNodeByNodeKey(rx, nodeKey)
})
}
// GetNodeByNodeKey finds a Node by its NodeKey and returns the Node struct.
func GetNodeByNodeKey(
tx *gorm.DB,
nodeKey key.NodePublic,
) (*types.Node, error) {
mach := types.Node{}
if result := tx.
Preload("AuthKey").
Preload("AuthKey.User").
Preload("User").
Preload("Routes").
First(&mach, "node_key = ?", nodeKey.String()); result.Error != nil {
return nil, result.Error
}
return &mach, nil
}
func (hsdb *HSDatabase) GetNodeByAnyKey(
machineKey key.MachinePublic,
nodeKey key.NodePublic,
@ -319,60 +343,83 @@ func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error {
return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("last_seen", lastSeen).Error
}
func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
mkey key.MachinePublic,
// HandleNodeFromAuthPath is called from the OIDC or CLI auth path
// with a registrationID to register or reauthenticate a node.
// If the node found in the registration cache is not already registered,
// it will be registered with the user and the node will be removed from the cache.
// If the node is already registered, the expiry will be updated.
// The node, and a boolean indicating if it was a new node or not, will be returned.
func (hsdb *HSDatabase) HandleNodeFromAuthPath(
registrationID types.RegistrationID,
userID types.UserID,
nodeExpiry *time.Time,
registrationMethod string,
ipv4 *netip.Addr,
ipv6 *netip.Addr,
) (*types.Node, error) {
return Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
if node, ok := hsdb.regCache.Get(mkey.String()); ok {
user, err := GetUserByID(tx, userID)
if err != nil {
return nil, fmt.Errorf(
"failed to find user in register node from auth callback, %w",
err,
) (*types.Node, bool, error) {
var newNode bool
node, err := Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
if reg, ok := hsdb.regCache.Get(registrationID); ok {
if node, _ := GetNodeByNodeKey(tx, reg.Node.NodeKey); node == nil {
user, err := GetUserByID(tx, userID)
if err != nil {
return nil, fmt.Errorf(
"failed to find user in register node from auth callback, %w",
err,
)
}
log.Debug().
Str("registration_id", registrationID.String()).
Str("username", user.Username()).
Str("registrationMethod", registrationMethod).
Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)).
Msg("Registering node from API/CLI or auth callback")
// TODO(kradalby): This looks quite wrong? why ID 0?
// Why not always?
// Registration of expired node with different user
if reg.Node.ID != 0 &&
reg.Node.UserID != user.ID {
return nil, ErrDifferentRegisteredUser
}
reg.Node.UserID = user.ID
reg.Node.User = *user
reg.Node.RegisterMethod = registrationMethod
if nodeExpiry != nil {
reg.Node.Expiry = nodeExpiry
}
node, err := RegisterNode(
tx,
reg.Node,
ipv4, ipv6,
)
if err == nil {
hsdb.regCache.Delete(registrationID)
}
// Signal to waiting clients that the machine has been registered.
close(reg.Registered)
newNode = true
return node, err
} else {
// If the node is already registered, this is a refresh.
err := NodeSetExpiry(tx, node.ID, *nodeExpiry)
if err != nil {
return nil, err
}
return node, nil
}
log.Debug().
Str("machine_key", mkey.ShortString()).
Str("username", user.Username()).
Str("registrationMethod", registrationMethod).
Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)).
Msg("Registering node from API/CLI or auth callback")
// Registration of expired node with different user
if node.ID != 0 &&
node.UserID != user.ID {
return nil, ErrDifferentRegisteredUser
}
node.UserID = user.ID
node.User = *user
node.RegisterMethod = registrationMethod
if nodeExpiry != nil {
node.Expiry = nodeExpiry
}
node, err := RegisterNode(
tx,
node,
ipv4, ipv6,
)
if err == nil {
hsdb.regCache.Delete(mkey.String())
}
return node, err
}
return nil, ErrNodeNotFoundRegistrationCache
})
return node, newNode, err
}
func (hsdb *HSDatabase) RegisterNode(node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) {