Merge branch 'main' into db-error-handling
This commit is contained in:
commit
0676aa11a9
36 changed files with 1522 additions and 554 deletions
135
machine.go
135
machine.go
|
@ -9,7 +9,6 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fatih/set"
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/rs/zerolog/log"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
@ -27,6 +26,7 @@ const (
|
|||
)
|
||||
errCouldNotConvertMachineInterface = Error("failed to convert machine interface")
|
||||
errHostnameTooLong = Error("Hostname too long")
|
||||
MachineGivenNameHashLength = 8
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -40,7 +40,18 @@ type Machine struct {
|
|||
NodeKey string
|
||||
DiscoKey string
|
||||
IPAddresses MachineAddresses
|
||||
Name string
|
||||
|
||||
// Hostname represents the name given by the Tailscale
|
||||
// client during registration
|
||||
Hostname string
|
||||
|
||||
// Givenname represents either:
|
||||
// a DNS normalized version of Hostname
|
||||
// a valid name set by the User
|
||||
//
|
||||
// GivenName is the name used in all DNS related
|
||||
// parts of headscale.
|
||||
GivenName string `gorm:"type:varchar(63);unique_index"`
|
||||
NamespaceID uint
|
||||
Namespace Namespace `gorm:"foreignKey:NamespaceID"`
|
||||
|
||||
|
@ -152,7 +163,7 @@ func getFilteredByACLPeers(
|
|||
) Machines {
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", machine.Name).
|
||||
Str("machine", machine.Hostname).
|
||||
Msg("Finding peers filtered by ACLs")
|
||||
|
||||
peers := make(map[uint64]Machine)
|
||||
|
@ -219,7 +230,7 @@ func getFilteredByACLPeers(
|
|||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", machine.Name).
|
||||
Str("machine", machine.Hostname).
|
||||
Msgf("Found some machines: %v", machines)
|
||||
|
||||
return authorizedPeers
|
||||
|
@ -228,7 +239,7 @@ func getFilteredByACLPeers(
|
|||
func (h *Headscale) ListPeers(machine *Machine) (Machines, error) {
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", machine.Name).
|
||||
Str("machine", machine.Hostname).
|
||||
Msg("Finding direct peers")
|
||||
|
||||
machines := Machines{}
|
||||
|
@ -243,7 +254,7 @@ func (h *Headscale) ListPeers(machine *Machine) (Machines, error) {
|
|||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", machine.Name).
|
||||
Str("machine", machine.Hostname).
|
||||
Msgf("Found peers: %s", machines.String())
|
||||
|
||||
return machines, nil
|
||||
|
@ -280,7 +291,7 @@ func (h *Headscale) getPeers(machine *Machine) (Machines, error) {
|
|||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", machine.Name).
|
||||
Str("machine", machine.Hostname).
|
||||
Msgf("Found total peers: %s", peers.String())
|
||||
|
||||
return peers, nil
|
||||
|
@ -320,7 +331,7 @@ func (h *Headscale) GetMachine(namespace string, name string) (*Machine, error)
|
|||
}
|
||||
|
||||
for _, m := range machines {
|
||||
if m.Name == name {
|
||||
if m.Hostname == name {
|
||||
return &m, nil
|
||||
}
|
||||
}
|
||||
|
@ -350,9 +361,9 @@ func (h *Headscale) GetMachineByMachineKey(
|
|||
return &m, nil
|
||||
}
|
||||
|
||||
// UpdateMachine takes a Machine struct pointer (typically already loaded from database
|
||||
// UpdateMachineFromDatabase takes a Machine struct pointer (typically already loaded from database
|
||||
// and updates it with the latest data from the database.
|
||||
func (h *Headscale) UpdateMachine(machine *Machine) error {
|
||||
func (h *Headscale) UpdateMachineFromDatabase(machine *Machine) error {
|
||||
if result := h.db.Find(machine).First(&machine); result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
@ -389,8 +400,35 @@ func (h *Headscale) ExpireMachine(machine *Machine) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// RefreshMachine takes a Machine struct and sets the expire field.
|
||||
func (h *Headscale) RefreshMachine(machine *Machine, expiry time.Time) error {
|
||||
// RenameMachine takes a Machine struct and a new GivenName for the machines
|
||||
// and renames it.
|
||||
func (h *Headscale) RenameMachine(machine *Machine, newName string) error {
|
||||
err := CheckForFQDNRules(
|
||||
newName,
|
||||
)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("func", "RenameMachine").
|
||||
Str("machine", machine.Hostname).
|
||||
Str("newName", newName).
|
||||
Err(err)
|
||||
|
||||
return err
|
||||
}
|
||||
machine.GivenName = newName
|
||||
|
||||
h.setLastStateChangeToNow(machine.Namespace.Name)
|
||||
|
||||
if err := h.db.Save(machine).Error; err != nil {
|
||||
return fmt.Errorf("failed to rename machine in the database: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RefreshMachine takes a Machine struct and sets the expire field to now.
|
||||
func (h *Headscale) RefreshMachine(machine *Machine, expiry time.Time) {
|
||||
now := time.Now()
|
||||
|
||||
machine.LastSuccessfulUpdate = &now
|
||||
|
@ -437,46 +475,41 @@ func (machine *Machine) GetHostInfo() tailcfg.Hostinfo {
|
|||
}
|
||||
|
||||
func (h *Headscale) isOutdated(machine *Machine) bool {
|
||||
if err := h.UpdateMachine(machine); err != nil {
|
||||
if err := h.UpdateMachineFromDatabase(machine); err != nil {
|
||||
// It does not seem meaningful to propagate this error as the end result
|
||||
// will have to be that the machine has to be considered outdated.
|
||||
return true
|
||||
}
|
||||
|
||||
namespaceSet := set.New(set.ThreadSafe)
|
||||
namespaceSet.Add(machine.Namespace.Name)
|
||||
|
||||
namespaces := make([]string, namespaceSet.Size())
|
||||
for index, namespace := range namespaceSet.List() {
|
||||
if name, ok := namespace.(string); ok {
|
||||
namespaces[index] = name
|
||||
}
|
||||
}
|
||||
|
||||
lastChange := h.getLastStateChange(namespaces...)
|
||||
// Get the last update from all headscale namespaces to compare with our nodes
|
||||
// last update.
|
||||
// TODO(kradalby): Only request updates from namespaces where we can talk to nodes
|
||||
// This would mostly be for a bit of performance, and can be calculated based on
|
||||
// ACLs.
|
||||
lastChange := h.getLastStateChange()
|
||||
lastUpdate := machine.CreatedAt
|
||||
if machine.LastSuccessfulUpdate != nil {
|
||||
lastUpdate = *machine.LastSuccessfulUpdate
|
||||
}
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", machine.Name).
|
||||
Str("machine", machine.Hostname).
|
||||
Time("last_successful_update", lastChange).
|
||||
Time("last_state_change", lastUpdate).
|
||||
Msgf("Checking if %s is missing updates", machine.Name)
|
||||
Msgf("Checking if %s is missing updates", machine.Hostname)
|
||||
|
||||
return lastUpdate.Before(lastChange)
|
||||
}
|
||||
|
||||
func (machine Machine) String() string {
|
||||
return machine.Name
|
||||
return machine.Hostname
|
||||
}
|
||||
|
||||
func (machines Machines) String() string {
|
||||
temp := make([]string, len(machines))
|
||||
|
||||
for index, machine := range machines {
|
||||
temp[index] = machine.Name
|
||||
temp[index] = machine.Hostname
|
||||
}
|
||||
|
||||
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
|
||||
|
@ -487,7 +520,7 @@ func (machines MachinesP) String() string {
|
|||
temp := make([]string, len(machines))
|
||||
|
||||
for index, machine := range machines {
|
||||
temp[index] = machine.Name
|
||||
temp[index] = machine.Hostname
|
||||
}
|
||||
|
||||
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
|
||||
|
@ -584,7 +617,7 @@ func (machine Machine) toNode(
|
|||
if dnsConfig != nil && dnsConfig.Proxied { // MagicDNS
|
||||
hostname = fmt.Sprintf(
|
||||
"%s.%s.%s",
|
||||
machine.Name,
|
||||
machine.GivenName,
|
||||
machine.Namespace.Name,
|
||||
baseDomain,
|
||||
)
|
||||
|
@ -596,7 +629,7 @@ func (machine Machine) toNode(
|
|||
)
|
||||
}
|
||||
} else {
|
||||
hostname = machine.Name
|
||||
hostname = machine.GivenName
|
||||
}
|
||||
|
||||
hostInfo := machine.GetHostInfo()
|
||||
|
@ -637,7 +670,8 @@ func (machine *Machine) toProto() *v1.Machine {
|
|||
NodeKey: machine.NodeKey,
|
||||
DiscoKey: machine.DiscoKey,
|
||||
IpAddresses: machine.IPAddresses.ToStringSlice(),
|
||||
Name: machine.Name,
|
||||
Name: machine.Hostname,
|
||||
GivenName: machine.GivenName,
|
||||
Namespace: machine.Namespace.toProto(),
|
||||
ForcedTags: machine.ForcedTags,
|
||||
|
||||
|
@ -753,7 +787,7 @@ func (h *Headscale) RegisterMachine(machine Machine,
|
|||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", machine.Name).
|
||||
Str("machine", machine.Hostname).
|
||||
Msg("Attempting to register machine")
|
||||
|
||||
h.ipAllocationMutex.Lock()
|
||||
|
@ -764,7 +798,7 @@ func (h *Headscale) RegisterMachine(machine Machine,
|
|||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Str("machine", machine.Name).
|
||||
Str("machine", machine.Hostname).
|
||||
Msg("Could not find IP for the new machine")
|
||||
|
||||
return nil, err
|
||||
|
@ -778,7 +812,7 @@ func (h *Headscale) RegisterMachine(machine Machine,
|
|||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", machine.Name).
|
||||
Str("machine", machine.Hostname).
|
||||
Str("ip", strings.Join(ips.ToStringSlice(), ",")).
|
||||
Msg("Machine registered with the database")
|
||||
|
||||
|
@ -827,7 +861,7 @@ func (h *Headscale) EnableRoutes(machine *Machine, routeStrs ...string) error {
|
|||
if !contains(machine.GetAdvertisedRoutes(), newRoute) {
|
||||
return fmt.Errorf(
|
||||
"route (%s) is not available on node %s: %w",
|
||||
machine.Name,
|
||||
machine.Hostname,
|
||||
newRoute, errMachineRouteIsNotAvailable,
|
||||
)
|
||||
}
|
||||
|
@ -852,3 +886,32 @@ func (machine *Machine) RoutesToProto() *v1.Routes {
|
|||
EnabledRoutes: ipPrefixToString(enabledRoutes),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Headscale) GenerateGivenName(suppliedName string) (string, error) {
|
||||
// If a hostname is or will be longer than 63 chars after adding the hash,
|
||||
// it needs to be trimmed.
|
||||
trimmedHostnameLength := labelHostnameLength - MachineGivenNameHashLength - 2
|
||||
|
||||
normalizedHostname, err := NormalizeToFQDNRules(
|
||||
suppliedName,
|
||||
h.cfg.OIDC.StripEmaildomain,
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
postfix, err := GenerateRandomStringDNSSafe(MachineGivenNameHashLength)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Verify that that the new unique name is shorter than the maximum allowed
|
||||
// DNS segment.
|
||||
if len(normalizedHostname) <= trimmedHostnameLength {
|
||||
normalizedHostname = fmt.Sprintf("%s-%s", normalizedHostname, postfix)
|
||||
} else {
|
||||
normalizedHostname = fmt.Sprintf("%s-%s", normalizedHostname[:trimmedHostnameLength], postfix)
|
||||
}
|
||||
|
||||
return normalizedHostname, nil
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue