Factor wgkey to types/key

This commit converts all the uses of wgkey to the new key interfaces.

It now has specific  machine, node and discovery keys and we now should
use them correctly.

Please note the new logic which strips a key prefix (in utils.go) that
is now standard inside tailscale.

In theory we could put it in the database, but to preserve backwards
compatibility and not spend a lot of resources on accounting for both,
we just strip them.
This commit is contained in:
Kristoffer Dalby 2021-11-26 23:30:42 +00:00
parent 07418140a2
commit cfd53bc4aa
7 changed files with 184 additions and 143 deletions

View file

@ -12,12 +12,12 @@ import (
"github.com/fatih/set"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/rs/zerolog/log"
"go4.org/mem"
"google.golang.org/protobuf/types/known/timestamppb"
"gorm.io/datatypes"
"gorm.io/gorm"
"inet.af/netaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/wgkey"
"tailscale.com/types/key"
)
const (
@ -260,9 +260,11 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
}
// GetMachineByMachineKey finds a Machine by ID and returns the Machine struct.
func (h *Headscale) GetMachineByMachineKey(machineKey string) (*Machine, error) {
func (h *Headscale) GetMachineByMachineKey(
machineKey key.MachinePublic,
) (*Machine, error) {
m := Machine{}
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", machineKey); result.Error != nil {
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", MachinePublicKeyStripPrefix(machineKey)); result.Error != nil {
return nil, result.Error
}
@ -437,25 +439,31 @@ func (machine Machine) toNode(
dnsConfig *tailcfg.DNSConfig,
includeRoutes bool,
) (*tailcfg.Node, error) {
nodeKey, err := wgkey.ParseHex(machine.NodeKey)
nodeKey, err := key.ParseNodePublicUntyped(mem.S(machine.NodeKey))
if err != nil {
return nil, err
log.Trace().
Caller().
Str("node_key", machine.NodeKey).
Msgf("Failed to parse node public key from hex")
return nil, fmt.Errorf("failed to parse node public key: %w", err)
}
machineKey, err := wgkey.ParseHex(machine.MachineKey)
machineKey, err := key.ParseMachinePublicUntyped(mem.S(machine.MachineKey))
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to parse machine public key: %w", err)
}
var discoKey tailcfg.DiscoKey
var discoKey key.DiscoPublic
if machine.DiscoKey != "" {
dKey, err := wgkey.ParseHex(machine.DiscoKey)
dKey := key.DiscoPublic{}
err := dKey.UnmarshalText([]byte(discoPublicHexPrefix + machine.DiscoKey))
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to parse disco public key: %w", err)
}
discoKey = tailcfg.DiscoKey(dKey)
discoKey = key.DiscoPublic(dKey)
} else {
discoKey = tailcfg.DiscoKey{}
discoKey = key.DiscoPublic{}
}
addrs := []netaddr.IPPrefix{}
@ -555,9 +563,9 @@ func (machine Machine) toNode(
), // in headscale, unlike tailcontrol server, IDs are permanent
Name: hostname,
User: tailcfg.UserID(machine.NamespaceID),
Key: tailcfg.NodeKey(nodeKey),
Key: nodeKey,
KeyExpiry: keyExpiry,
Machine: tailcfg.MachineKey(machineKey),
Machine: machineKey,
DiscoKey: discoKey,
Addresses: addrs,
AllowedIPs: allowedIPs,
@ -618,31 +626,35 @@ func (machine *Machine) toProto() *v1.Machine {
// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey.
func (h *Headscale) RegisterMachine(
key string,
machineKeyStr string,
namespaceName string,
) (*Machine, error) {
namespace, err := h.GetNamespace(namespaceName)
if err != nil {
return nil, err
}
machineKey, err := wgkey.ParseHex(key)
machineKey, err := key.ParseMachinePublicUntyped(mem.S(machineKeyStr))
if err != nil {
return nil, err
}
machine := Machine{}
if result := h.db.First(&machine, "machine_key = ?", machineKey.HexString()); errors.Is(
result.Error,
gorm.ErrRecordNotFound,
) {
return nil, errMachineNotFound
log.Trace().
Caller().
Str("machine_key_str", machineKeyStr).
Str("machine_key", machineKey.String()).
Msg("Registering machine")
machine, err := h.GetMachineByMachineKey(machineKey)
if err != nil {
return nil, err
}
// TODO(kradalby): Currently, if it fails to find a requested expiry, non will be set
// This means that if a user is to slow with register a machine, it will possibly not
// have the correct expiry.
requestedTime := time.Time{}
if requestedTimeIf, found := h.requestedExpiryCache.Get(machineKey.HexString()); found {
if requestedTimeIf, found := h.requestedExpiryCache.Get(machineKey.String()); found {
log.Trace().
Caller().
Str("machine", machine.Name).
@ -658,9 +670,9 @@ func (h *Headscale) RegisterMachine(
Str("machine", machine.Name).
Msg("machine already registered, reauthenticating")
h.RefreshMachine(&machine, requestedTime)
h.RefreshMachine(machine, requestedTime)
return &machine, nil
return machine, nil
}
log.Trace().
@ -709,7 +721,7 @@ func (h *Headscale) RegisterMachine(
Str("ip", ip.String()).
Msg("Machine registered with the database")
return &machine, nil
return machine, nil
}
func (machine *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) {