Split code into modules

This is a massive commit that restructures the code into modules:

db/
    All functions related to modifying the Database

types/
    All type definitions and methods that can be exclusivly used on
    these types without dependencies

policy/
    All Policy related code, now without dependencies on the Database.

policy/matcher/
    Dedicated code to match machines in a list of FilterRules

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2023-05-21 19:37:59 +03:00 committed by Kristoffer Dalby
parent 14e29a7bee
commit feb15365b5
51 changed files with 4677 additions and 4290 deletions

View file

@ -9,6 +9,7 @@ import (
"strings"
"time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
@ -171,7 +172,7 @@ func (h *Headscale) handleRegisterCommon(
// that we rely on a method that calls back some how (OpenID or CLI)
// We create the machine and then keep it around until a callback
// happens
newMachine := Machine{
newMachine := types.Machine{
MachineKey: util.MachinePublicKeyStripPrefix(machineKey),
Hostname: registerRequest.Hostinfo.Hostname,
GivenName: givenName,
@ -214,8 +215,7 @@ func (h *Headscale) handleRegisterCommon(
[]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)),
)
if err != nil || storedMachineKey.IsZero() {
machine.MachineKey = util.MachinePublicKeyStripPrefix(machineKey)
if err := h.db.db.Save(&machine).Error; err != nil {
if err := h.db.MachineSetMachineKey(machine, machineKey); err != nil {
log.Error().
Caller().
Str("func", "RegistrationHandler").
@ -244,7 +244,7 @@ func (h *Headscale) handleRegisterCommon(
// If machine is not expired, and it is register, we have a already accepted this machine,
// let it proceed with a valid registration
if !machine.isExpired() {
if !machine.IsExpired() {
h.handleMachineValidRegistrationCommon(writer, *machine, machineKey, isNoise)
return
@ -253,7 +253,7 @@ func (h *Headscale) handleRegisterCommon(
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
if machine.NodeKey == util.NodePublicKeyStripPrefix(registerRequest.OldNodeKey) &&
!machine.isExpired() {
!machine.IsExpired() {
h.handleMachineRefreshKeyCommon(
writer,
registerRequest,
@ -312,7 +312,7 @@ func (h *Headscale) handleAuthKeyCommon(
Msgf("Processing auth key for %s", registerRequest.Hostinfo.Hostname)
resp := tailcfg.RegisterResponse{}
pak, err := h.db.checkKeyValidity(registerRequest.Auth.AuthKey)
pak, err := h.db.ValidatePreAuthKey(registerRequest.Auth.AuthKey)
if err != nil {
log.Error().
Caller().
@ -333,7 +333,7 @@ func (h *Headscale) handleAuthKeyCommon(
Err(err).
Msg("Cannot encode message")
http.Error(writer, "Internal server error", http.StatusInternalServerError)
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name).
machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name).
Inc()
return
@ -358,10 +358,10 @@ func (h *Headscale) handleAuthKeyCommon(
Msg("Failed authentication via AuthKey")
if pak != nil {
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name).
machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name).
Inc()
} else {
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", "unknown").Inc()
machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", "unknown").Inc()
}
return
@ -401,10 +401,10 @@ func (h *Headscale) handleAuthKeyCommon(
return
}
aclTags := pak.toProto().AclTags
aclTags := pak.Proto().AclTags
if len(aclTags) > 0 {
// This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login
err = h.db.SetTags(machine, aclTags, h.UpdateACLRules)
err = h.db.SetTags(machine, aclTags)
if err != nil {
log.Error().
@ -433,17 +433,17 @@ func (h *Headscale) handleAuthKeyCommon(
return
}
machineToRegister := Machine{
machineToRegister := types.Machine{
Hostname: registerRequest.Hostinfo.Hostname,
GivenName: givenName,
UserID: pak.User.ID,
MachineKey: util.MachinePublicKeyStripPrefix(machineKey),
RegisterMethod: RegisterMethodAuthKey,
RegisterMethod: util.RegisterMethodAuthKey,
Expiry: &registerRequest.Expiry,
NodeKey: nodeKey,
LastSeen: &now,
AuthKeyID: uint(pak.ID),
ForcedTags: pak.toProto().AclTags,
ForcedTags: pak.Proto().AclTags,
}
machine, err = h.db.RegisterMachine(
@ -455,7 +455,7 @@ func (h *Headscale) handleAuthKeyCommon(
Bool("noise", isNoise).
Err(err).
Msg("could not register machine")
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name).
machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name).
Inc()
http.Error(writer, "Internal server error", http.StatusInternalServerError)
@ -470,7 +470,7 @@ func (h *Headscale) handleAuthKeyCommon(
Bool("noise", isNoise).
Err(err).
Msg("Failed to use pre-auth key")
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name).
machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name).
Inc()
http.Error(writer, "Internal server error", http.StatusInternalServerError)
@ -478,10 +478,10 @@ func (h *Headscale) handleAuthKeyCommon(
}
resp.MachineAuthorized = true
resp.User = *pak.User.toTailscaleUser()
resp.User = *pak.User.TailscaleUser()
// Provide LoginName when registering with pre-auth key
// Otherwise it will need to exec `tailscale up` twice to fetch the *LoginName*
resp.Login = *pak.User.toTailscaleLogin()
resp.Login = *pak.User.TailscaleLogin()
respBody, err := h.marshalResponse(resp, machineKey, isNoise)
if err != nil {
@ -492,13 +492,13 @@ func (h *Headscale) handleAuthKeyCommon(
Str("machine", registerRequest.Hostinfo.Hostname).
Err(err).
Msg("Cannot encode message")
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name).
machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name).
Inc()
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "success", pak.User.Name).
machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "success", pak.User.Name).
Inc()
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
@ -581,7 +581,7 @@ func (h *Headscale) handleNewMachineCommon(
func (h *Headscale) handleMachineLogOutCommon(
writer http.ResponseWriter,
machine Machine,
machine types.Machine,
machineKey key.MachinePublic,
isNoise bool,
) {
@ -608,7 +608,7 @@ func (h *Headscale) handleMachineLogOutCommon(
resp.AuthURL = ""
resp.MachineAuthorized = false
resp.NodeKeyExpired = true
resp.User = *machine.User.toTailscaleUser()
resp.User = *machine.User.TailscaleUser()
respBody, err := h.marshalResponse(resp, machineKey, isNoise)
if err != nil {
log.Error().
@ -634,7 +634,7 @@ func (h *Headscale) handleMachineLogOutCommon(
return
}
if machine.isEphemeral() {
if machine.IsEphemeral() {
err = h.db.HardDeleteMachine(&machine)
if err != nil {
log.Error().
@ -655,7 +655,7 @@ func (h *Headscale) handleMachineLogOutCommon(
func (h *Headscale) handleMachineValidRegistrationCommon(
writer http.ResponseWriter,
machine Machine,
machine types.Machine,
machineKey key.MachinePublic,
isNoise bool,
) {
@ -670,8 +670,8 @@ func (h *Headscale) handleMachineValidRegistrationCommon(
resp.AuthURL = ""
resp.MachineAuthorized = true
resp.User = *machine.User.toTailscaleUser()
resp.Login = *machine.User.toTailscaleLogin()
resp.User = *machine.User.TailscaleUser()
resp.Login = *machine.User.TailscaleLogin()
respBody, err := h.marshalResponse(resp, machineKey, isNoise)
if err != nil {
@ -710,7 +710,7 @@ func (h *Headscale) handleMachineValidRegistrationCommon(
func (h *Headscale) handleMachineRefreshKeyCommon(
writer http.ResponseWriter,
registerRequest tailcfg.RegisterRequest,
machine Machine,
machine types.Machine,
machineKey key.MachinePublic,
isNoise bool,
) {
@ -721,9 +721,9 @@ func (h *Headscale) handleMachineRefreshKeyCommon(
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Msg("We have the OldNodeKey in the database. This is a key refresh")
machine.NodeKey = util.NodePublicKeyStripPrefix(registerRequest.NodeKey)
if err := h.db.db.Save(&machine).Error; err != nil {
err := h.db.MachineSetNodeKey(&machine, registerRequest.NodeKey)
if err != nil {
log.Error().
Caller().
Err(err).
@ -734,7 +734,7 @@ func (h *Headscale) handleMachineRefreshKeyCommon(
}
resp.AuthURL = ""
resp.User = *machine.User.toTailscaleUser()
resp.User = *machine.User.TailscaleUser()
respBody, err := h.marshalResponse(resp, machineKey, isNoise)
if err != nil {
log.Error().
@ -770,7 +770,7 @@ func (h *Headscale) handleMachineRefreshKeyCommon(
func (h *Headscale) handleMachineExpiredOrLoggedOutCommon(
writer http.ResponseWriter,
registerRequest tailcfg.RegisterRequest,
machine Machine,
machine types.Machine,
machineKey key.MachinePublic,
isNoise bool,
) {