Factor wgkey to types/key
This commit converts all the uses of wgkey to the new key interfaces. It now has specific machine, node and discovery keys and we now should use them correctly. Please note the new logic which strips a key prefix (in utils.go) that is now standard inside tailscale. In theory we could put it in the database, but to preserve backwards compatibility and not spend a lot of resources on accounting for both, we just strip them.
This commit is contained in:
parent
07418140a2
commit
cfd53bc4aa
7 changed files with 184 additions and 143 deletions
92
api.go
92
api.go
|
@ -13,9 +13,10 @@ import (
|
|||
"github.com/gin-gonic/gin"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/rs/zerolog/log"
|
||||
"go4.org/mem"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/wgkey"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -34,7 +35,7 @@ func (h *Headscale) KeyHandler(ctx *gin.Context) {
|
|||
ctx.Data(
|
||||
http.StatusOK,
|
||||
"text/plain; charset=utf-8",
|
||||
[]byte(h.publicKey.HexString()),
|
||||
[]byte(MachinePublicKeyStripPrefix(*h.publicKey)),
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -73,10 +74,10 @@ func (h *Headscale) RegisterWebAPI(ctx *gin.Context) {
|
|||
func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
|
||||
body, _ := io.ReadAll(ctx.Request.Body)
|
||||
machineKeyStr := ctx.Param("id")
|
||||
machineKey, err := wgkey.ParseHex(machineKeyStr)
|
||||
machineKey, err := key.ParseMachinePublicUntyped(mem.S(machineKeyStr))
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "Registration").
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Cannot parse machine key")
|
||||
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
|
||||
|
@ -88,7 +89,7 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
|
|||
err = decode(body, &req, &machineKey, h.privateKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "Registration").
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Cannot decode message")
|
||||
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
|
||||
|
@ -98,17 +99,17 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
|
|||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
machine, err := h.GetMachineByMachineKey(machineKey.HexString())
|
||||
machine, err := h.GetMachineByMachineKey(machineKey)
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine")
|
||||
newMachine := Machine{
|
||||
Expiry: &time.Time{},
|
||||
MachineKey: machineKey.HexString(),
|
||||
MachineKey: MachinePublicKeyStripPrefix(machineKey),
|
||||
Name: req.Hostinfo.Hostname,
|
||||
}
|
||||
if err := h.db.Create(&newMachine).Error; err != nil {
|
||||
log.Error().
|
||||
Str("handler", "Registration").
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Could not create row")
|
||||
machineRegistrations.WithLabelValues("unknown", "web", "error", machine.Namespace.Name).
|
||||
|
@ -125,7 +126,7 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
|
|||
// - Trying to log out (sending a expiry in the past)
|
||||
// - A valid, registered machine, looking for the node map
|
||||
// - Expired machine wanting to reauthenticate
|
||||
if machine.NodeKey == wgkey.Key(req.NodeKey).HexString() {
|
||||
if machine.NodeKey == req.NodeKey.String() {
|
||||
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
|
||||
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
|
||||
if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) {
|
||||
|
@ -144,7 +145,7 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
|
|||
}
|
||||
|
||||
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
|
||||
if machine.NodeKey == wgkey.Key(req.OldNodeKey).HexString() &&
|
||||
if machine.NodeKey == req.OldNodeKey.String() &&
|
||||
!machine.isExpired() {
|
||||
h.handleMachineRefreshKey(ctx, machineKey, req, *machine)
|
||||
|
||||
|
@ -168,7 +169,7 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
|
|||
}
|
||||
|
||||
func (h *Headscale) getMapResponse(
|
||||
machineKey wgkey.Key,
|
||||
machineKey key.MachinePublic,
|
||||
req tailcfg.MapRequest,
|
||||
machine *Machine,
|
||||
) ([]byte, error) {
|
||||
|
@ -179,6 +180,7 @@ func (h *Headscale) getMapResponse(
|
|||
node, err := machine.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("func", "getMapResponse").
|
||||
Err(err).
|
||||
Msg("Cannot convert to node")
|
||||
|
@ -189,6 +191,7 @@ func (h *Headscale) getMapResponse(
|
|||
peers, err := h.getValidPeers(machine)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("func", "getMapResponse").
|
||||
Err(err).
|
||||
Msg("Cannot fetch peers")
|
||||
|
@ -201,6 +204,7 @@ func (h *Headscale) getMapResponse(
|
|||
nodePeers, err := peers.toNodes(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("func", "getMapResponse").
|
||||
Err(err).
|
||||
Msg("Failed to convert peers to Tailscale nodes")
|
||||
|
@ -238,10 +242,7 @@ func (h *Headscale) getMapResponse(
|
|||
|
||||
encoder, _ := zstd.NewWriter(nil)
|
||||
srcCompressed := encoder.EncodeAll(src, nil)
|
||||
respBody, err = encodeMsg(srcCompressed, &machineKey, h.privateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
respBody = h.privateKey.SealTo(machineKey, srcCompressed)
|
||||
} else {
|
||||
respBody, err = encode(resp, &machineKey, h.privateKey)
|
||||
if err != nil {
|
||||
|
@ -257,7 +258,7 @@ func (h *Headscale) getMapResponse(
|
|||
}
|
||||
|
||||
func (h *Headscale) getMapKeepAliveResponse(
|
||||
machineKey wgkey.Key,
|
||||
machineKey key.MachinePublic,
|
||||
mapRequest tailcfg.MapRequest,
|
||||
) ([]byte, error) {
|
||||
mapResponse := tailcfg.MapResponse{
|
||||
|
@ -269,10 +270,7 @@ func (h *Headscale) getMapKeepAliveResponse(
|
|||
src, _ := json.Marshal(mapResponse)
|
||||
encoder, _ := zstd.NewWriter(nil)
|
||||
srcCompressed := encoder.EncodeAll(src, nil)
|
||||
respBody, err = encodeMsg(srcCompressed, &machineKey, h.privateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
respBody = h.privateKey.SealTo(machineKey, srcCompressed)
|
||||
} else {
|
||||
respBody, err = encode(mapResponse, &machineKey, h.privateKey)
|
||||
if err != nil {
|
||||
|
@ -288,13 +286,12 @@ func (h *Headscale) getMapKeepAliveResponse(
|
|||
|
||||
func (h *Headscale) handleMachineLogOut(
|
||||
ctx *gin.Context,
|
||||
machineKey wgkey.Key,
|
||||
machineKey key.MachinePublic,
|
||||
machine Machine,
|
||||
) {
|
||||
resp := tailcfg.RegisterResponse{}
|
||||
|
||||
log.Info().
|
||||
Str("handler", "Registration").
|
||||
Str("machine", machine.Name).
|
||||
Msg("Client requested logout")
|
||||
|
||||
|
@ -306,7 +303,7 @@ func (h *Headscale) handleMachineLogOut(
|
|||
respBody, err := encode(resp, &machineKey, h.privateKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "Registration").
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Cannot encode message")
|
||||
ctx.String(http.StatusInternalServerError, "")
|
||||
|
@ -318,14 +315,13 @@ func (h *Headscale) handleMachineLogOut(
|
|||
|
||||
func (h *Headscale) handleMachineValidRegistration(
|
||||
ctx *gin.Context,
|
||||
machineKey wgkey.Key,
|
||||
machineKey key.MachinePublic,
|
||||
machine Machine,
|
||||
) {
|
||||
resp := tailcfg.RegisterResponse{}
|
||||
|
||||
// The machine registration is valid, respond with redirect to /map
|
||||
log.Debug().
|
||||
Str("handler", "Registration").
|
||||
Str("machine", machine.Name).
|
||||
Msg("Client is registered and we have the current NodeKey. All clear to /map")
|
||||
|
||||
|
@ -337,7 +333,7 @@ func (h *Headscale) handleMachineValidRegistration(
|
|||
respBody, err := encode(resp, &machineKey, h.privateKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "Registration").
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Cannot encode message")
|
||||
machineRegistrations.WithLabelValues("update", "web", "error", machine.Namespace.Name).
|
||||
|
@ -353,7 +349,7 @@ func (h *Headscale) handleMachineValidRegistration(
|
|||
|
||||
func (h *Headscale) handleMachineExpired(
|
||||
ctx *gin.Context,
|
||||
machineKey wgkey.Key,
|
||||
machineKey key.MachinePublic,
|
||||
registerRequest tailcfg.RegisterRequest,
|
||||
machine Machine,
|
||||
) {
|
||||
|
@ -361,7 +357,6 @@ func (h *Headscale) handleMachineExpired(
|
|||
|
||||
// The client has registered before, but has expired
|
||||
log.Debug().
|
||||
Str("handler", "Registration").
|
||||
Str("machine", machine.Name).
|
||||
Msg("Machine registration has expired. Sending a authurl to register")
|
||||
|
||||
|
@ -373,16 +368,16 @@ func (h *Headscale) handleMachineExpired(
|
|||
|
||||
if h.cfg.OIDC.Issuer != "" {
|
||||
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
|
||||
strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString())
|
||||
strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.String())
|
||||
} else {
|
||||
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
|
||||
strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString())
|
||||
strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.String())
|
||||
}
|
||||
|
||||
respBody, err := encode(resp, &machineKey, h.privateKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "Registration").
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Cannot encode message")
|
||||
machineRegistrations.WithLabelValues("reauth", "web", "error", machine.Namespace.Name).
|
||||
|
@ -398,17 +393,16 @@ func (h *Headscale) handleMachineExpired(
|
|||
|
||||
func (h *Headscale) handleMachineRefreshKey(
|
||||
ctx *gin.Context,
|
||||
machineKey wgkey.Key,
|
||||
machineKey key.MachinePublic,
|
||||
registerRequest tailcfg.RegisterRequest,
|
||||
machine Machine,
|
||||
) {
|
||||
resp := tailcfg.RegisterResponse{}
|
||||
|
||||
log.Debug().
|
||||
Str("handler", "Registration").
|
||||
Str("machine", machine.Name).
|
||||
Msg("We have the OldNodeKey in the database. This is a key refresh")
|
||||
machine.NodeKey = wgkey.Key(registerRequest.NodeKey).HexString()
|
||||
machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey)
|
||||
h.db.Save(&machine)
|
||||
|
||||
resp.AuthURL = ""
|
||||
|
@ -416,7 +410,7 @@ func (h *Headscale) handleMachineRefreshKey(
|
|||
respBody, err := encode(resp, &machineKey, h.privateKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "Registration").
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Cannot encode message")
|
||||
ctx.String(http.StatusInternalServerError, "Extremely sad!")
|
||||
|
@ -428,7 +422,7 @@ func (h *Headscale) handleMachineRefreshKey(
|
|||
|
||||
func (h *Headscale) handleMachineRegistrationNew(
|
||||
ctx *gin.Context,
|
||||
machineKey wgkey.Key,
|
||||
machineKey key.MachinePublic,
|
||||
registerRequest tailcfg.RegisterRequest,
|
||||
machine Machine,
|
||||
) {
|
||||
|
@ -436,18 +430,17 @@ func (h *Headscale) handleMachineRegistrationNew(
|
|||
|
||||
// The machine registration is new, redirect the client to the registration URL
|
||||
log.Debug().
|
||||
Str("handler", "Registration").
|
||||
Str("machine", machine.Name).
|
||||
Msg("The node is sending us a new NodeKey, sending auth url")
|
||||
if h.cfg.OIDC.Issuer != "" {
|
||||
resp.AuthURL = fmt.Sprintf(
|
||||
"%s/oidc/register/%s",
|
||||
strings.TrimSuffix(h.cfg.ServerURL, "/"),
|
||||
machineKey.HexString(),
|
||||
machineKey.String(),
|
||||
)
|
||||
} else {
|
||||
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
|
||||
strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString())
|
||||
strings.TrimSuffix(h.cfg.ServerURL, "/"), MachinePublicKeyStripPrefix(machineKey))
|
||||
}
|
||||
|
||||
if !registerRequest.Expiry.IsZero() {
|
||||
|
@ -457,19 +450,21 @@ func (h *Headscale) handleMachineRegistrationNew(
|
|||
Time("expiry", registerRequest.Expiry).
|
||||
Msg("Non-zero expiry time requested, adding to cache")
|
||||
h.requestedExpiryCache.Set(
|
||||
machineKey.HexString(),
|
||||
machineKey.String(),
|
||||
registerRequest.Expiry,
|
||||
requestedExpiryCacheExpiration,
|
||||
)
|
||||
}
|
||||
|
||||
machine.NodeKey = wgkey.Key(registerRequest.NodeKey).HexString() // save the NodeKey
|
||||
machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey)
|
||||
|
||||
// save the NodeKey
|
||||
h.db.Save(&machine)
|
||||
|
||||
respBody, err := encode(resp, &machineKey, h.privateKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "Registration").
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Cannot encode message")
|
||||
ctx.String(http.StatusInternalServerError, "")
|
||||
|
@ -481,7 +476,7 @@ func (h *Headscale) handleMachineRegistrationNew(
|
|||
|
||||
func (h *Headscale) handleAuthKey(
|
||||
ctx *gin.Context,
|
||||
machineKey wgkey.Key,
|
||||
machineKey key.MachinePublic,
|
||||
registerRequest tailcfg.RegisterRequest,
|
||||
machine Machine,
|
||||
) {
|
||||
|
@ -493,6 +488,7 @@ func (h *Headscale) handleAuthKey(
|
|||
pak, err := h.checkKeyValidity(registerRequest.Auth.AuthKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("func", "handleAuthKey").
|
||||
Str("machine", machine.Name).
|
||||
Err(err).
|
||||
|
@ -501,6 +497,7 @@ func (h *Headscale) handleAuthKey(
|
|||
respBody, err := encode(resp, &machineKey, h.privateKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("func", "handleAuthKey").
|
||||
Str("machine", machine.Name).
|
||||
Err(err).
|
||||
|
@ -513,6 +510,7 @@ func (h *Headscale) handleAuthKey(
|
|||
}
|
||||
ctx.Data(http.StatusUnauthorized, "application/json; charset=utf-8", respBody)
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("func", "handleAuthKey").
|
||||
Str("machine", machine.Name).
|
||||
Msg("Failed authentication via AuthKey")
|
||||
|
@ -537,6 +535,7 @@ func (h *Headscale) handleAuthKey(
|
|||
ip, err := h.getAvailableIP()
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("func", "handleAuthKey").
|
||||
Str("machine", machine.Name).
|
||||
Msg("Failed to find an available IP")
|
||||
|
@ -555,9 +554,9 @@ func (h *Headscale) handleAuthKey(
|
|||
machine.AuthKeyID = uint(pak.ID)
|
||||
machine.IPAddress = ip.String()
|
||||
machine.NamespaceID = pak.NamespaceID
|
||||
machine.NodeKey = wgkey.Key(registerRequest.NodeKey).
|
||||
HexString()
|
||||
// we update it just in case
|
||||
|
||||
machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey)
|
||||
// we update it just in case
|
||||
machine.Registered = true
|
||||
machine.RegisterMethod = RegisterMethodAuthKey
|
||||
h.db.Save(&machine)
|
||||
|
@ -571,6 +570,7 @@ func (h *Headscale) handleAuthKey(
|
|||
respBody, err := encode(resp, &machineKey, h.privateKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("func", "handleAuthKey").
|
||||
Str("machine", machine.Name).
|
||||
Err(err).
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue