use dedicated registration ID for auth flow (#2337)
This commit is contained in:
parent
97e5d95399
commit
4c8e847f47
26 changed files with 586 additions and 586 deletions
|
@ -41,7 +41,7 @@ type KV struct {
|
|||
type HSDatabase struct {
|
||||
DB *gorm.DB
|
||||
cfg *types.DatabaseConfig
|
||||
regCache *zcache.Cache[string, types.Node]
|
||||
regCache *zcache.Cache[types.RegistrationID, types.RegisterNode]
|
||||
|
||||
baseDomain string
|
||||
}
|
||||
|
@ -51,7 +51,7 @@ type HSDatabase struct {
|
|||
func NewHeadscaleDatabase(
|
||||
cfg types.DatabaseConfig,
|
||||
baseDomain string,
|
||||
regCache *zcache.Cache[string, types.Node],
|
||||
regCache *zcache.Cache[types.RegistrationID, types.RegisterNode],
|
||||
) (*HSDatabase, error) {
|
||||
dbConn, err := openDB(cfg)
|
||||
if err != nil {
|
||||
|
|
|
@ -260,8 +260,8 @@ func testCopyOfDatabase(src string) (string, error) {
|
|||
return dst, err
|
||||
}
|
||||
|
||||
func emptyCache() *zcache.Cache[string, types.Node] {
|
||||
return zcache.New[string, types.Node](time.Minute, time.Hour)
|
||||
func emptyCache() *zcache.Cache[types.RegistrationID, types.RegisterNode] {
|
||||
return zcache.New[types.RegistrationID, types.RegisterNode](time.Minute, time.Hour)
|
||||
}
|
||||
|
||||
// requireConstraintFailed checks if the error is a constraint failure with
|
||||
|
|
|
@ -158,6 +158,30 @@ func GetNodeByMachineKey(
|
|||
return &mach, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetNodeByNodeKey(nodeKey key.NodePublic) (*types.Node, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
|
||||
return GetNodeByNodeKey(rx, nodeKey)
|
||||
})
|
||||
}
|
||||
|
||||
// GetNodeByNodeKey finds a Node by its NodeKey and returns the Node struct.
|
||||
func GetNodeByNodeKey(
|
||||
tx *gorm.DB,
|
||||
nodeKey key.NodePublic,
|
||||
) (*types.Node, error) {
|
||||
mach := types.Node{}
|
||||
if result := tx.
|
||||
Preload("AuthKey").
|
||||
Preload("AuthKey.User").
|
||||
Preload("User").
|
||||
Preload("Routes").
|
||||
First(&mach, "node_key = ?", nodeKey.String()); result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
return &mach, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetNodeByAnyKey(
|
||||
machineKey key.MachinePublic,
|
||||
nodeKey key.NodePublic,
|
||||
|
@ -319,60 +343,83 @@ func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error {
|
|||
return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("last_seen", lastSeen).Error
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
|
||||
mkey key.MachinePublic,
|
||||
// HandleNodeFromAuthPath is called from the OIDC or CLI auth path
|
||||
// with a registrationID to register or reauthenticate a node.
|
||||
// If the node found in the registration cache is not already registered,
|
||||
// it will be registered with the user and the node will be removed from the cache.
|
||||
// If the node is already registered, the expiry will be updated.
|
||||
// The node, and a boolean indicating if it was a new node or not, will be returned.
|
||||
func (hsdb *HSDatabase) HandleNodeFromAuthPath(
|
||||
registrationID types.RegistrationID,
|
||||
userID types.UserID,
|
||||
nodeExpiry *time.Time,
|
||||
registrationMethod string,
|
||||
ipv4 *netip.Addr,
|
||||
ipv6 *netip.Addr,
|
||||
) (*types.Node, error) {
|
||||
return Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
if node, ok := hsdb.regCache.Get(mkey.String()); ok {
|
||||
user, err := GetUserByID(tx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"failed to find user in register node from auth callback, %w",
|
||||
err,
|
||||
) (*types.Node, bool, error) {
|
||||
var newNode bool
|
||||
node, err := Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
if reg, ok := hsdb.regCache.Get(registrationID); ok {
|
||||
if node, _ := GetNodeByNodeKey(tx, reg.Node.NodeKey); node == nil {
|
||||
user, err := GetUserByID(tx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"failed to find user in register node from auth callback, %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("registration_id", registrationID.String()).
|
||||
Str("username", user.Username()).
|
||||
Str("registrationMethod", registrationMethod).
|
||||
Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)).
|
||||
Msg("Registering node from API/CLI or auth callback")
|
||||
|
||||
// TODO(kradalby): This looks quite wrong? why ID 0?
|
||||
// Why not always?
|
||||
// Registration of expired node with different user
|
||||
if reg.Node.ID != 0 &&
|
||||
reg.Node.UserID != user.ID {
|
||||
return nil, ErrDifferentRegisteredUser
|
||||
}
|
||||
|
||||
reg.Node.UserID = user.ID
|
||||
reg.Node.User = *user
|
||||
reg.Node.RegisterMethod = registrationMethod
|
||||
|
||||
if nodeExpiry != nil {
|
||||
reg.Node.Expiry = nodeExpiry
|
||||
}
|
||||
|
||||
node, err := RegisterNode(
|
||||
tx,
|
||||
reg.Node,
|
||||
ipv4, ipv6,
|
||||
)
|
||||
|
||||
if err == nil {
|
||||
hsdb.regCache.Delete(registrationID)
|
||||
}
|
||||
|
||||
// Signal to waiting clients that the machine has been registered.
|
||||
close(reg.Registered)
|
||||
newNode = true
|
||||
return node, err
|
||||
} else {
|
||||
// If the node is already registered, this is a refresh.
|
||||
err := NodeSetExpiry(tx, node.ID, *nodeExpiry)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return node, nil
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("machine_key", mkey.ShortString()).
|
||||
Str("username", user.Username()).
|
||||
Str("registrationMethod", registrationMethod).
|
||||
Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)).
|
||||
Msg("Registering node from API/CLI or auth callback")
|
||||
|
||||
// Registration of expired node with different user
|
||||
if node.ID != 0 &&
|
||||
node.UserID != user.ID {
|
||||
return nil, ErrDifferentRegisteredUser
|
||||
}
|
||||
|
||||
node.UserID = user.ID
|
||||
node.User = *user
|
||||
node.RegisterMethod = registrationMethod
|
||||
|
||||
if nodeExpiry != nil {
|
||||
node.Expiry = nodeExpiry
|
||||
}
|
||||
|
||||
node, err := RegisterNode(
|
||||
tx,
|
||||
node,
|
||||
ipv4, ipv6,
|
||||
)
|
||||
|
||||
if err == nil {
|
||||
hsdb.regCache.Delete(mkey.String())
|
||||
}
|
||||
|
||||
return node, err
|
||||
}
|
||||
|
||||
return nil, ErrNodeNotFoundRegistrationCache
|
||||
})
|
||||
|
||||
return node, newNode, err
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) RegisterNode(node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue