create DB struct

This is step one in detaching the Database layer from Headscale (h). The
ultimate goal is to have all function that does database operations in
its own package, and keep the business logic and writing separate.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2023-05-11 09:09:18 +02:00 committed by Kristoffer Dalby
parent b01f1f1867
commit 14e29a7bee
48 changed files with 1731 additions and 1572 deletions

View file

@ -9,6 +9,7 @@ import (
"strings"
"time"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"tailscale.com/tailcfg"
@ -82,7 +83,7 @@ func (h *Headscale) KeyHandler(
// Old clients don't send a 'v' parameter, so we send the legacy public key
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err := writer.Write([]byte(MachinePublicKeyStripPrefix(h.privateKey.Public())))
_, err := writer.Write([]byte(util.MachinePublicKeyStripPrefix(h.privateKey.Public())))
if err != nil {
log.Error().
Caller().
@ -102,7 +103,7 @@ func (h *Headscale) handleRegisterCommon(
isNoise bool,
) {
now := time.Now().UTC()
machine, err := h.GetMachineByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey)
machine, err := h.db.GetMachineByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey)
if errors.Is(err, gorm.ErrRecordNotFound) {
// If the machine has AuthKey set, handle registration via PreAuthKeys
if registerRequest.Auth.AuthKey != "" {
@ -120,7 +121,7 @@ func (h *Headscale) handleRegisterCommon(
// is that the client will hammer headscale with requests until it gets a
// successful RegisterResponse.
if registerRequest.Followup != "" {
if _, ok := h.registrationCache.Get(NodePublicKeyStripPrefix(registerRequest.NodeKey)); ok {
if _, ok := h.registrationCache.Get(util.NodePublicKeyStripPrefix(registerRequest.NodeKey)); ok {
log.Debug().
Caller().
Str("machine", registerRequest.Hostinfo.Hostname).
@ -152,7 +153,7 @@ func (h *Headscale) handleRegisterCommon(
Bool("noise", isNoise).
Msg("New machine not yet in the database")
givenName, err := h.GenerateGivenName(
givenName, err := h.db.GenerateGivenName(
machineKey.String(),
registerRequest.Hostinfo.Hostname,
)
@ -171,10 +172,10 @@ func (h *Headscale) handleRegisterCommon(
// We create the machine and then keep it around until a callback
// happens
newMachine := Machine{
MachineKey: MachinePublicKeyStripPrefix(machineKey),
MachineKey: util.MachinePublicKeyStripPrefix(machineKey),
Hostname: registerRequest.Hostinfo.Hostname,
GivenName: givenName,
NodeKey: NodePublicKeyStripPrefix(registerRequest.NodeKey),
NodeKey: util.NodePublicKeyStripPrefix(registerRequest.NodeKey),
LastSeen: &now,
Expiry: &time.Time{},
}
@ -210,11 +211,11 @@ func (h *Headscale) handleRegisterCommon(
// So if we have a not valid MachineKey (but we were able to fetch the machine with the NodeKeys), we update it.
var storedMachineKey key.MachinePublic
err = storedMachineKey.UnmarshalText(
[]byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)),
[]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)),
)
if err != nil || storedMachineKey.IsZero() {
machine.MachineKey = MachinePublicKeyStripPrefix(machineKey)
if err := h.db.Save(&machine).Error; err != nil {
machine.MachineKey = util.MachinePublicKeyStripPrefix(machineKey)
if err := h.db.db.Save(&machine).Error; err != nil {
log.Error().
Caller().
Str("func", "RegistrationHandler").
@ -231,7 +232,7 @@ func (h *Headscale) handleRegisterCommon(
// - Trying to log out (sending a expiry in the past)
// - A valid, registered machine, looking for /map
// - Expired machine wanting to reauthenticate
if machine.NodeKey == NodePublicKeyStripPrefix(registerRequest.NodeKey) {
if machine.NodeKey == util.NodePublicKeyStripPrefix(registerRequest.NodeKey) {
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
if !registerRequest.Expiry.IsZero() &&
@ -251,7 +252,7 @@ func (h *Headscale) handleRegisterCommon(
}
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
if machine.NodeKey == NodePublicKeyStripPrefix(registerRequest.OldNodeKey) &&
if machine.NodeKey == util.NodePublicKeyStripPrefix(registerRequest.OldNodeKey) &&
!machine.isExpired() {
h.handleMachineRefreshKeyCommon(
writer,
@ -282,9 +283,9 @@ func (h *Headscale) handleRegisterCommon(
// we need to make sure the NodeKey matches the one in the request
// TODO(juan): What happens when using fast user switching between two
// headscale-managed tailnets?
machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey)
machine.NodeKey = util.NodePublicKeyStripPrefix(registerRequest.NodeKey)
h.registrationCache.Set(
NodePublicKeyStripPrefix(registerRequest.NodeKey),
util.NodePublicKeyStripPrefix(registerRequest.NodeKey),
*machine,
registerCacheExpiration,
)
@ -311,7 +312,7 @@ func (h *Headscale) handleAuthKeyCommon(
Msgf("Processing auth key for %s", registerRequest.Hostinfo.Hostname)
resp := tailcfg.RegisterResponse{}
pak, err := h.checkKeyValidity(registerRequest.Auth.AuthKey)
pak, err := h.db.checkKeyValidity(registerRequest.Auth.AuthKey)
if err != nil {
log.Error().
Caller().
@ -372,13 +373,13 @@ func (h *Headscale) handleAuthKeyCommon(
Str("machine", registerRequest.Hostinfo.Hostname).
Msg("Authentication key was valid, proceeding to acquire IP addresses")
nodeKey := NodePublicKeyStripPrefix(registerRequest.NodeKey)
nodeKey := util.NodePublicKeyStripPrefix(registerRequest.NodeKey)
// retrieve machine information if it exist
// The error is not important, because if it does not
// exist, then this is a new machine and we will move
// on to registration.
machine, _ := h.GetMachineByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey)
machine, _ := h.db.GetMachineByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey)
if machine != nil {
log.Trace().
Caller().
@ -388,7 +389,7 @@ func (h *Headscale) handleAuthKeyCommon(
machine.NodeKey = nodeKey
machine.AuthKeyID = uint(pak.ID)
err := h.RefreshMachine(machine, registerRequest.Expiry)
err := h.db.RefreshMachine(machine, registerRequest.Expiry)
if err != nil {
log.Error().
Caller().
@ -403,7 +404,7 @@ func (h *Headscale) handleAuthKeyCommon(
aclTags := pak.toProto().AclTags
if len(aclTags) > 0 {
// This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login
err = h.SetTags(machine, aclTags)
err = h.db.SetTags(machine, aclTags, h.UpdateACLRules)
if err != nil {
log.Error().
@ -420,7 +421,7 @@ func (h *Headscale) handleAuthKeyCommon(
} else {
now := time.Now().UTC()
givenName, err := h.GenerateGivenName(MachinePublicKeyStripPrefix(machineKey), registerRequest.Hostinfo.Hostname)
givenName, err := h.db.GenerateGivenName(util.MachinePublicKeyStripPrefix(machineKey), registerRequest.Hostinfo.Hostname)
if err != nil {
log.Error().
Caller().
@ -436,7 +437,7 @@ func (h *Headscale) handleAuthKeyCommon(
Hostname: registerRequest.Hostinfo.Hostname,
GivenName: givenName,
UserID: pak.User.ID,
MachineKey: MachinePublicKeyStripPrefix(machineKey),
MachineKey: util.MachinePublicKeyStripPrefix(machineKey),
RegisterMethod: RegisterMethodAuthKey,
Expiry: &registerRequest.Expiry,
NodeKey: nodeKey,
@ -445,7 +446,7 @@ func (h *Headscale) handleAuthKeyCommon(
ForcedTags: pak.toProto().AclTags,
}
machine, err = h.RegisterMachine(
machine, err = h.db.RegisterMachine(
machineToRegister,
)
if err != nil {
@ -462,7 +463,7 @@ func (h *Headscale) handleAuthKeyCommon(
}
}
err = h.UsePreAuthKey(pak)
err = h.db.UsePreAuthKey(pak)
if err != nil {
log.Error().
Caller().
@ -591,7 +592,7 @@ func (h *Headscale) handleMachineLogOutCommon(
Str("machine", machine.Hostname).
Msg("Client requested logout")
err := h.ExpireMachine(&machine)
err := h.db.ExpireMachine(&machine)
if err != nil {
log.Error().
Caller().
@ -634,7 +635,7 @@ func (h *Headscale) handleMachineLogOutCommon(
}
if machine.isEphemeral() {
err = h.HardDeleteMachine(&machine)
err = h.db.HardDeleteMachine(&machine)
if err != nil {
log.Error().
Err(err).
@ -720,9 +721,9 @@ func (h *Headscale) handleMachineRefreshKeyCommon(
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Msg("We have the OldNodeKey in the database. This is a key refresh")
machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey)
machine.NodeKey = util.NodePublicKeyStripPrefix(registerRequest.NodeKey)
if err := h.db.Save(&machine).Error; err != nil {
if err := h.db.db.Save(&machine).Error; err != nil {
log.Error().
Caller().
Err(err).