create DB struct

This is step one in detaching the Database layer from Headscale (h). The
ultimate goal is to have all function that does database operations in
its own package, and keep the business logic and writing separate.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2023-05-11 09:09:18 +02:00 committed by Kristoffer Dalby
parent b01f1f1867
commit 14e29a7bee
48 changed files with 1731 additions and 1572 deletions

View file

@ -11,6 +11,8 @@ import (
"time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/patrickmn/go-cache"
"github.com/rs/zerolog/log"
"github.com/samber/lo"
"go4.org/netipx"
@ -21,23 +23,23 @@ import (
)
const (
ErrMachineNotFound = Error("machine not found")
ErrMachineRouteIsNotAvailable = Error("route is not available on machine")
ErrMachineAddressesInvalid = Error("failed to parse machine addresses")
ErrMachineNotFoundRegistrationCache = Error(
"machine not found in registration cache",
)
ErrCouldNotConvertMachineInterface = Error("failed to convert machine interface")
ErrHostnameTooLong = Error("Hostname too long")
ErrDifferentRegisteredUser = Error(
"machine was previously registered with a different user",
)
MachineGivenNameHashLength = 8
MachineGivenNameTrimSize = 2
maxHostnameLength = 255
)
const (
maxHostnameLength = 255
var (
ErrMachineNotFound = errors.New("machine not found")
ErrMachineRouteIsNotAvailable = errors.New("route is not available on machine")
ErrMachineAddressesInvalid = errors.New("failed to parse machine addresses")
ErrMachineNotFoundRegistrationCache = errors.New(
"machine not found in registration cache",
)
ErrCouldNotConvertMachineInterface = errors.New("failed to convert machine interface")
ErrHostnameTooLong = errors.New("hostname too long")
ErrDifferentRegisteredUser = errors.New(
"machine was previously registered with a different user",
)
)
// Machine is a Headscale client.
@ -188,8 +190,10 @@ func (machine *Machine) canAccess(filter []tailcfg.FilterRule, machine2 *Machine
// filterMachinesByACL wrapper function to not have devs pass around locks and maps
// related to the application outside of tests.
func (h *Headscale) filterMachinesByACL(currentMachine *Machine, peers Machines) Machines {
return filterMachinesByACL(currentMachine, peers, h.aclRules)
func (hsdb *HSDatabase) filterMachinesByACL(
aclRules []tailcfg.FilterRule,
currentMachine *Machine, peers Machines) Machines {
return filterMachinesByACL(currentMachine, peers, aclRules)
}
// filterMachinesByACL returns the list of peers authorized to be accessed from a given machine.
@ -213,14 +217,14 @@ func filterMachinesByACL(
return result
}
func (h *Headscale) ListPeers(machine *Machine) (Machines, error) {
func (hsdb *HSDatabase) ListPeers(machine *Machine) (Machines, error) {
log.Trace().
Caller().
Str("machine", machine.Hostname).
Msg("Finding direct peers")
machines := Machines{}
if err := h.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where("node_key <> ?",
if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where("node_key <> ?",
machine.NodeKey).Find(&machines).Error; err != nil {
log.Error().Err(err).Msg("Error accessing db")
@ -237,23 +241,27 @@ func (h *Headscale) ListPeers(machine *Machine) (Machines, error) {
return machines, nil
}
func (h *Headscale) getPeers(machine *Machine) (Machines, error) {
func (hsdb *HSDatabase) getPeers(
aclPolicy *ACLPolicy,
aclRules []tailcfg.FilterRule,
machine *Machine,
) (Machines, error) {
var peers Machines
var err error
// If ACLs rules are defined, filter visible host list with the ACLs
// else use the classic user scope
if h.aclPolicy != nil {
if aclPolicy != nil {
var machines []Machine
machines, err = h.ListMachines()
machines, err = hsdb.ListMachines()
if err != nil {
log.Error().Err(err).Msg("Error retrieving list of machines")
return Machines{}, err
}
peers = h.filterMachinesByACL(machine, machines)
peers = hsdb.filterMachinesByACL(aclRules, machine, machines)
} else {
peers, err = h.ListPeers(machine)
peers, err = hsdb.ListPeers(machine)
if err != nil {
log.Error().
Caller().
@ -275,10 +283,14 @@ func (h *Headscale) getPeers(machine *Machine) (Machines, error) {
return peers, nil
}
func (h *Headscale) getValidPeers(machine *Machine) (Machines, error) {
func (hsdb *HSDatabase) getValidPeers(
aclPolicy *ACLPolicy,
aclRules []tailcfg.FilterRule,
machine *Machine,
) (Machines, error) {
validPeers := make(Machines, 0)
peers, err := h.getPeers(machine)
peers, err := hsdb.getPeers(aclPolicy, aclRules, machine)
if err != nil {
return Machines{}, err
}
@ -292,18 +304,18 @@ func (h *Headscale) getValidPeers(machine *Machine) (Machines, error) {
return validPeers, nil
}
func (h *Headscale) ListMachines() ([]Machine, error) {
func (hsdb *HSDatabase) ListMachines() ([]Machine, error) {
machines := []Machine{}
if err := h.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Find(&machines).Error; err != nil {
if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Find(&machines).Error; err != nil {
return nil, err
}
return machines, nil
}
func (h *Headscale) ListMachinesByGivenName(givenName string) ([]Machine, error) {
func (hsdb *HSDatabase) ListMachinesByGivenName(givenName string) ([]Machine, error) {
machines := []Machine{}
if err := h.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where("given_name = ?", givenName).Find(&machines).Error; err != nil {
if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where("given_name = ?", givenName).Find(&machines).Error; err != nil {
return nil, err
}
@ -311,8 +323,8 @@ func (h *Headscale) ListMachinesByGivenName(givenName string) ([]Machine, error)
}
// GetMachine finds a Machine by name and user and returns the Machine struct.
func (h *Headscale) GetMachine(user string, name string) (*Machine, error) {
machines, err := h.ListMachinesByUser(user)
func (hsdb *HSDatabase) GetMachine(user string, name string) (*Machine, error) {
machines, err := hsdb.ListMachinesByUser(user)
if err != nil {
return nil, err
}
@ -327,8 +339,8 @@ func (h *Headscale) GetMachine(user string, name string) (*Machine, error) {
}
// GetMachineByGivenName finds a Machine by given name and user and returns the Machine struct.
func (h *Headscale) GetMachineByGivenName(user string, givenName string) (*Machine, error) {
machines, err := h.ListMachinesByUser(user)
func (hsdb *HSDatabase) GetMachineByGivenName(user string, givenName string) (*Machine, error) {
machines, err := hsdb.ListMachinesByUser(user)
if err != nil {
return nil, err
}
@ -343,9 +355,9 @@ func (h *Headscale) GetMachineByGivenName(user string, givenName string) (*Machi
}
// GetMachineByID finds a Machine by ID and returns the Machine struct.
func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
func (hsdb *HSDatabase) GetMachineByID(id uint64) (*Machine, error) {
m := Machine{}
if result := h.db.Preload("AuthKey").Preload("User").Find(&Machine{ID: id}).First(&m); result.Error != nil {
if result := hsdb.db.Preload("AuthKey").Preload("User").Find(&Machine{ID: id}).First(&m); result.Error != nil {
return nil, result.Error
}
@ -353,11 +365,11 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
}
// GetMachineByMachineKey finds a Machine by its MachineKey and returns the Machine struct.
func (h *Headscale) GetMachineByMachineKey(
func (hsdb *HSDatabase) GetMachineByMachineKey(
machineKey key.MachinePublic,
) (*Machine, error) {
m := Machine{}
if result := h.db.Preload("AuthKey").Preload("User").First(&m, "machine_key = ?", MachinePublicKeyStripPrefix(machineKey)); result.Error != nil {
if result := hsdb.db.Preload("AuthKey").Preload("User").First(&m, "machine_key = ?", util.MachinePublicKeyStripPrefix(machineKey)); result.Error != nil {
return nil, result.Error
}
@ -365,12 +377,12 @@ func (h *Headscale) GetMachineByMachineKey(
}
// GetMachineByNodeKey finds a Machine by its current NodeKey.
func (h *Headscale) GetMachineByNodeKey(
func (hsdb *HSDatabase) GetMachineByNodeKey(
nodeKey key.NodePublic,
) (*Machine, error) {
machine := Machine{}
if result := h.db.Preload("AuthKey").Preload("User").First(&machine, "node_key = ?",
NodePublicKeyStripPrefix(nodeKey)); result.Error != nil {
if result := hsdb.db.Preload("AuthKey").Preload("User").First(&machine, "node_key = ?",
util.NodePublicKeyStripPrefix(nodeKey)); result.Error != nil {
return nil, result.Error
}
@ -378,14 +390,14 @@ func (h *Headscale) GetMachineByNodeKey(
}
// GetMachineByAnyNodeKey finds a Machine by its MachineKey, its current NodeKey or the old one, and returns the Machine struct.
func (h *Headscale) GetMachineByAnyKey(
func (hsdb *HSDatabase) GetMachineByAnyKey(
machineKey key.MachinePublic, nodeKey key.NodePublic, oldNodeKey key.NodePublic,
) (*Machine, error) {
machine := Machine{}
if result := h.db.Preload("AuthKey").Preload("User").First(&machine, "machine_key = ? OR node_key = ? OR node_key = ?",
MachinePublicKeyStripPrefix(machineKey),
NodePublicKeyStripPrefix(nodeKey),
NodePublicKeyStripPrefix(oldNodeKey)); result.Error != nil {
if result := hsdb.db.Preload("AuthKey").Preload("User").First(&machine, "machine_key = ? OR node_key = ? OR node_key = ?",
util.MachinePublicKeyStripPrefix(machineKey),
util.NodePublicKeyStripPrefix(nodeKey),
util.NodePublicKeyStripPrefix(oldNodeKey)); result.Error != nil {
return nil, result.Error
}
@ -394,8 +406,8 @@ func (h *Headscale) GetMachineByAnyKey(
// UpdateMachineFromDatabase takes a Machine struct pointer (typically already loaded from database
// and updates it with the latest data from the database.
func (h *Headscale) UpdateMachineFromDatabase(machine *Machine) error {
if result := h.db.Find(machine).First(&machine); result.Error != nil {
func (hsdb *HSDatabase) UpdateMachineFromDatabase(machine *Machine) error {
if result := hsdb.db.Find(machine).First(&machine); result.Error != nil {
return result.Error
}
@ -403,20 +415,28 @@ func (h *Headscale) UpdateMachineFromDatabase(machine *Machine) error {
}
// SetTags takes a Machine struct pointer and update the forced tags.
func (h *Headscale) SetTags(machine *Machine, tags []string) error {
func (hsdb *HSDatabase) SetTags(
machine *Machine,
tags []string,
// TODO(kradalby): This is a temporary measure to be able to detach the
// database completely from the global h. In the future, as part of this
// reorg, the rules will be generated on a per node basis, and not be prone
// to throwing error at save.
updateACL func() error) error {
newTags := []string{}
for _, tag := range tags {
if !contains(newTags, tag) {
if !util.StringOrPrefixListContains(newTags, tag) {
newTags = append(newTags, tag)
}
}
machine.ForcedTags = newTags
if err := h.UpdateACLRules(); err != nil && !errors.Is(err, errEmptyPolicy) {
if err := updateACL(); err != nil && !errors.Is(err, errEmptyPolicy) {
return err
}
h.setLastStateChangeToNow()
if err := h.db.Save(machine).Error; err != nil {
hsdb.notifyStateChange()
if err := hsdb.db.Save(machine).Error; err != nil {
return fmt.Errorf("failed to update tags for machine in the database: %w", err)
}
@ -424,13 +444,13 @@ func (h *Headscale) SetTags(machine *Machine, tags []string) error {
}
// ExpireMachine takes a Machine struct and sets the expire field to now.
func (h *Headscale) ExpireMachine(machine *Machine) error {
func (hsdb *HSDatabase) ExpireMachine(machine *Machine) error {
now := time.Now()
machine.Expiry = &now
h.setLastStateChangeToNow()
hsdb.notifyStateChange()
if err := h.db.Save(machine).Error; err != nil {
if err := hsdb.db.Save(machine).Error; err != nil {
return fmt.Errorf("failed to expire machine in the database: %w", err)
}
@ -439,7 +459,7 @@ func (h *Headscale) ExpireMachine(machine *Machine) error {
// RenameMachine takes a Machine struct and a new GivenName for the machines
// and renames it.
func (h *Headscale) RenameMachine(machine *Machine, newName string) error {
func (hsdb *HSDatabase) RenameMachine(machine *Machine, newName string) error {
err := CheckForFQDNRules(
newName,
)
@ -455,9 +475,9 @@ func (h *Headscale) RenameMachine(machine *Machine, newName string) error {
}
machine.GivenName = newName
h.setLastStateChangeToNow()
hsdb.notifyStateChange()
if err := h.db.Save(machine).Error; err != nil {
if err := hsdb.db.Save(machine).Error; err != nil {
return fmt.Errorf("failed to rename machine in the database: %w", err)
}
@ -465,15 +485,15 @@ func (h *Headscale) RenameMachine(machine *Machine, newName string) error {
}
// RefreshMachine takes a Machine struct and sets the expire field to now.
func (h *Headscale) RefreshMachine(machine *Machine, expiry time.Time) error {
func (hsdb *HSDatabase) RefreshMachine(machine *Machine, expiry time.Time) error {
now := time.Now()
machine.LastSuccessfulUpdate = &now
machine.Expiry = &expiry
h.setLastStateChangeToNow()
hsdb.notifyStateChange()
if err := h.db.Save(machine).Error; err != nil {
if err := hsdb.db.Save(machine).Error; err != nil {
return fmt.Errorf(
"failed to refresh machine (update expiration) in the database: %w",
err,
@ -484,21 +504,21 @@ func (h *Headscale) RefreshMachine(machine *Machine, expiry time.Time) error {
}
// DeleteMachine softs deletes a Machine from the database.
func (h *Headscale) DeleteMachine(machine *Machine) error {
err := h.DeleteMachineRoutes(machine)
func (hsdb *HSDatabase) DeleteMachine(machine *Machine) error {
err := hsdb.DeleteMachineRoutes(machine)
if err != nil {
return err
}
if err := h.db.Delete(&machine).Error; err != nil {
if err := hsdb.db.Delete(&machine).Error; err != nil {
return err
}
return nil
}
func (h *Headscale) TouchMachine(machine *Machine) error {
return h.db.Updates(Machine{
func (hsdb *HSDatabase) TouchMachine(machine *Machine) error {
return hsdb.db.Updates(Machine{
ID: machine.ID,
LastSeen: machine.LastSeen,
LastSuccessfulUpdate: machine.LastSuccessfulUpdate,
@ -506,13 +526,13 @@ func (h *Headscale) TouchMachine(machine *Machine) error {
}
// HardDeleteMachine hard deletes a Machine from the database.
func (h *Headscale) HardDeleteMachine(machine *Machine) error {
err := h.DeleteMachineRoutes(machine)
func (hsdb *HSDatabase) HardDeleteMachine(machine *Machine) error {
err := hsdb.DeleteMachineRoutes(machine)
if err != nil {
return err
}
if err := h.db.Unscoped().Delete(&machine).Error; err != nil {
if err := hsdb.db.Unscoped().Delete(&machine).Error; err != nil {
return err
}
@ -524,8 +544,8 @@ func (machine *Machine) GetHostInfo() tailcfg.Hostinfo {
return tailcfg.Hostinfo(machine.HostInfo)
}
func (h *Headscale) isOutdated(machine *Machine) bool {
if err := h.UpdateMachineFromDatabase(machine); err != nil {
func (hsdb *HSDatabase) isOutdated(machine *Machine, lastChange time.Time) bool {
if err := hsdb.UpdateMachineFromDatabase(machine); err != nil {
// It does not seem meaningful to propagate this error as the end result
// will have to be that the machine has to be considered outdated.
return true
@ -536,7 +556,6 @@ func (h *Headscale) isOutdated(machine *Machine) bool {
// TODO(kradalby): Only request updates from users where we can talk to nodes
// This would mostly be for a bit of performance, and can be calculated based on
// ACLs.
lastChange := h.getLastStateChange()
lastUpdate := machine.CreatedAt
if machine.LastSuccessfulUpdate != nil {
lastUpdate = *machine.LastSuccessfulUpdate
@ -576,15 +595,16 @@ func (machines MachinesP) String() string {
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
}
func (h *Headscale) toNodes(
func (hsdb *HSDatabase) toNodes(
machines Machines,
aclPolicy *ACLPolicy,
baseDomain string,
dnsConfig *tailcfg.DNSConfig,
) ([]*tailcfg.Node, error) {
nodes := make([]*tailcfg.Node, len(machines))
for index, machine := range machines {
node, err := h.toNode(machine, baseDomain, dnsConfig)
node, err := hsdb.toNode(machine, aclPolicy, baseDomain, dnsConfig)
if err != nil {
return nil, err
}
@ -597,13 +617,14 @@ func (h *Headscale) toNodes(
// toNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes
// as per the expected behaviour in the official SaaS.
func (h *Headscale) toNode(
func (hsdb *HSDatabase) toNode(
machine Machine,
aclPolicy *ACLPolicy,
baseDomain string,
dnsConfig *tailcfg.DNSConfig,
) (*tailcfg.Node, error) {
var nodeKey key.NodePublic
err := nodeKey.UnmarshalText([]byte(NodePublicKeyEnsurePrefix(machine.NodeKey)))
err := nodeKey.UnmarshalText([]byte(util.NodePublicKeyEnsurePrefix(machine.NodeKey)))
if err != nil {
log.Trace().
Caller().
@ -617,7 +638,7 @@ func (h *Headscale) toNode(
// MachineKey is only used in the legacy protocol
if machine.MachineKey != "" {
err = machineKey.UnmarshalText(
[]byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)),
[]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)),
)
if err != nil {
return nil, fmt.Errorf("failed to parse machine public key: %w", err)
@ -627,7 +648,7 @@ func (h *Headscale) toNode(
var discoKey key.DiscoPublic
if machine.DiscoKey != "" {
err := discoKey.UnmarshalText(
[]byte(DiscoPublicKeyEnsurePrefix(machine.DiscoKey)),
[]byte(util.DiscoPublicKeyEnsurePrefix(machine.DiscoKey)),
)
if err != nil {
return nil, fmt.Errorf("failed to parse disco public key: %w", err)
@ -646,13 +667,13 @@ func (h *Headscale) toNode(
[]netip.Prefix{},
addrs...) // we append the node own IP, as it is required by the clients
primaryRoutes, err := h.getMachinePrimaryRoutes(&machine)
primaryRoutes, err := hsdb.getMachinePrimaryRoutes(&machine)
if err != nil {
return nil, err
}
primaryPrefixes := Routes(primaryRoutes).toPrefixes()
machineRoutes, err := h.GetMachineRoutes(&machine)
machineRoutes, err := hsdb.GetMachineRoutes(&machine)
if err != nil {
return nil, err
}
@ -699,13 +720,13 @@ func (h *Headscale) toNode(
online := machine.isOnline()
tags, _ := getTags(h.aclPolicy, machine, h.cfg.OIDC.StripEmaildomain)
tags, _ := getTags(aclPolicy, machine, hsdb.stripEmailDomain)
tags = lo.Uniq(append(tags, machine.ForcedTags...))
node := tailcfg.Node{
ID: tailcfg.NodeID(machine.ID), // this is the actual ID
StableID: tailcfg.StableNodeID(
strconv.FormatUint(machine.ID, Base10),
strconv.FormatUint(machine.ID, util.Base10),
), // in headscale, unlike tailcontrol server, IDs are permanent
Name: hostname,
@ -827,7 +848,8 @@ func getTags(
return validTags, invalidTags
}
func (h *Headscale) RegisterMachineFromAuthCallback(
func (hsdb *HSDatabase) RegisterMachineFromAuthCallback(
cache *cache.Cache,
nodeKeyStr string,
userName string,
machineExpiry *time.Time,
@ -846,9 +868,9 @@ func (h *Headscale) RegisterMachineFromAuthCallback(
Str("expiresAt", fmt.Sprintf("%v", machineExpiry)).
Msg("Registering machine from API/CLI or auth callback")
if machineInterface, ok := h.registrationCache.Get(NodePublicKeyStripPrefix(nodeKey)); ok {
if machineInterface, ok := cache.Get(util.NodePublicKeyStripPrefix(nodeKey)); ok {
if registrationMachine, ok := machineInterface.(Machine); ok {
user, err := h.GetUser(userName)
user, err := hsdb.GetUser(userName)
if err != nil {
return nil, fmt.Errorf(
"failed to find user in register machine from auth callback, %w",
@ -869,12 +891,12 @@ func (h *Headscale) RegisterMachineFromAuthCallback(
registrationMachine.Expiry = machineExpiry
}
machine, err := h.RegisterMachine(
machine, err := hsdb.RegisterMachine(
registrationMachine,
)
if err == nil {
h.registrationCache.Delete(nodeKeyStr)
cache.Delete(nodeKeyStr)
}
return machine, err
@ -887,7 +909,7 @@ func (h *Headscale) RegisterMachineFromAuthCallback(
}
// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey.
func (h *Headscale) RegisterMachine(machine Machine,
func (hsdb *HSDatabase) RegisterMachine(machine Machine,
) (*Machine, error) {
log.Debug().
Str("machine", machine.Hostname).
@ -900,7 +922,7 @@ func (h *Headscale) RegisterMachine(machine Machine,
// so we store the machine.Expire and machine.Nodekey that has been set when
// adding it to the registrationCache
if len(machine.IPAddresses) > 0 {
if err := h.db.Save(&machine).Error; err != nil {
if err := hsdb.db.Save(&machine).Error; err != nil {
return nil, fmt.Errorf("failed register existing machine in the database: %w", err)
}
@ -915,10 +937,10 @@ func (h *Headscale) RegisterMachine(machine Machine,
return &machine, nil
}
h.ipAllocationMutex.Lock()
defer h.ipAllocationMutex.Unlock()
hsdb.ipAllocationMutex.Lock()
defer hsdb.ipAllocationMutex.Unlock()
ips, err := h.getAvailableIPs()
ips, err := hsdb.getAvailableIPs()
if err != nil {
log.Error().
Caller().
@ -931,7 +953,7 @@ func (h *Headscale) RegisterMachine(machine Machine,
machine.IPAddresses = ips
if err := h.db.Save(&machine).Error; err != nil {
if err := hsdb.db.Save(&machine).Error; err != nil {
return nil, fmt.Errorf("failed register(save) machine in the database: %w", err)
}
@ -945,10 +967,10 @@ func (h *Headscale) RegisterMachine(machine Machine,
}
// GetAdvertisedRoutes returns the routes that are be advertised by the given machine.
func (h *Headscale) GetAdvertisedRoutes(machine *Machine) ([]netip.Prefix, error) {
func (hsdb *HSDatabase) GetAdvertisedRoutes(machine *Machine) ([]netip.Prefix, error) {
routes := []Route{}
err := h.db.
err := hsdb.db.
Preload("Machine").
Where("machine_id = ? AND advertised = ?", machine.ID, true).Find(&routes).Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
@ -970,10 +992,10 @@ func (h *Headscale) GetAdvertisedRoutes(machine *Machine) ([]netip.Prefix, error
}
// GetEnabledRoutes returns the routes that are enabled for the machine.
func (h *Headscale) GetEnabledRoutes(machine *Machine) ([]netip.Prefix, error) {
func (hsdb *HSDatabase) GetEnabledRoutes(machine *Machine) ([]netip.Prefix, error) {
routes := []Route{}
err := h.db.
err := hsdb.db.
Preload("Machine").
Where("machine_id = ? AND advertised = ? AND enabled = ?", machine.ID, true, true).
Find(&routes).Error
@ -995,13 +1017,13 @@ func (h *Headscale) GetEnabledRoutes(machine *Machine) ([]netip.Prefix, error) {
return prefixes, nil
}
func (h *Headscale) IsRoutesEnabled(machine *Machine, routeStr string) bool {
func (hsdb *HSDatabase) IsRoutesEnabled(machine *Machine, routeStr string) bool {
route, err := netip.ParsePrefix(routeStr)
if err != nil {
return false
}
enabledRoutes, err := h.GetEnabledRoutes(machine)
enabledRoutes, err := hsdb.GetEnabledRoutes(machine)
if err != nil {
log.Error().Err(err).Msg("Could not get enabled routes")
@ -1018,7 +1040,7 @@ func (h *Headscale) IsRoutesEnabled(machine *Machine, routeStr string) bool {
}
// enableRoutes enables new routes based on a list of new routes.
func (h *Headscale) enableRoutes(machine *Machine, routeStrs ...string) error {
func (hsdb *HSDatabase) enableRoutes(machine *Machine, routeStrs ...string) error {
newRoutes := make([]netip.Prefix, len(routeStrs))
for index, routeStr := range routeStrs {
route, err := netip.ParsePrefix(routeStr)
@ -1029,13 +1051,13 @@ func (h *Headscale) enableRoutes(machine *Machine, routeStrs ...string) error {
newRoutes[index] = route
}
advertisedRoutes, err := h.GetAdvertisedRoutes(machine)
advertisedRoutes, err := hsdb.GetAdvertisedRoutes(machine)
if err != nil {
return err
}
for _, newRoute := range newRoutes {
if !contains(advertisedRoutes, newRoute) {
if !util.StringOrPrefixListContains(advertisedRoutes, newRoute) {
return fmt.Errorf(
"route (%s) is not available on node %s: %w",
machine.Hostname,
@ -1047,7 +1069,7 @@ func (h *Headscale) enableRoutes(machine *Machine, routeStrs ...string) error {
// Separate loop so we don't leave things in a half-updated state
for _, prefix := range newRoutes {
route := Route{}
err := h.db.Preload("Machine").
err := hsdb.db.Preload("Machine").
Where("machine_id = ? AND prefix = ?", machine.ID, IPPrefix(prefix)).
First(&route).Error
if err == nil {
@ -1056,10 +1078,10 @@ func (h *Headscale) enableRoutes(machine *Machine, routeStrs ...string) error {
// Mark already as primary if there is only this node offering this subnet
// (and is not an exit route)
if !route.isExitRoute() {
route.IsPrimary = h.isUniquePrefix(route)
route.IsPrimary = hsdb.isUniquePrefix(route)
}
err = h.db.Save(&route).Error
err = hsdb.db.Save(&route).Error
if err != nil {
return fmt.Errorf("failed to enable route: %w", err)
}
@ -1068,19 +1090,19 @@ func (h *Headscale) enableRoutes(machine *Machine, routeStrs ...string) error {
}
}
h.setLastStateChangeToNow()
hsdb.notifyStateChange()
return nil
}
// EnableAutoApprovedRoutes enables any routes advertised by a machine that match the ACL autoApprovers policy.
func (h *Headscale) EnableAutoApprovedRoutes(machine *Machine) error {
func (hsdb *HSDatabase) EnableAutoApprovedRoutes(aclPolicy *ACLPolicy, machine *Machine) error {
if len(machine.IPAddresses) == 0 {
return nil // This machine has no IPAddresses, so can't possibly match any autoApprovers ACLs
}
routes := []Route{}
err := h.db.
err := hsdb.db.
Preload("Machine").
Where("machine_id = ? AND advertised = true AND enabled = false", machine.ID).
Find(&routes).Error
@ -1097,7 +1119,7 @@ func (h *Headscale) EnableAutoApprovedRoutes(machine *Machine) error {
approvedRoutes := []Route{}
for _, advertisedRoute := range routes {
routeApprovers, err := h.aclPolicy.AutoApprovers.GetRouteApprovers(
routeApprovers, err := aclPolicy.AutoApprovers.GetRouteApprovers(
netip.Prefix(advertisedRoute.Prefix),
)
if err != nil {
@ -1113,7 +1135,7 @@ func (h *Headscale) EnableAutoApprovedRoutes(machine *Machine) error {
if approvedAlias == machine.User.Name {
approvedRoutes = append(approvedRoutes, advertisedRoute)
} else {
approvedIps, err := h.aclPolicy.expandAlias([]Machine{*machine}, approvedAlias, h.cfg.OIDC.StripEmaildomain)
approvedIps, err := aclPolicy.expandAlias([]Machine{*machine}, approvedAlias, hsdb.stripEmailDomain)
if err != nil {
log.Err(err).
Str("alias", approvedAlias).
@ -1132,7 +1154,7 @@ func (h *Headscale) EnableAutoApprovedRoutes(machine *Machine) error {
for i, approvedRoute := range approvedRoutes {
approvedRoutes[i].Enabled = true
err = h.db.Save(&approvedRoutes[i]).Error
err = hsdb.db.Save(&approvedRoutes[i]).Error
if err != nil {
log.Err(err).
Str("approvedRoute", approvedRoute.String()).
@ -1146,10 +1168,10 @@ func (h *Headscale) EnableAutoApprovedRoutes(machine *Machine) error {
return nil
}
func (h *Headscale) generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
func (hsdb *HSDatabase) generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
normalizedHostname, err := NormalizeToFQDNRules(
suppliedName,
h.cfg.OIDC.StripEmaildomain,
hsdb.stripEmailDomain,
)
if err != nil {
return "", err
@ -1162,7 +1184,7 @@ func (h *Headscale) generateGivenName(suppliedName string, randomSuffix bool) (s
normalizedHostname = normalizedHostname[:trimmedHostnameLength]
}
suffix, err := GenerateRandomStringDNSSafe(MachineGivenNameHashLength)
suffix, err := util.GenerateRandomStringDNSSafe(MachineGivenNameHashLength)
if err != nil {
return "", err
}
@ -1173,21 +1195,21 @@ func (h *Headscale) generateGivenName(suppliedName string, randomSuffix bool) (s
return normalizedHostname, nil
}
func (h *Headscale) GenerateGivenName(machineKey string, suppliedName string) (string, error) {
givenName, err := h.generateGivenName(suppliedName, false)
func (hsdb *HSDatabase) GenerateGivenName(machineKey string, suppliedName string) (string, error) {
givenName, err := hsdb.generateGivenName(suppliedName, false)
if err != nil {
return "", err
}
// Tailscale rules (may differ) https://tailscale.com/kb/1098/machine-names/
machines, err := h.ListMachinesByGivenName(givenName)
machines, err := hsdb.ListMachinesByGivenName(givenName)
if err != nil {
return "", err
}
for _, machine := range machines {
if machine.MachineKey != machineKey && machine.GivenName == givenName {
postfixedName, err := h.generateGivenName(suppliedName, true)
postfixedName, err := hsdb.generateGivenName(suppliedName, true)
if err != nil {
return "", err
}