create DB struct
This is step one in detaching the Database layer from Headscale (h). The ultimate goal is to have all function that does database operations in its own package, and keep the business logic and writing separate. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
b01f1f1867
commit
14e29a7bee
48 changed files with 1731 additions and 1572 deletions
|
@ -11,6 +11,8 @@ import (
|
|||
"time"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/patrickmn/go-cache"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/samber/lo"
|
||||
"go4.org/netipx"
|
||||
|
@ -21,23 +23,23 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
ErrMachineNotFound = Error("machine not found")
|
||||
ErrMachineRouteIsNotAvailable = Error("route is not available on machine")
|
||||
ErrMachineAddressesInvalid = Error("failed to parse machine addresses")
|
||||
ErrMachineNotFoundRegistrationCache = Error(
|
||||
"machine not found in registration cache",
|
||||
)
|
||||
ErrCouldNotConvertMachineInterface = Error("failed to convert machine interface")
|
||||
ErrHostnameTooLong = Error("Hostname too long")
|
||||
ErrDifferentRegisteredUser = Error(
|
||||
"machine was previously registered with a different user",
|
||||
)
|
||||
MachineGivenNameHashLength = 8
|
||||
MachineGivenNameTrimSize = 2
|
||||
maxHostnameLength = 255
|
||||
)
|
||||
|
||||
const (
|
||||
maxHostnameLength = 255
|
||||
var (
|
||||
ErrMachineNotFound = errors.New("machine not found")
|
||||
ErrMachineRouteIsNotAvailable = errors.New("route is not available on machine")
|
||||
ErrMachineAddressesInvalid = errors.New("failed to parse machine addresses")
|
||||
ErrMachineNotFoundRegistrationCache = errors.New(
|
||||
"machine not found in registration cache",
|
||||
)
|
||||
ErrCouldNotConvertMachineInterface = errors.New("failed to convert machine interface")
|
||||
ErrHostnameTooLong = errors.New("hostname too long")
|
||||
ErrDifferentRegisteredUser = errors.New(
|
||||
"machine was previously registered with a different user",
|
||||
)
|
||||
)
|
||||
|
||||
// Machine is a Headscale client.
|
||||
|
@ -188,8 +190,10 @@ func (machine *Machine) canAccess(filter []tailcfg.FilterRule, machine2 *Machine
|
|||
|
||||
// filterMachinesByACL wrapper function to not have devs pass around locks and maps
|
||||
// related to the application outside of tests.
|
||||
func (h *Headscale) filterMachinesByACL(currentMachine *Machine, peers Machines) Machines {
|
||||
return filterMachinesByACL(currentMachine, peers, h.aclRules)
|
||||
func (hsdb *HSDatabase) filterMachinesByACL(
|
||||
aclRules []tailcfg.FilterRule,
|
||||
currentMachine *Machine, peers Machines) Machines {
|
||||
return filterMachinesByACL(currentMachine, peers, aclRules)
|
||||
}
|
||||
|
||||
// filterMachinesByACL returns the list of peers authorized to be accessed from a given machine.
|
||||
|
@ -213,14 +217,14 @@ func filterMachinesByACL(
|
|||
return result
|
||||
}
|
||||
|
||||
func (h *Headscale) ListPeers(machine *Machine) (Machines, error) {
|
||||
func (hsdb *HSDatabase) ListPeers(machine *Machine) (Machines, error) {
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", machine.Hostname).
|
||||
Msg("Finding direct peers")
|
||||
|
||||
machines := Machines{}
|
||||
if err := h.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where("node_key <> ?",
|
||||
if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where("node_key <> ?",
|
||||
machine.NodeKey).Find(&machines).Error; err != nil {
|
||||
log.Error().Err(err).Msg("Error accessing db")
|
||||
|
||||
|
@ -237,23 +241,27 @@ func (h *Headscale) ListPeers(machine *Machine) (Machines, error) {
|
|||
return machines, nil
|
||||
}
|
||||
|
||||
func (h *Headscale) getPeers(machine *Machine) (Machines, error) {
|
||||
func (hsdb *HSDatabase) getPeers(
|
||||
aclPolicy *ACLPolicy,
|
||||
aclRules []tailcfg.FilterRule,
|
||||
machine *Machine,
|
||||
) (Machines, error) {
|
||||
var peers Machines
|
||||
var err error
|
||||
|
||||
// If ACLs rules are defined, filter visible host list with the ACLs
|
||||
// else use the classic user scope
|
||||
if h.aclPolicy != nil {
|
||||
if aclPolicy != nil {
|
||||
var machines []Machine
|
||||
machines, err = h.ListMachines()
|
||||
machines, err = hsdb.ListMachines()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Error retrieving list of machines")
|
||||
|
||||
return Machines{}, err
|
||||
}
|
||||
peers = h.filterMachinesByACL(machine, machines)
|
||||
peers = hsdb.filterMachinesByACL(aclRules, machine, machines)
|
||||
} else {
|
||||
peers, err = h.ListPeers(machine)
|
||||
peers, err = hsdb.ListPeers(machine)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
@ -275,10 +283,14 @@ func (h *Headscale) getPeers(machine *Machine) (Machines, error) {
|
|||
return peers, nil
|
||||
}
|
||||
|
||||
func (h *Headscale) getValidPeers(machine *Machine) (Machines, error) {
|
||||
func (hsdb *HSDatabase) getValidPeers(
|
||||
aclPolicy *ACLPolicy,
|
||||
aclRules []tailcfg.FilterRule,
|
||||
machine *Machine,
|
||||
) (Machines, error) {
|
||||
validPeers := make(Machines, 0)
|
||||
|
||||
peers, err := h.getPeers(machine)
|
||||
peers, err := hsdb.getPeers(aclPolicy, aclRules, machine)
|
||||
if err != nil {
|
||||
return Machines{}, err
|
||||
}
|
||||
|
@ -292,18 +304,18 @@ func (h *Headscale) getValidPeers(machine *Machine) (Machines, error) {
|
|||
return validPeers, nil
|
||||
}
|
||||
|
||||
func (h *Headscale) ListMachines() ([]Machine, error) {
|
||||
func (hsdb *HSDatabase) ListMachines() ([]Machine, error) {
|
||||
machines := []Machine{}
|
||||
if err := h.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Find(&machines).Error; err != nil {
|
||||
if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Find(&machines).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return machines, nil
|
||||
}
|
||||
|
||||
func (h *Headscale) ListMachinesByGivenName(givenName string) ([]Machine, error) {
|
||||
func (hsdb *HSDatabase) ListMachinesByGivenName(givenName string) ([]Machine, error) {
|
||||
machines := []Machine{}
|
||||
if err := h.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where("given_name = ?", givenName).Find(&machines).Error; err != nil {
|
||||
if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where("given_name = ?", givenName).Find(&machines).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -311,8 +323,8 @@ func (h *Headscale) ListMachinesByGivenName(givenName string) ([]Machine, error)
|
|||
}
|
||||
|
||||
// GetMachine finds a Machine by name and user and returns the Machine struct.
|
||||
func (h *Headscale) GetMachine(user string, name string) (*Machine, error) {
|
||||
machines, err := h.ListMachinesByUser(user)
|
||||
func (hsdb *HSDatabase) GetMachine(user string, name string) (*Machine, error) {
|
||||
machines, err := hsdb.ListMachinesByUser(user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -327,8 +339,8 @@ func (h *Headscale) GetMachine(user string, name string) (*Machine, error) {
|
|||
}
|
||||
|
||||
// GetMachineByGivenName finds a Machine by given name and user and returns the Machine struct.
|
||||
func (h *Headscale) GetMachineByGivenName(user string, givenName string) (*Machine, error) {
|
||||
machines, err := h.ListMachinesByUser(user)
|
||||
func (hsdb *HSDatabase) GetMachineByGivenName(user string, givenName string) (*Machine, error) {
|
||||
machines, err := hsdb.ListMachinesByUser(user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -343,9 +355,9 @@ func (h *Headscale) GetMachineByGivenName(user string, givenName string) (*Machi
|
|||
}
|
||||
|
||||
// GetMachineByID finds a Machine by ID and returns the Machine struct.
|
||||
func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
|
||||
func (hsdb *HSDatabase) GetMachineByID(id uint64) (*Machine, error) {
|
||||
m := Machine{}
|
||||
if result := h.db.Preload("AuthKey").Preload("User").Find(&Machine{ID: id}).First(&m); result.Error != nil {
|
||||
if result := hsdb.db.Preload("AuthKey").Preload("User").Find(&Machine{ID: id}).First(&m); result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
|
@ -353,11 +365,11 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
|
|||
}
|
||||
|
||||
// GetMachineByMachineKey finds a Machine by its MachineKey and returns the Machine struct.
|
||||
func (h *Headscale) GetMachineByMachineKey(
|
||||
func (hsdb *HSDatabase) GetMachineByMachineKey(
|
||||
machineKey key.MachinePublic,
|
||||
) (*Machine, error) {
|
||||
m := Machine{}
|
||||
if result := h.db.Preload("AuthKey").Preload("User").First(&m, "machine_key = ?", MachinePublicKeyStripPrefix(machineKey)); result.Error != nil {
|
||||
if result := hsdb.db.Preload("AuthKey").Preload("User").First(&m, "machine_key = ?", util.MachinePublicKeyStripPrefix(machineKey)); result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
|
@ -365,12 +377,12 @@ func (h *Headscale) GetMachineByMachineKey(
|
|||
}
|
||||
|
||||
// GetMachineByNodeKey finds a Machine by its current NodeKey.
|
||||
func (h *Headscale) GetMachineByNodeKey(
|
||||
func (hsdb *HSDatabase) GetMachineByNodeKey(
|
||||
nodeKey key.NodePublic,
|
||||
) (*Machine, error) {
|
||||
machine := Machine{}
|
||||
if result := h.db.Preload("AuthKey").Preload("User").First(&machine, "node_key = ?",
|
||||
NodePublicKeyStripPrefix(nodeKey)); result.Error != nil {
|
||||
if result := hsdb.db.Preload("AuthKey").Preload("User").First(&machine, "node_key = ?",
|
||||
util.NodePublicKeyStripPrefix(nodeKey)); result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
|
@ -378,14 +390,14 @@ func (h *Headscale) GetMachineByNodeKey(
|
|||
}
|
||||
|
||||
// GetMachineByAnyNodeKey finds a Machine by its MachineKey, its current NodeKey or the old one, and returns the Machine struct.
|
||||
func (h *Headscale) GetMachineByAnyKey(
|
||||
func (hsdb *HSDatabase) GetMachineByAnyKey(
|
||||
machineKey key.MachinePublic, nodeKey key.NodePublic, oldNodeKey key.NodePublic,
|
||||
) (*Machine, error) {
|
||||
machine := Machine{}
|
||||
if result := h.db.Preload("AuthKey").Preload("User").First(&machine, "machine_key = ? OR node_key = ? OR node_key = ?",
|
||||
MachinePublicKeyStripPrefix(machineKey),
|
||||
NodePublicKeyStripPrefix(nodeKey),
|
||||
NodePublicKeyStripPrefix(oldNodeKey)); result.Error != nil {
|
||||
if result := hsdb.db.Preload("AuthKey").Preload("User").First(&machine, "machine_key = ? OR node_key = ? OR node_key = ?",
|
||||
util.MachinePublicKeyStripPrefix(machineKey),
|
||||
util.NodePublicKeyStripPrefix(nodeKey),
|
||||
util.NodePublicKeyStripPrefix(oldNodeKey)); result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
|
@ -394,8 +406,8 @@ func (h *Headscale) GetMachineByAnyKey(
|
|||
|
||||
// UpdateMachineFromDatabase takes a Machine struct pointer (typically already loaded from database
|
||||
// and updates it with the latest data from the database.
|
||||
func (h *Headscale) UpdateMachineFromDatabase(machine *Machine) error {
|
||||
if result := h.db.Find(machine).First(&machine); result.Error != nil {
|
||||
func (hsdb *HSDatabase) UpdateMachineFromDatabase(machine *Machine) error {
|
||||
if result := hsdb.db.Find(machine).First(&machine); result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
|
@ -403,20 +415,28 @@ func (h *Headscale) UpdateMachineFromDatabase(machine *Machine) error {
|
|||
}
|
||||
|
||||
// SetTags takes a Machine struct pointer and update the forced tags.
|
||||
func (h *Headscale) SetTags(machine *Machine, tags []string) error {
|
||||
func (hsdb *HSDatabase) SetTags(
|
||||
machine *Machine,
|
||||
tags []string,
|
||||
// TODO(kradalby): This is a temporary measure to be able to detach the
|
||||
// database completely from the global h. In the future, as part of this
|
||||
// reorg, the rules will be generated on a per node basis, and not be prone
|
||||
// to throwing error at save.
|
||||
updateACL func() error) error {
|
||||
newTags := []string{}
|
||||
for _, tag := range tags {
|
||||
if !contains(newTags, tag) {
|
||||
if !util.StringOrPrefixListContains(newTags, tag) {
|
||||
newTags = append(newTags, tag)
|
||||
}
|
||||
}
|
||||
machine.ForcedTags = newTags
|
||||
if err := h.UpdateACLRules(); err != nil && !errors.Is(err, errEmptyPolicy) {
|
||||
if err := updateACL(); err != nil && !errors.Is(err, errEmptyPolicy) {
|
||||
return err
|
||||
}
|
||||
h.setLastStateChangeToNow()
|
||||
|
||||
if err := h.db.Save(machine).Error; err != nil {
|
||||
hsdb.notifyStateChange()
|
||||
|
||||
if err := hsdb.db.Save(machine).Error; err != nil {
|
||||
return fmt.Errorf("failed to update tags for machine in the database: %w", err)
|
||||
}
|
||||
|
||||
|
@ -424,13 +444,13 @@ func (h *Headscale) SetTags(machine *Machine, tags []string) error {
|
|||
}
|
||||
|
||||
// ExpireMachine takes a Machine struct and sets the expire field to now.
|
||||
func (h *Headscale) ExpireMachine(machine *Machine) error {
|
||||
func (hsdb *HSDatabase) ExpireMachine(machine *Machine) error {
|
||||
now := time.Now()
|
||||
machine.Expiry = &now
|
||||
|
||||
h.setLastStateChangeToNow()
|
||||
hsdb.notifyStateChange()
|
||||
|
||||
if err := h.db.Save(machine).Error; err != nil {
|
||||
if err := hsdb.db.Save(machine).Error; err != nil {
|
||||
return fmt.Errorf("failed to expire machine in the database: %w", err)
|
||||
}
|
||||
|
||||
|
@ -439,7 +459,7 @@ func (h *Headscale) ExpireMachine(machine *Machine) error {
|
|||
|
||||
// RenameMachine takes a Machine struct and a new GivenName for the machines
|
||||
// and renames it.
|
||||
func (h *Headscale) RenameMachine(machine *Machine, newName string) error {
|
||||
func (hsdb *HSDatabase) RenameMachine(machine *Machine, newName string) error {
|
||||
err := CheckForFQDNRules(
|
||||
newName,
|
||||
)
|
||||
|
@ -455,9 +475,9 @@ func (h *Headscale) RenameMachine(machine *Machine, newName string) error {
|
|||
}
|
||||
machine.GivenName = newName
|
||||
|
||||
h.setLastStateChangeToNow()
|
||||
hsdb.notifyStateChange()
|
||||
|
||||
if err := h.db.Save(machine).Error; err != nil {
|
||||
if err := hsdb.db.Save(machine).Error; err != nil {
|
||||
return fmt.Errorf("failed to rename machine in the database: %w", err)
|
||||
}
|
||||
|
||||
|
@ -465,15 +485,15 @@ func (h *Headscale) RenameMachine(machine *Machine, newName string) error {
|
|||
}
|
||||
|
||||
// RefreshMachine takes a Machine struct and sets the expire field to now.
|
||||
func (h *Headscale) RefreshMachine(machine *Machine, expiry time.Time) error {
|
||||
func (hsdb *HSDatabase) RefreshMachine(machine *Machine, expiry time.Time) error {
|
||||
now := time.Now()
|
||||
|
||||
machine.LastSuccessfulUpdate = &now
|
||||
machine.Expiry = &expiry
|
||||
|
||||
h.setLastStateChangeToNow()
|
||||
hsdb.notifyStateChange()
|
||||
|
||||
if err := h.db.Save(machine).Error; err != nil {
|
||||
if err := hsdb.db.Save(machine).Error; err != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to refresh machine (update expiration) in the database: %w",
|
||||
err,
|
||||
|
@ -484,21 +504,21 @@ func (h *Headscale) RefreshMachine(machine *Machine, expiry time.Time) error {
|
|||
}
|
||||
|
||||
// DeleteMachine softs deletes a Machine from the database.
|
||||
func (h *Headscale) DeleteMachine(machine *Machine) error {
|
||||
err := h.DeleteMachineRoutes(machine)
|
||||
func (hsdb *HSDatabase) DeleteMachine(machine *Machine) error {
|
||||
err := hsdb.DeleteMachineRoutes(machine)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := h.db.Delete(&machine).Error; err != nil {
|
||||
if err := hsdb.db.Delete(&machine).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Headscale) TouchMachine(machine *Machine) error {
|
||||
return h.db.Updates(Machine{
|
||||
func (hsdb *HSDatabase) TouchMachine(machine *Machine) error {
|
||||
return hsdb.db.Updates(Machine{
|
||||
ID: machine.ID,
|
||||
LastSeen: machine.LastSeen,
|
||||
LastSuccessfulUpdate: machine.LastSuccessfulUpdate,
|
||||
|
@ -506,13 +526,13 @@ func (h *Headscale) TouchMachine(machine *Machine) error {
|
|||
}
|
||||
|
||||
// HardDeleteMachine hard deletes a Machine from the database.
|
||||
func (h *Headscale) HardDeleteMachine(machine *Machine) error {
|
||||
err := h.DeleteMachineRoutes(machine)
|
||||
func (hsdb *HSDatabase) HardDeleteMachine(machine *Machine) error {
|
||||
err := hsdb.DeleteMachineRoutes(machine)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := h.db.Unscoped().Delete(&machine).Error; err != nil {
|
||||
if err := hsdb.db.Unscoped().Delete(&machine).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -524,8 +544,8 @@ func (machine *Machine) GetHostInfo() tailcfg.Hostinfo {
|
|||
return tailcfg.Hostinfo(machine.HostInfo)
|
||||
}
|
||||
|
||||
func (h *Headscale) isOutdated(machine *Machine) bool {
|
||||
if err := h.UpdateMachineFromDatabase(machine); err != nil {
|
||||
func (hsdb *HSDatabase) isOutdated(machine *Machine, lastChange time.Time) bool {
|
||||
if err := hsdb.UpdateMachineFromDatabase(machine); err != nil {
|
||||
// It does not seem meaningful to propagate this error as the end result
|
||||
// will have to be that the machine has to be considered outdated.
|
||||
return true
|
||||
|
@ -536,7 +556,6 @@ func (h *Headscale) isOutdated(machine *Machine) bool {
|
|||
// TODO(kradalby): Only request updates from users where we can talk to nodes
|
||||
// This would mostly be for a bit of performance, and can be calculated based on
|
||||
// ACLs.
|
||||
lastChange := h.getLastStateChange()
|
||||
lastUpdate := machine.CreatedAt
|
||||
if machine.LastSuccessfulUpdate != nil {
|
||||
lastUpdate = *machine.LastSuccessfulUpdate
|
||||
|
@ -576,15 +595,16 @@ func (machines MachinesP) String() string {
|
|||
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
|
||||
}
|
||||
|
||||
func (h *Headscale) toNodes(
|
||||
func (hsdb *HSDatabase) toNodes(
|
||||
machines Machines,
|
||||
aclPolicy *ACLPolicy,
|
||||
baseDomain string,
|
||||
dnsConfig *tailcfg.DNSConfig,
|
||||
) ([]*tailcfg.Node, error) {
|
||||
nodes := make([]*tailcfg.Node, len(machines))
|
||||
|
||||
for index, machine := range machines {
|
||||
node, err := h.toNode(machine, baseDomain, dnsConfig)
|
||||
node, err := hsdb.toNode(machine, aclPolicy, baseDomain, dnsConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -597,13 +617,14 @@ func (h *Headscale) toNodes(
|
|||
|
||||
// toNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes
|
||||
// as per the expected behaviour in the official SaaS.
|
||||
func (h *Headscale) toNode(
|
||||
func (hsdb *HSDatabase) toNode(
|
||||
machine Machine,
|
||||
aclPolicy *ACLPolicy,
|
||||
baseDomain string,
|
||||
dnsConfig *tailcfg.DNSConfig,
|
||||
) (*tailcfg.Node, error) {
|
||||
var nodeKey key.NodePublic
|
||||
err := nodeKey.UnmarshalText([]byte(NodePublicKeyEnsurePrefix(machine.NodeKey)))
|
||||
err := nodeKey.UnmarshalText([]byte(util.NodePublicKeyEnsurePrefix(machine.NodeKey)))
|
||||
if err != nil {
|
||||
log.Trace().
|
||||
Caller().
|
||||
|
@ -617,7 +638,7 @@ func (h *Headscale) toNode(
|
|||
// MachineKey is only used in the legacy protocol
|
||||
if machine.MachineKey != "" {
|
||||
err = machineKey.UnmarshalText(
|
||||
[]byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)),
|
||||
[]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse machine public key: %w", err)
|
||||
|
@ -627,7 +648,7 @@ func (h *Headscale) toNode(
|
|||
var discoKey key.DiscoPublic
|
||||
if machine.DiscoKey != "" {
|
||||
err := discoKey.UnmarshalText(
|
||||
[]byte(DiscoPublicKeyEnsurePrefix(machine.DiscoKey)),
|
||||
[]byte(util.DiscoPublicKeyEnsurePrefix(machine.DiscoKey)),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse disco public key: %w", err)
|
||||
|
@ -646,13 +667,13 @@ func (h *Headscale) toNode(
|
|||
[]netip.Prefix{},
|
||||
addrs...) // we append the node own IP, as it is required by the clients
|
||||
|
||||
primaryRoutes, err := h.getMachinePrimaryRoutes(&machine)
|
||||
primaryRoutes, err := hsdb.getMachinePrimaryRoutes(&machine)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
primaryPrefixes := Routes(primaryRoutes).toPrefixes()
|
||||
|
||||
machineRoutes, err := h.GetMachineRoutes(&machine)
|
||||
machineRoutes, err := hsdb.GetMachineRoutes(&machine)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -699,13 +720,13 @@ func (h *Headscale) toNode(
|
|||
|
||||
online := machine.isOnline()
|
||||
|
||||
tags, _ := getTags(h.aclPolicy, machine, h.cfg.OIDC.StripEmaildomain)
|
||||
tags, _ := getTags(aclPolicy, machine, hsdb.stripEmailDomain)
|
||||
tags = lo.Uniq(append(tags, machine.ForcedTags...))
|
||||
|
||||
node := tailcfg.Node{
|
||||
ID: tailcfg.NodeID(machine.ID), // this is the actual ID
|
||||
StableID: tailcfg.StableNodeID(
|
||||
strconv.FormatUint(machine.ID, Base10),
|
||||
strconv.FormatUint(machine.ID, util.Base10),
|
||||
), // in headscale, unlike tailcontrol server, IDs are permanent
|
||||
Name: hostname,
|
||||
|
||||
|
@ -827,7 +848,8 @@ func getTags(
|
|||
return validTags, invalidTags
|
||||
}
|
||||
|
||||
func (h *Headscale) RegisterMachineFromAuthCallback(
|
||||
func (hsdb *HSDatabase) RegisterMachineFromAuthCallback(
|
||||
cache *cache.Cache,
|
||||
nodeKeyStr string,
|
||||
userName string,
|
||||
machineExpiry *time.Time,
|
||||
|
@ -846,9 +868,9 @@ func (h *Headscale) RegisterMachineFromAuthCallback(
|
|||
Str("expiresAt", fmt.Sprintf("%v", machineExpiry)).
|
||||
Msg("Registering machine from API/CLI or auth callback")
|
||||
|
||||
if machineInterface, ok := h.registrationCache.Get(NodePublicKeyStripPrefix(nodeKey)); ok {
|
||||
if machineInterface, ok := cache.Get(util.NodePublicKeyStripPrefix(nodeKey)); ok {
|
||||
if registrationMachine, ok := machineInterface.(Machine); ok {
|
||||
user, err := h.GetUser(userName)
|
||||
user, err := hsdb.GetUser(userName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"failed to find user in register machine from auth callback, %w",
|
||||
|
@ -869,12 +891,12 @@ func (h *Headscale) RegisterMachineFromAuthCallback(
|
|||
registrationMachine.Expiry = machineExpiry
|
||||
}
|
||||
|
||||
machine, err := h.RegisterMachine(
|
||||
machine, err := hsdb.RegisterMachine(
|
||||
registrationMachine,
|
||||
)
|
||||
|
||||
if err == nil {
|
||||
h.registrationCache.Delete(nodeKeyStr)
|
||||
cache.Delete(nodeKeyStr)
|
||||
}
|
||||
|
||||
return machine, err
|
||||
|
@ -887,7 +909,7 @@ func (h *Headscale) RegisterMachineFromAuthCallback(
|
|||
}
|
||||
|
||||
// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey.
|
||||
func (h *Headscale) RegisterMachine(machine Machine,
|
||||
func (hsdb *HSDatabase) RegisterMachine(machine Machine,
|
||||
) (*Machine, error) {
|
||||
log.Debug().
|
||||
Str("machine", machine.Hostname).
|
||||
|
@ -900,7 +922,7 @@ func (h *Headscale) RegisterMachine(machine Machine,
|
|||
// so we store the machine.Expire and machine.Nodekey that has been set when
|
||||
// adding it to the registrationCache
|
||||
if len(machine.IPAddresses) > 0 {
|
||||
if err := h.db.Save(&machine).Error; err != nil {
|
||||
if err := hsdb.db.Save(&machine).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed register existing machine in the database: %w", err)
|
||||
}
|
||||
|
||||
|
@ -915,10 +937,10 @@ func (h *Headscale) RegisterMachine(machine Machine,
|
|||
return &machine, nil
|
||||
}
|
||||
|
||||
h.ipAllocationMutex.Lock()
|
||||
defer h.ipAllocationMutex.Unlock()
|
||||
hsdb.ipAllocationMutex.Lock()
|
||||
defer hsdb.ipAllocationMutex.Unlock()
|
||||
|
||||
ips, err := h.getAvailableIPs()
|
||||
ips, err := hsdb.getAvailableIPs()
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
@ -931,7 +953,7 @@ func (h *Headscale) RegisterMachine(machine Machine,
|
|||
|
||||
machine.IPAddresses = ips
|
||||
|
||||
if err := h.db.Save(&machine).Error; err != nil {
|
||||
if err := hsdb.db.Save(&machine).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed register(save) machine in the database: %w", err)
|
||||
}
|
||||
|
||||
|
@ -945,10 +967,10 @@ func (h *Headscale) RegisterMachine(machine Machine,
|
|||
}
|
||||
|
||||
// GetAdvertisedRoutes returns the routes that are be advertised by the given machine.
|
||||
func (h *Headscale) GetAdvertisedRoutes(machine *Machine) ([]netip.Prefix, error) {
|
||||
func (hsdb *HSDatabase) GetAdvertisedRoutes(machine *Machine) ([]netip.Prefix, error) {
|
||||
routes := []Route{}
|
||||
|
||||
err := h.db.
|
||||
err := hsdb.db.
|
||||
Preload("Machine").
|
||||
Where("machine_id = ? AND advertised = ?", machine.ID, true).Find(&routes).Error
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
|
@ -970,10 +992,10 @@ func (h *Headscale) GetAdvertisedRoutes(machine *Machine) ([]netip.Prefix, error
|
|||
}
|
||||
|
||||
// GetEnabledRoutes returns the routes that are enabled for the machine.
|
||||
func (h *Headscale) GetEnabledRoutes(machine *Machine) ([]netip.Prefix, error) {
|
||||
func (hsdb *HSDatabase) GetEnabledRoutes(machine *Machine) ([]netip.Prefix, error) {
|
||||
routes := []Route{}
|
||||
|
||||
err := h.db.
|
||||
err := hsdb.db.
|
||||
Preload("Machine").
|
||||
Where("machine_id = ? AND advertised = ? AND enabled = ?", machine.ID, true, true).
|
||||
Find(&routes).Error
|
||||
|
@ -995,13 +1017,13 @@ func (h *Headscale) GetEnabledRoutes(machine *Machine) ([]netip.Prefix, error) {
|
|||
return prefixes, nil
|
||||
}
|
||||
|
||||
func (h *Headscale) IsRoutesEnabled(machine *Machine, routeStr string) bool {
|
||||
func (hsdb *HSDatabase) IsRoutesEnabled(machine *Machine, routeStr string) bool {
|
||||
route, err := netip.ParsePrefix(routeStr)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
enabledRoutes, err := h.GetEnabledRoutes(machine)
|
||||
enabledRoutes, err := hsdb.GetEnabledRoutes(machine)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Could not get enabled routes")
|
||||
|
||||
|
@ -1018,7 +1040,7 @@ func (h *Headscale) IsRoutesEnabled(machine *Machine, routeStr string) bool {
|
|||
}
|
||||
|
||||
// enableRoutes enables new routes based on a list of new routes.
|
||||
func (h *Headscale) enableRoutes(machine *Machine, routeStrs ...string) error {
|
||||
func (hsdb *HSDatabase) enableRoutes(machine *Machine, routeStrs ...string) error {
|
||||
newRoutes := make([]netip.Prefix, len(routeStrs))
|
||||
for index, routeStr := range routeStrs {
|
||||
route, err := netip.ParsePrefix(routeStr)
|
||||
|
@ -1029,13 +1051,13 @@ func (h *Headscale) enableRoutes(machine *Machine, routeStrs ...string) error {
|
|||
newRoutes[index] = route
|
||||
}
|
||||
|
||||
advertisedRoutes, err := h.GetAdvertisedRoutes(machine)
|
||||
advertisedRoutes, err := hsdb.GetAdvertisedRoutes(machine)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, newRoute := range newRoutes {
|
||||
if !contains(advertisedRoutes, newRoute) {
|
||||
if !util.StringOrPrefixListContains(advertisedRoutes, newRoute) {
|
||||
return fmt.Errorf(
|
||||
"route (%s) is not available on node %s: %w",
|
||||
machine.Hostname,
|
||||
|
@ -1047,7 +1069,7 @@ func (h *Headscale) enableRoutes(machine *Machine, routeStrs ...string) error {
|
|||
// Separate loop so we don't leave things in a half-updated state
|
||||
for _, prefix := range newRoutes {
|
||||
route := Route{}
|
||||
err := h.db.Preload("Machine").
|
||||
err := hsdb.db.Preload("Machine").
|
||||
Where("machine_id = ? AND prefix = ?", machine.ID, IPPrefix(prefix)).
|
||||
First(&route).Error
|
||||
if err == nil {
|
||||
|
@ -1056,10 +1078,10 @@ func (h *Headscale) enableRoutes(machine *Machine, routeStrs ...string) error {
|
|||
// Mark already as primary if there is only this node offering this subnet
|
||||
// (and is not an exit route)
|
||||
if !route.isExitRoute() {
|
||||
route.IsPrimary = h.isUniquePrefix(route)
|
||||
route.IsPrimary = hsdb.isUniquePrefix(route)
|
||||
}
|
||||
|
||||
err = h.db.Save(&route).Error
|
||||
err = hsdb.db.Save(&route).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to enable route: %w", err)
|
||||
}
|
||||
|
@ -1068,19 +1090,19 @@ func (h *Headscale) enableRoutes(machine *Machine, routeStrs ...string) error {
|
|||
}
|
||||
}
|
||||
|
||||
h.setLastStateChangeToNow()
|
||||
hsdb.notifyStateChange()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnableAutoApprovedRoutes enables any routes advertised by a machine that match the ACL autoApprovers policy.
|
||||
func (h *Headscale) EnableAutoApprovedRoutes(machine *Machine) error {
|
||||
func (hsdb *HSDatabase) EnableAutoApprovedRoutes(aclPolicy *ACLPolicy, machine *Machine) error {
|
||||
if len(machine.IPAddresses) == 0 {
|
||||
return nil // This machine has no IPAddresses, so can't possibly match any autoApprovers ACLs
|
||||
}
|
||||
|
||||
routes := []Route{}
|
||||
err := h.db.
|
||||
err := hsdb.db.
|
||||
Preload("Machine").
|
||||
Where("machine_id = ? AND advertised = true AND enabled = false", machine.ID).
|
||||
Find(&routes).Error
|
||||
|
@ -1097,7 +1119,7 @@ func (h *Headscale) EnableAutoApprovedRoutes(machine *Machine) error {
|
|||
approvedRoutes := []Route{}
|
||||
|
||||
for _, advertisedRoute := range routes {
|
||||
routeApprovers, err := h.aclPolicy.AutoApprovers.GetRouteApprovers(
|
||||
routeApprovers, err := aclPolicy.AutoApprovers.GetRouteApprovers(
|
||||
netip.Prefix(advertisedRoute.Prefix),
|
||||
)
|
||||
if err != nil {
|
||||
|
@ -1113,7 +1135,7 @@ func (h *Headscale) EnableAutoApprovedRoutes(machine *Machine) error {
|
|||
if approvedAlias == machine.User.Name {
|
||||
approvedRoutes = append(approvedRoutes, advertisedRoute)
|
||||
} else {
|
||||
approvedIps, err := h.aclPolicy.expandAlias([]Machine{*machine}, approvedAlias, h.cfg.OIDC.StripEmaildomain)
|
||||
approvedIps, err := aclPolicy.expandAlias([]Machine{*machine}, approvedAlias, hsdb.stripEmailDomain)
|
||||
if err != nil {
|
||||
log.Err(err).
|
||||
Str("alias", approvedAlias).
|
||||
|
@ -1132,7 +1154,7 @@ func (h *Headscale) EnableAutoApprovedRoutes(machine *Machine) error {
|
|||
|
||||
for i, approvedRoute := range approvedRoutes {
|
||||
approvedRoutes[i].Enabled = true
|
||||
err = h.db.Save(&approvedRoutes[i]).Error
|
||||
err = hsdb.db.Save(&approvedRoutes[i]).Error
|
||||
if err != nil {
|
||||
log.Err(err).
|
||||
Str("approvedRoute", approvedRoute.String()).
|
||||
|
@ -1146,10 +1168,10 @@ func (h *Headscale) EnableAutoApprovedRoutes(machine *Machine) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (h *Headscale) generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
|
||||
func (hsdb *HSDatabase) generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
|
||||
normalizedHostname, err := NormalizeToFQDNRules(
|
||||
suppliedName,
|
||||
h.cfg.OIDC.StripEmaildomain,
|
||||
hsdb.stripEmailDomain,
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
@ -1162,7 +1184,7 @@ func (h *Headscale) generateGivenName(suppliedName string, randomSuffix bool) (s
|
|||
normalizedHostname = normalizedHostname[:trimmedHostnameLength]
|
||||
}
|
||||
|
||||
suffix, err := GenerateRandomStringDNSSafe(MachineGivenNameHashLength)
|
||||
suffix, err := util.GenerateRandomStringDNSSafe(MachineGivenNameHashLength)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -1173,21 +1195,21 @@ func (h *Headscale) generateGivenName(suppliedName string, randomSuffix bool) (s
|
|||
return normalizedHostname, nil
|
||||
}
|
||||
|
||||
func (h *Headscale) GenerateGivenName(machineKey string, suppliedName string) (string, error) {
|
||||
givenName, err := h.generateGivenName(suppliedName, false)
|
||||
func (hsdb *HSDatabase) GenerateGivenName(machineKey string, suppliedName string) (string, error) {
|
||||
givenName, err := hsdb.generateGivenName(suppliedName, false)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Tailscale rules (may differ) https://tailscale.com/kb/1098/machine-names/
|
||||
machines, err := h.ListMachinesByGivenName(givenName)
|
||||
machines, err := hsdb.ListMachinesByGivenName(givenName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for _, machine := range machines {
|
||||
if machine.MachineKey != machineKey && machine.GivenName == givenName {
|
||||
postfixedName, err := h.generateGivenName(suppliedName, true)
|
||||
postfixedName, err := hsdb.generateGivenName(suppliedName, true)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue