create DB struct
This is step one in detaching the Database layer from Headscale (h). The ultimate goal is to have all function that does database operations in its own package, and keep the business logic and writing separate. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
b01f1f1867
commit
14e29a7bee
48 changed files with 1731 additions and 1572 deletions
|
@ -9,6 +9,7 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
|
@ -82,7 +83,7 @@ func (h *Headscale) KeyHandler(
|
|||
// Old clients don't send a 'v' parameter, so we send the legacy public key
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
_, err := writer.Write([]byte(MachinePublicKeyStripPrefix(h.privateKey.Public())))
|
||||
_, err := writer.Write([]byte(util.MachinePublicKeyStripPrefix(h.privateKey.Public())))
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
@ -102,7 +103,7 @@ func (h *Headscale) handleRegisterCommon(
|
|||
isNoise bool,
|
||||
) {
|
||||
now := time.Now().UTC()
|
||||
machine, err := h.GetMachineByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey)
|
||||
machine, err := h.db.GetMachineByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey)
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// If the machine has AuthKey set, handle registration via PreAuthKeys
|
||||
if registerRequest.Auth.AuthKey != "" {
|
||||
|
@ -120,7 +121,7 @@ func (h *Headscale) handleRegisterCommon(
|
|||
// is that the client will hammer headscale with requests until it gets a
|
||||
// successful RegisterResponse.
|
||||
if registerRequest.Followup != "" {
|
||||
if _, ok := h.registrationCache.Get(NodePublicKeyStripPrefix(registerRequest.NodeKey)); ok {
|
||||
if _, ok := h.registrationCache.Get(util.NodePublicKeyStripPrefix(registerRequest.NodeKey)); ok {
|
||||
log.Debug().
|
||||
Caller().
|
||||
Str("machine", registerRequest.Hostinfo.Hostname).
|
||||
|
@ -152,7 +153,7 @@ func (h *Headscale) handleRegisterCommon(
|
|||
Bool("noise", isNoise).
|
||||
Msg("New machine not yet in the database")
|
||||
|
||||
givenName, err := h.GenerateGivenName(
|
||||
givenName, err := h.db.GenerateGivenName(
|
||||
machineKey.String(),
|
||||
registerRequest.Hostinfo.Hostname,
|
||||
)
|
||||
|
@ -171,10 +172,10 @@ func (h *Headscale) handleRegisterCommon(
|
|||
// We create the machine and then keep it around until a callback
|
||||
// happens
|
||||
newMachine := Machine{
|
||||
MachineKey: MachinePublicKeyStripPrefix(machineKey),
|
||||
MachineKey: util.MachinePublicKeyStripPrefix(machineKey),
|
||||
Hostname: registerRequest.Hostinfo.Hostname,
|
||||
GivenName: givenName,
|
||||
NodeKey: NodePublicKeyStripPrefix(registerRequest.NodeKey),
|
||||
NodeKey: util.NodePublicKeyStripPrefix(registerRequest.NodeKey),
|
||||
LastSeen: &now,
|
||||
Expiry: &time.Time{},
|
||||
}
|
||||
|
@ -210,11 +211,11 @@ func (h *Headscale) handleRegisterCommon(
|
|||
// So if we have a not valid MachineKey (but we were able to fetch the machine with the NodeKeys), we update it.
|
||||
var storedMachineKey key.MachinePublic
|
||||
err = storedMachineKey.UnmarshalText(
|
||||
[]byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)),
|
||||
[]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)),
|
||||
)
|
||||
if err != nil || storedMachineKey.IsZero() {
|
||||
machine.MachineKey = MachinePublicKeyStripPrefix(machineKey)
|
||||
if err := h.db.Save(&machine).Error; err != nil {
|
||||
machine.MachineKey = util.MachinePublicKeyStripPrefix(machineKey)
|
||||
if err := h.db.db.Save(&machine).Error; err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("func", "RegistrationHandler").
|
||||
|
@ -231,7 +232,7 @@ func (h *Headscale) handleRegisterCommon(
|
|||
// - Trying to log out (sending a expiry in the past)
|
||||
// - A valid, registered machine, looking for /map
|
||||
// - Expired machine wanting to reauthenticate
|
||||
if machine.NodeKey == NodePublicKeyStripPrefix(registerRequest.NodeKey) {
|
||||
if machine.NodeKey == util.NodePublicKeyStripPrefix(registerRequest.NodeKey) {
|
||||
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
|
||||
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
|
||||
if !registerRequest.Expiry.IsZero() &&
|
||||
|
@ -251,7 +252,7 @@ func (h *Headscale) handleRegisterCommon(
|
|||
}
|
||||
|
||||
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
|
||||
if machine.NodeKey == NodePublicKeyStripPrefix(registerRequest.OldNodeKey) &&
|
||||
if machine.NodeKey == util.NodePublicKeyStripPrefix(registerRequest.OldNodeKey) &&
|
||||
!machine.isExpired() {
|
||||
h.handleMachineRefreshKeyCommon(
|
||||
writer,
|
||||
|
@ -282,9 +283,9 @@ func (h *Headscale) handleRegisterCommon(
|
|||
// we need to make sure the NodeKey matches the one in the request
|
||||
// TODO(juan): What happens when using fast user switching between two
|
||||
// headscale-managed tailnets?
|
||||
machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey)
|
||||
machine.NodeKey = util.NodePublicKeyStripPrefix(registerRequest.NodeKey)
|
||||
h.registrationCache.Set(
|
||||
NodePublicKeyStripPrefix(registerRequest.NodeKey),
|
||||
util.NodePublicKeyStripPrefix(registerRequest.NodeKey),
|
||||
*machine,
|
||||
registerCacheExpiration,
|
||||
)
|
||||
|
@ -311,7 +312,7 @@ func (h *Headscale) handleAuthKeyCommon(
|
|||
Msgf("Processing auth key for %s", registerRequest.Hostinfo.Hostname)
|
||||
resp := tailcfg.RegisterResponse{}
|
||||
|
||||
pak, err := h.checkKeyValidity(registerRequest.Auth.AuthKey)
|
||||
pak, err := h.db.checkKeyValidity(registerRequest.Auth.AuthKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
@ -372,13 +373,13 @@ func (h *Headscale) handleAuthKeyCommon(
|
|||
Str("machine", registerRequest.Hostinfo.Hostname).
|
||||
Msg("Authentication key was valid, proceeding to acquire IP addresses")
|
||||
|
||||
nodeKey := NodePublicKeyStripPrefix(registerRequest.NodeKey)
|
||||
nodeKey := util.NodePublicKeyStripPrefix(registerRequest.NodeKey)
|
||||
|
||||
// retrieve machine information if it exist
|
||||
// The error is not important, because if it does not
|
||||
// exist, then this is a new machine and we will move
|
||||
// on to registration.
|
||||
machine, _ := h.GetMachineByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey)
|
||||
machine, _ := h.db.GetMachineByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey)
|
||||
if machine != nil {
|
||||
log.Trace().
|
||||
Caller().
|
||||
|
@ -388,7 +389,7 @@ func (h *Headscale) handleAuthKeyCommon(
|
|||
|
||||
machine.NodeKey = nodeKey
|
||||
machine.AuthKeyID = uint(pak.ID)
|
||||
err := h.RefreshMachine(machine, registerRequest.Expiry)
|
||||
err := h.db.RefreshMachine(machine, registerRequest.Expiry)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
@ -403,7 +404,7 @@ func (h *Headscale) handleAuthKeyCommon(
|
|||
aclTags := pak.toProto().AclTags
|
||||
if len(aclTags) > 0 {
|
||||
// This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login
|
||||
err = h.SetTags(machine, aclTags)
|
||||
err = h.db.SetTags(machine, aclTags, h.UpdateACLRules)
|
||||
|
||||
if err != nil {
|
||||
log.Error().
|
||||
|
@ -420,7 +421,7 @@ func (h *Headscale) handleAuthKeyCommon(
|
|||
} else {
|
||||
now := time.Now().UTC()
|
||||
|
||||
givenName, err := h.GenerateGivenName(MachinePublicKeyStripPrefix(machineKey), registerRequest.Hostinfo.Hostname)
|
||||
givenName, err := h.db.GenerateGivenName(util.MachinePublicKeyStripPrefix(machineKey), registerRequest.Hostinfo.Hostname)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
@ -436,7 +437,7 @@ func (h *Headscale) handleAuthKeyCommon(
|
|||
Hostname: registerRequest.Hostinfo.Hostname,
|
||||
GivenName: givenName,
|
||||
UserID: pak.User.ID,
|
||||
MachineKey: MachinePublicKeyStripPrefix(machineKey),
|
||||
MachineKey: util.MachinePublicKeyStripPrefix(machineKey),
|
||||
RegisterMethod: RegisterMethodAuthKey,
|
||||
Expiry: ®isterRequest.Expiry,
|
||||
NodeKey: nodeKey,
|
||||
|
@ -445,7 +446,7 @@ func (h *Headscale) handleAuthKeyCommon(
|
|||
ForcedTags: pak.toProto().AclTags,
|
||||
}
|
||||
|
||||
machine, err = h.RegisterMachine(
|
||||
machine, err = h.db.RegisterMachine(
|
||||
machineToRegister,
|
||||
)
|
||||
if err != nil {
|
||||
|
@ -462,7 +463,7 @@ func (h *Headscale) handleAuthKeyCommon(
|
|||
}
|
||||
}
|
||||
|
||||
err = h.UsePreAuthKey(pak)
|
||||
err = h.db.UsePreAuthKey(pak)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
@ -591,7 +592,7 @@ func (h *Headscale) handleMachineLogOutCommon(
|
|||
Str("machine", machine.Hostname).
|
||||
Msg("Client requested logout")
|
||||
|
||||
err := h.ExpireMachine(&machine)
|
||||
err := h.db.ExpireMachine(&machine)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
@ -634,7 +635,7 @@ func (h *Headscale) handleMachineLogOutCommon(
|
|||
}
|
||||
|
||||
if machine.isEphemeral() {
|
||||
err = h.HardDeleteMachine(&machine)
|
||||
err = h.db.HardDeleteMachine(&machine)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
|
@ -720,9 +721,9 @@ func (h *Headscale) handleMachineRefreshKeyCommon(
|
|||
Bool("noise", isNoise).
|
||||
Str("machine", machine.Hostname).
|
||||
Msg("We have the OldNodeKey in the database. This is a key refresh")
|
||||
machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey)
|
||||
machine.NodeKey = util.NodePublicKeyStripPrefix(registerRequest.NodeKey)
|
||||
|
||||
if err := h.db.Save(&machine).Error; err != nil {
|
||||
if err := h.db.db.Save(&machine).Error; err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue