Merge branch 'main' into db-error-handling

This commit is contained in:
Kristoffer Dalby 2022-05-31 10:18:13 +02:00 committed by GitHub
commit 0676aa11a9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
36 changed files with 1522 additions and 554 deletions

View file

@ -9,7 +9,6 @@ import (
"strings"
"time"
"github.com/fatih/set"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/rs/zerolog/log"
"google.golang.org/protobuf/types/known/timestamppb"
@ -27,6 +26,7 @@ const (
)
errCouldNotConvertMachineInterface = Error("failed to convert machine interface")
errHostnameTooLong = Error("Hostname too long")
MachineGivenNameHashLength = 8
)
const (
@ -40,7 +40,18 @@ type Machine struct {
NodeKey string
DiscoKey string
IPAddresses MachineAddresses
Name string
// Hostname represents the name given by the Tailscale
// client during registration
Hostname string
// Givenname represents either:
// a DNS normalized version of Hostname
// a valid name set by the User
//
// GivenName is the name used in all DNS related
// parts of headscale.
GivenName string `gorm:"type:varchar(63);unique_index"`
NamespaceID uint
Namespace Namespace `gorm:"foreignKey:NamespaceID"`
@ -152,7 +163,7 @@ func getFilteredByACLPeers(
) Machines {
log.Trace().
Caller().
Str("machine", machine.Name).
Str("machine", machine.Hostname).
Msg("Finding peers filtered by ACLs")
peers := make(map[uint64]Machine)
@ -219,7 +230,7 @@ func getFilteredByACLPeers(
log.Trace().
Caller().
Str("machine", machine.Name).
Str("machine", machine.Hostname).
Msgf("Found some machines: %v", machines)
return authorizedPeers
@ -228,7 +239,7 @@ func getFilteredByACLPeers(
func (h *Headscale) ListPeers(machine *Machine) (Machines, error) {
log.Trace().
Caller().
Str("machine", machine.Name).
Str("machine", machine.Hostname).
Msg("Finding direct peers")
machines := Machines{}
@ -243,7 +254,7 @@ func (h *Headscale) ListPeers(machine *Machine) (Machines, error) {
log.Trace().
Caller().
Str("machine", machine.Name).
Str("machine", machine.Hostname).
Msgf("Found peers: %s", machines.String())
return machines, nil
@ -280,7 +291,7 @@ func (h *Headscale) getPeers(machine *Machine) (Machines, error) {
log.Trace().
Caller().
Str("machine", machine.Name).
Str("machine", machine.Hostname).
Msgf("Found total peers: %s", peers.String())
return peers, nil
@ -320,7 +331,7 @@ func (h *Headscale) GetMachine(namespace string, name string) (*Machine, error)
}
for _, m := range machines {
if m.Name == name {
if m.Hostname == name {
return &m, nil
}
}
@ -350,9 +361,9 @@ func (h *Headscale) GetMachineByMachineKey(
return &m, nil
}
// UpdateMachine takes a Machine struct pointer (typically already loaded from database
// UpdateMachineFromDatabase takes a Machine struct pointer (typically already loaded from database
// and updates it with the latest data from the database.
func (h *Headscale) UpdateMachine(machine *Machine) error {
func (h *Headscale) UpdateMachineFromDatabase(machine *Machine) error {
if result := h.db.Find(machine).First(&machine); result.Error != nil {
return result.Error
}
@ -389,8 +400,35 @@ func (h *Headscale) ExpireMachine(machine *Machine) error {
return nil
}
// RefreshMachine takes a Machine struct and sets the expire field.
func (h *Headscale) RefreshMachine(machine *Machine, expiry time.Time) error {
// RenameMachine takes a Machine struct and a new GivenName for the machines
// and renames it.
func (h *Headscale) RenameMachine(machine *Machine, newName string) error {
err := CheckForFQDNRules(
newName,
)
if err != nil {
log.Error().
Caller().
Str("func", "RenameMachine").
Str("machine", machine.Hostname).
Str("newName", newName).
Err(err)
return err
}
machine.GivenName = newName
h.setLastStateChangeToNow(machine.Namespace.Name)
if err := h.db.Save(machine).Error; err != nil {
return fmt.Errorf("failed to rename machine in the database: %w", err)
}
return nil
}
// RefreshMachine takes a Machine struct and sets the expire field to now.
func (h *Headscale) RefreshMachine(machine *Machine, expiry time.Time) {
now := time.Now()
machine.LastSuccessfulUpdate = &now
@ -437,46 +475,41 @@ func (machine *Machine) GetHostInfo() tailcfg.Hostinfo {
}
func (h *Headscale) isOutdated(machine *Machine) bool {
if err := h.UpdateMachine(machine); err != nil {
if err := h.UpdateMachineFromDatabase(machine); err != nil {
// It does not seem meaningful to propagate this error as the end result
// will have to be that the machine has to be considered outdated.
return true
}
namespaceSet := set.New(set.ThreadSafe)
namespaceSet.Add(machine.Namespace.Name)
namespaces := make([]string, namespaceSet.Size())
for index, namespace := range namespaceSet.List() {
if name, ok := namespace.(string); ok {
namespaces[index] = name
}
}
lastChange := h.getLastStateChange(namespaces...)
// Get the last update from all headscale namespaces to compare with our nodes
// last update.
// TODO(kradalby): Only request updates from namespaces where we can talk to nodes
// This would mostly be for a bit of performance, and can be calculated based on
// ACLs.
lastChange := h.getLastStateChange()
lastUpdate := machine.CreatedAt
if machine.LastSuccessfulUpdate != nil {
lastUpdate = *machine.LastSuccessfulUpdate
}
log.Trace().
Caller().
Str("machine", machine.Name).
Str("machine", machine.Hostname).
Time("last_successful_update", lastChange).
Time("last_state_change", lastUpdate).
Msgf("Checking if %s is missing updates", machine.Name)
Msgf("Checking if %s is missing updates", machine.Hostname)
return lastUpdate.Before(lastChange)
}
func (machine Machine) String() string {
return machine.Name
return machine.Hostname
}
func (machines Machines) String() string {
temp := make([]string, len(machines))
for index, machine := range machines {
temp[index] = machine.Name
temp[index] = machine.Hostname
}
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
@ -487,7 +520,7 @@ func (machines MachinesP) String() string {
temp := make([]string, len(machines))
for index, machine := range machines {
temp[index] = machine.Name
temp[index] = machine.Hostname
}
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
@ -584,7 +617,7 @@ func (machine Machine) toNode(
if dnsConfig != nil && dnsConfig.Proxied { // MagicDNS
hostname = fmt.Sprintf(
"%s.%s.%s",
machine.Name,
machine.GivenName,
machine.Namespace.Name,
baseDomain,
)
@ -596,7 +629,7 @@ func (machine Machine) toNode(
)
}
} else {
hostname = machine.Name
hostname = machine.GivenName
}
hostInfo := machine.GetHostInfo()
@ -637,7 +670,8 @@ func (machine *Machine) toProto() *v1.Machine {
NodeKey: machine.NodeKey,
DiscoKey: machine.DiscoKey,
IpAddresses: machine.IPAddresses.ToStringSlice(),
Name: machine.Name,
Name: machine.Hostname,
GivenName: machine.GivenName,
Namespace: machine.Namespace.toProto(),
ForcedTags: machine.ForcedTags,
@ -753,7 +787,7 @@ func (h *Headscale) RegisterMachine(machine Machine,
log.Trace().
Caller().
Str("machine", machine.Name).
Str("machine", machine.Hostname).
Msg("Attempting to register machine")
h.ipAllocationMutex.Lock()
@ -764,7 +798,7 @@ func (h *Headscale) RegisterMachine(machine Machine,
log.Error().
Caller().
Err(err).
Str("machine", machine.Name).
Str("machine", machine.Hostname).
Msg("Could not find IP for the new machine")
return nil, err
@ -778,7 +812,7 @@ func (h *Headscale) RegisterMachine(machine Machine,
log.Trace().
Caller().
Str("machine", machine.Name).
Str("machine", machine.Hostname).
Str("ip", strings.Join(ips.ToStringSlice(), ",")).
Msg("Machine registered with the database")
@ -827,7 +861,7 @@ func (h *Headscale) EnableRoutes(machine *Machine, routeStrs ...string) error {
if !contains(machine.GetAdvertisedRoutes(), newRoute) {
return fmt.Errorf(
"route (%s) is not available on node %s: %w",
machine.Name,
machine.Hostname,
newRoute, errMachineRouteIsNotAvailable,
)
}
@ -852,3 +886,32 @@ func (machine *Machine) RoutesToProto() *v1.Routes {
EnabledRoutes: ipPrefixToString(enabledRoutes),
}
}
func (h *Headscale) GenerateGivenName(suppliedName string) (string, error) {
// If a hostname is or will be longer than 63 chars after adding the hash,
// it needs to be trimmed.
trimmedHostnameLength := labelHostnameLength - MachineGivenNameHashLength - 2
normalizedHostname, err := NormalizeToFQDNRules(
suppliedName,
h.cfg.OIDC.StripEmaildomain,
)
if err != nil {
return "", err
}
postfix, err := GenerateRandomStringDNSSafe(MachineGivenNameHashLength)
if err != nil {
return "", err
}
// Verify that that the new unique name is shorter than the maximum allowed
// DNS segment.
if len(normalizedHostname) <= trimmedHostnameLength {
normalizedHostname = fmt.Sprintf("%s-%s", normalizedHostname, postfix)
} else {
normalizedHostname = fmt.Sprintf("%s-%s", normalizedHostname[:trimmedHostnameLength], postfix)
}
return normalizedHostname, nil
}