Use tailscale key types instead of strings (#1609)

* upgrade tailscale

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* make Node object use actualy tailscale key types

This commit changes the Node struct to have both a field for strings
to store the keys in the database and a dedicated Key for each type
of key.

The keys are populated and stored with Gorm hooks to ensure the data
is stored in the db.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* use key types throughout the code

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* make sure machinekey is concistently used

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* use machine key in auth url

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* fix web register

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* use key type in notifier

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* fix relogin with webauth

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

---------

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2023-11-19 22:37:04 +01:00 committed by GitHub
parent c0fd06e3f5
commit ed4e19996b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 550 additions and 471 deletions

View file

@ -449,10 +449,10 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
router.HandleFunc("/health", h.HealthHandler).Methods(http.MethodGet)
router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet)
router.HandleFunc("/register/{nkey}", h.RegisterWebAPI).Methods(http.MethodGet)
router.HandleFunc("/register/{mkey}", h.RegisterWebAPI).Methods(http.MethodGet)
h.addLegacyHandlers(router)
router.HandleFunc("/oidc/register/{nkey}", h.RegisterOIDC).Methods(http.MethodGet)
router.HandleFunc("/oidc/register/{mkey}", h.RegisterOIDC).Methods(http.MethodGet)
router.HandleFunc("/oidc/callback", h.OIDCCallback).Methods(http.MethodGet)
router.HandleFunc("/apple", h.AppleConfigMessage).Methods(http.MethodGet)
router.HandleFunc("/apple/{platform}", h.ApplePlatformConfig).

View file

@ -45,7 +45,7 @@ func (h *Headscale) handleRegister(
// is that the client will hammer headscale with requests until it gets a
// successful RegisterResponse.
if registerRequest.Followup != "" {
if _, ok := h.registrationCache.Get(registerRequest.NodeKey.String()); ok {
if _, ok := h.registrationCache.Get(machineKey.String()); ok {
log.Debug().
Caller().
Str("node", registerRequest.Hostinfo.Hostname).
@ -78,7 +78,7 @@ func (h *Headscale) handleRegister(
Msg("New node not yet in the database")
givenName, err := h.db.GenerateGivenName(
machineKey.String(),
machineKey,
registerRequest.Hostinfo.Hostname,
)
if err != nil {
@ -97,10 +97,10 @@ func (h *Headscale) handleRegister(
// We create the node and then keep it around until a callback
// happens
newNode := types.Node{
MachineKey: machineKey.String(),
MachineKey: machineKey,
Hostname: registerRequest.Hostinfo.Hostname,
GivenName: givenName,
NodeKey: registerRequest.NodeKey.String(),
NodeKey: registerRequest.NodeKey,
LastSeen: &now,
Expiry: &time.Time{},
}
@ -116,7 +116,7 @@ func (h *Headscale) handleRegister(
}
h.registrationCache.Set(
newNode.NodeKey,
machineKey.String(),
newNode,
registerCacheExpiration,
)
@ -134,11 +134,7 @@ func (h *Headscale) handleRegister(
// (juan): For a while we had a bug where we were not storing the MachineKey for the nodes using the TS2021,
// due to a misunderstanding of the protocol https://github.com/juanfont/headscale/issues/1054
// So if we have a not valid MachineKey (but we were able to fetch the node with the NodeKeys), we update it.
var storedMachineKey key.MachinePublic
err = storedMachineKey.UnmarshalText(
[]byte(node.MachineKey),
)
if err != nil || storedMachineKey.IsZero() {
if err != nil || node.MachineKey.IsZero() {
if err := h.db.NodeSetMachineKey(node, machineKey); err != nil {
log.Error().
Caller().
@ -156,7 +152,7 @@ func (h *Headscale) handleRegister(
// - Trying to log out (sending a expiry in the past)
// - A valid, registered node, looking for /map
// - Expired node wanting to reauthenticate
if node.NodeKey == registerRequest.NodeKey.String() {
if node.NodeKey.String() == registerRequest.NodeKey.String() {
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
if !registerRequest.Expiry.IsZero() &&
@ -176,7 +172,7 @@ func (h *Headscale) handleRegister(
}
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
if node.NodeKey == registerRequest.OldNodeKey.String() &&
if node.NodeKey.String() == registerRequest.OldNodeKey.String() &&
!node.IsExpired() {
h.handleNodeKeyRefresh(
writer,
@ -207,9 +203,9 @@ func (h *Headscale) handleRegister(
// we need to make sure the NodeKey matches the one in the request
// TODO(juan): What happens when using fast user switching between two
// headscale-managed tailnets?
node.NodeKey = registerRequest.NodeKey.String()
node.NodeKey = registerRequest.NodeKey
h.registrationCache.Set(
registerRequest.NodeKey.String(),
machineKey.String(),
*node,
registerCacheExpiration,
)
@ -294,7 +290,7 @@ func (h *Headscale) handleAuthKey(
Str("node", registerRequest.Hostinfo.Hostname).
Msg("Authentication key was valid, proceeding to acquire IP addresses")
nodeKey := registerRequest.NodeKey.String()
nodeKey := registerRequest.NodeKey
// retrieve node information if it exist
// The error is not important, because if it does not
@ -342,7 +338,7 @@ func (h *Headscale) handleAuthKey(
} else {
now := time.Now().UTC()
givenName, err := h.db.GenerateGivenName(machineKey.String(), registerRequest.Hostinfo.Hostname)
givenName, err := h.db.GenerateGivenName(machineKey, registerRequest.Hostinfo.Hostname)
if err != nil {
log.Error().
Caller().
@ -359,7 +355,7 @@ func (h *Headscale) handleAuthKey(
Hostname: registerRequest.Hostinfo.Hostname,
GivenName: givenName,
UserID: pak.User.ID,
MachineKey: machineKey.String(),
MachineKey: machineKey,
RegisterMethod: util.RegisterMethodAuthKey,
Expiry: &registerRequest.Expiry,
NodeKey: nodeKey,
@ -460,12 +456,12 @@ func (h *Headscale) handleNewNode(
resp.AuthURL = fmt.Sprintf(
"%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"),
registerRequest.NodeKey,
machineKey.String(),
)
} else {
resp.AuthURL = fmt.Sprintf("%s/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"),
registerRequest.NodeKey)
machineKey.String())
}
respBody, err := mapper.MarshalResponse(resp, isNoise, h.privateKey2019, machineKey)
@ -715,11 +711,11 @@ func (h *Headscale) handleNodeExpiredOrLoggedOut(
if h.oauth2Config != nil {
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"),
registerRequest.NodeKey)
machineKey.String())
} else {
resp.AuthURL = fmt.Sprintf("%s/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"),
registerRequest.NodeKey)
machineKey.String())
}
respBody, err := mapper.MarshalResponse(resp, isNoise, h.privateKey2019, machineKey)

View file

@ -2,6 +2,7 @@ package db
import (
"context"
"database/sql"
"errors"
"fmt"
"net/netip"
@ -99,7 +100,7 @@ func NewHeadscaleDatabase(
// node was registered.
_ = dbConn.Migrator().RenameColumn(&types.Node{}, "nickname", "given_name")
// If the MacNodehine table has a column for registered,
// If the Node table has a column for registered,
// find all occourences of "false" and drop them. Then
// remove the column.
if dbConn.Migrator().HasColumn(&types.Node{}, "registered") {
@ -114,13 +115,13 @@ func NewHeadscaleDatabase(
for _, node := range nodes {
log.Info().
Str("node", node.Hostname).
Str("machine_key", node.MachineKey).
Str("machine_key", node.MachineKey.ShortString()).
Msg("Deleting unregistered node")
if err := dbConn.Delete(&types.Node{}, node.ID).Error; err != nil {
log.Error().
Err(err).
Str("node", node.Hostname).
Str("machine_key", node.MachineKey).
Str("machine_key", node.MachineKey.ShortString()).
Msg("Error deleting unregistered node")
}
}
@ -136,6 +137,50 @@ func NewHeadscaleDatabase(
return nil, err
}
err = dbConn.AutoMigrate(&types.Node{})
if err != nil {
return nil, err
}
// Ensure all keys have correct prefixes
// https://github.com/tailscale/tailscale/blob/main/types/key/node.go#L35
type result struct {
ID uint64
MachineKey string
NodeKey string
DiscoKey string
}
var results []result
err = db.db.Raw("SELECT id, node_key, machine_key, disco_key FROM nodes").Find(&results).Error
if err != nil {
return nil, err
}
for _, node := range results {
mKey := node.MachineKey
if !strings.HasPrefix(node.MachineKey, "mkey:") {
mKey = "mkey:" + node.MachineKey
}
nKey := node.NodeKey
if !strings.HasPrefix(node.NodeKey, "nodekey:") {
nKey = "nodekey:" + node.NodeKey
}
dKey := node.DiscoKey
if !strings.HasPrefix(node.DiscoKey, "discokey:") {
dKey = "discokey:" + node.DiscoKey
}
err := db.db.Exec("UPDATE nodes SET machine_key = @mKey, node_key = @nKey, disco_key = @dKey WHERE ID = @id",
sql.Named("mKey", mKey),
sql.Named("nKey", nKey),
sql.Named("dKey", dKey),
sql.Named("id", node.ID)).Error
if err != nil {
return nil, err
}
}
if dbConn.Migrator().HasColumn(&types.Node{}, "enabled_routes") {
log.Info().Msgf("Database has legacy enabled_routes column in node, migrating...")
@ -195,11 +240,6 @@ func NewHeadscaleDatabase(
}
}
err = dbConn.AutoMigrate(&types.Node{})
if err != nil {
return nil, err
}
if dbConn.Migrator().HasColumn(&types.Node{}, "given_name") {
nodes := types.Nodes{}
if err := dbConn.Find(&nodes).Error; err != nil {
@ -253,27 +293,6 @@ func NewHeadscaleDatabase(
return nil, err
}
// Ensure all keys have correct prefixes
// https://github.com/tailscale/tailscale/blob/main/types/key/node.go#L35
nodes := types.Nodes{}
if err := dbConn.Find(&nodes).Error; err != nil {
log.Error().Err(err).Msg("Error accessing db")
}
for _, node := range nodes {
if !strings.HasPrefix(node.DiscoKey, "discokey:") {
node.DiscoKey = "discokey:" + node.DiscoKey
}
if !strings.HasPrefix(node.NodeKey, "nodekey:") {
node.NodeKey = "nodekey:" + node.NodeKey
}
if !strings.HasPrefix(node.MachineKey, "mkey:") {
node.MachineKey = "mkey:" + node.MachineKey
}
}
// TODO(kradalby): is this needed?
err = db.setValue("db_version", dbVersion)

View file

@ -55,7 +55,7 @@ func (hsdb *HSDatabase) listPeers(node *types.Node) (types.Nodes, error) {
Preload("User").
Preload("Routes").
Where("node_key <> ?",
node.NodeKey).Find(&nodes).Error; err != nil {
node.NodeKey.String()).Find(&nodes).Error; err != nil {
return types.Nodes{}, err
}
@ -268,7 +268,7 @@ func (hsdb *HSDatabase) SetTags(
hsdb.notifier.NotifyWithIgnore(types.StateUpdate{
Type: types.StatePeerChanged,
Changed: types.Nodes{node},
}, node.MachineKey)
}, node.MachineKey.String())
return nil
}
@ -304,7 +304,7 @@ func (hsdb *HSDatabase) RenameNode(node *types.Node, newName string) error {
hsdb.notifier.NotifyWithIgnore(types.StateUpdate{
Type: types.StatePeerChanged,
Changed: types.Nodes{node},
}, node.MachineKey)
}, node.MachineKey.String())
return nil
}
@ -330,7 +330,7 @@ func (hsdb *HSDatabase) nodeSetExpiry(node *types.Node, expiry time.Time) error
hsdb.notifier.NotifyWithIgnore(types.StateUpdate{
Type: types.StatePeerChanged,
Changed: types.Nodes{node},
}, node.MachineKey)
}, node.MachineKey.String())
return nil
}
@ -376,7 +376,7 @@ func (hsdb *HSDatabase) UpdateLastSeen(node *types.Node) error {
func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
cache *cache.Cache,
nodeKeyStr string,
mkey key.MachinePublic,
userName string,
nodeExpiry *time.Time,
registrationMethod string,
@ -384,20 +384,14 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
nodeKey := key.NodePublic{}
err := nodeKey.UnmarshalText([]byte(nodeKeyStr))
if err != nil {
return nil, err
}
log.Debug().
Str("nodeKey", nodeKey.ShortString()).
Str("machine_key", mkey.ShortString()).
Str("userName", userName).
Str("registrationMethod", registrationMethod).
Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)).
Msg("Registering node from API/CLI or auth callback")
if nodeInterface, ok := cache.Get(nodeKey.String()); ok {
if nodeInterface, ok := cache.Get(mkey.String()); ok {
if registrationNode, ok := nodeInterface.(types.Node); ok {
user, err := hsdb.getUser(userName)
if err != nil {
@ -425,7 +419,7 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
)
if err == nil {
cache.Delete(nodeKeyStr)
cache.Delete(mkey.String())
}
return node, err
@ -448,8 +442,8 @@ func (hsdb *HSDatabase) RegisterNode(node types.Node) (*types.Node, error) {
func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) {
log.Debug().
Str("node", node.Hostname).
Str("machine_key", node.MachineKey).
Str("node_key", node.NodeKey).
Str("machine_key", node.MachineKey.ShortString()).
Str("node_key", node.NodeKey.ShortString()).
Str("user", node.User.Name).
Msg("Registering node")
@ -464,8 +458,8 @@ func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) {
log.Trace().
Caller().
Str("node", node.Hostname).
Str("machine_key", node.MachineKey).
Str("node_key", node.NodeKey).
Str("machine_key", node.MachineKey.ShortString()).
Str("node_key", node.NodeKey.ShortString()).
Str("user", node.User.Name).
Msg("Node authorized again")
@ -507,7 +501,7 @@ func (hsdb *HSDatabase) NodeSetNodeKey(node *types.Node, nodeKey key.NodePublic)
defer hsdb.mu.Unlock()
if err := hsdb.db.Model(node).Updates(types.Node{
NodeKey: nodeKey.String(),
NodeKey: nodeKey,
}).Error; err != nil {
return err
}
@ -524,7 +518,7 @@ func (hsdb *HSDatabase) NodeSetMachineKey(
defer hsdb.mu.Unlock()
if err := hsdb.db.Model(node).Updates(types.Node{
MachineKey: machineKey.String(),
MachineKey: machineKey,
}).Error; err != nil {
return err
}
@ -703,7 +697,7 @@ func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) erro
hsdb.notifier.NotifyWithIgnore(types.StateUpdate{
Type: types.StatePeerChanged,
Changed: types.Nodes{node},
}, node.MachineKey)
}, node.MachineKey.String())
return nil
}
@ -734,7 +728,7 @@ func generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
return normalizedHostname, nil
}
func (hsdb *HSDatabase) GenerateGivenName(machineKey string, suppliedName string) (string, error) {
func (hsdb *HSDatabase) GenerateGivenName(mkey key.MachinePublic, suppliedName string) (string, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
@ -749,17 +743,22 @@ func (hsdb *HSDatabase) GenerateGivenName(machineKey string, suppliedName string
return "", err
}
for _, node := range nodes {
if node.MachineKey != machineKey && node.GivenName == givenName {
postfixedName, err := generateGivenName(suppliedName, true)
if err != nil {
return "", err
}
givenName = postfixedName
var nodeFound *types.Node
for idx, node := range nodes {
if node.GivenName == givenName {
nodeFound = nodes[idx]
}
}
if nodeFound != nil && nodeFound.MachineKey.String() != mkey.String() {
postfixedName, err := generateGivenName(suppliedName, true)
if err != nil {
return "", err
}
givenName = postfixedName
}
return givenName, nil
}

View file

@ -25,11 +25,13 @@ func (s *Suite) TestGetNode(c *check.C) {
_, err = db.GetNode("test", "testnode")
c.Assert(err, check.NotNil)
nodeKey := key.NewNode()
machineKey := key.NewMachine()
node := &types.Node{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
@ -51,11 +53,13 @@ func (s *Suite) TestGetNodeByID(c *check.C) {
_, err = db.GetNodeByID(0)
c.Assert(err, check.NotNil)
nodeKey := key.NewNode()
machineKey := key.NewMachine()
node := types.Node{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
@ -82,9 +86,8 @@ func (s *Suite) TestGetNodeByNodeKey(c *check.C) {
node := types.Node{
ID: 0,
MachineKey: machineKey.Public().String(),
NodeKey: nodeKey.Public().String(),
DiscoKey: "faa",
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
@ -113,9 +116,8 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) {
node := types.Node{
ID: 0,
MachineKey: machineKey.Public().String(),
NodeKey: nodeKey.Public().String(),
DiscoKey: "faa",
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
@ -130,11 +132,14 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) {
func (s *Suite) TestHardDeleteNode(c *check.C) {
user, err := db.CreateUser("test")
c.Assert(err, check.IsNil)
nodeKey := key.NewNode()
machineKey := key.NewMachine()
node := types.Node{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
Hostname: "testnode3",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
@ -160,11 +165,13 @@ func (s *Suite) TestListPeers(c *check.C) {
c.Assert(err, check.NotNil)
for index := 0; index <= 10; index++ {
nodeKey := key.NewNode()
machineKey := key.NewMachine()
node := types.Node{
ID: uint64(index),
MachineKey: "foo" + strconv.Itoa(index),
NodeKey: "bar" + strconv.Itoa(index),
DiscoKey: "faa" + strconv.Itoa(index),
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
Hostname: "testnode" + strconv.Itoa(index),
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
@ -205,11 +212,13 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
c.Assert(err, check.NotNil)
for index := 0; index <= 10; index++ {
nodeKey := key.NewNode()
machineKey := key.NewMachine()
node := types.Node{
ID: uint64(index),
MachineKey: "foo" + strconv.Itoa(index),
NodeKey: "bar" + strconv.Itoa(index),
DiscoKey: "faa" + strconv.Itoa(index),
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
IPAddresses: types.NodeAddresses{
netip.MustParseAddr(fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1))),
},
@ -288,11 +297,13 @@ func (s *Suite) TestExpireNode(c *check.C) {
_, err = db.GetNode("test", "testnode")
c.Assert(err, check.NotNil)
nodeKey := key.NewNode()
machineKey := key.NewMachine()
node := &types.Node{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
@ -345,11 +356,15 @@ func (s *Suite) TestGenerateGivenName(c *check.C) {
_, err = db.GetNode("user-1", "testnode")
c.Assert(err, check.NotNil)
nodeKey := key.NewNode()
machineKey := key.NewMachine()
machineKey2 := key.NewMachine()
node := &types.Node{
ID: 0,
MachineKey: "node-key-1",
NodeKey: "node-key-1",
DiscoKey: "disco-key-1",
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
Hostname: "hostname-1",
GivenName: "hostname-1",
UserID: user1.ID,
@ -358,25 +373,20 @@ func (s *Suite) TestGenerateGivenName(c *check.C) {
}
db.db.Save(node)
givenName, err := db.GenerateGivenName("node-key-2", "hostname-2")
givenName, err := db.GenerateGivenName(machineKey2.Public(), "hostname-2")
comment := check.Commentf("Same user, unique nodes, unique hostnames, no conflict")
c.Assert(err, check.IsNil, comment)
c.Assert(givenName, check.Equals, "hostname-2", comment)
givenName, err = db.GenerateGivenName("node-key-1", "hostname-1")
givenName, err = db.GenerateGivenName(machineKey.Public(), "hostname-1")
comment = check.Commentf("Same user, same node, same hostname, no conflict")
c.Assert(err, check.IsNil, comment)
c.Assert(givenName, check.Equals, "hostname-1", comment)
givenName, err = db.GenerateGivenName("node-key-2", "hostname-1")
givenName, err = db.GenerateGivenName(machineKey2.Public(), "hostname-1")
comment = check.Commentf("Same user, unique nodes, same hostname, conflict")
c.Assert(err, check.IsNil, comment)
c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", NodeGivenNameHashLength), comment)
givenName, err = db.GenerateGivenName("node-key-2", "hostname-1")
comment = check.Commentf("Unique users, unique nodes, same hostname, conflict")
c.Assert(err, check.IsNil, comment)
c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", NodeGivenNameHashLength), comment)
}
func (s *Suite) TestSetTags(c *check.C) {
@ -389,11 +399,13 @@ func (s *Suite) TestSetTags(c *check.C) {
_, err = db.GetNode("test", "testnode")
c.Assert(err, check.NotNil)
nodeKey := key.NewNode()
machineKey := key.NewMachine()
node := &types.Node{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
@ -565,6 +577,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
c.Assert(err, check.IsNil)
nodeKey := key.NewNode()
machineKey := key.NewMachine()
defaultRouteV4 := netip.MustParsePrefix("0.0.0.0/0")
defaultRouteV6 := netip.MustParsePrefix("::/0")
@ -574,9 +587,8 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
node := types.Node{
ID: 0,
MachineKey: "foo",
NodeKey: nodeKey.Public().String(),
DiscoKey: "faa",
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
Hostname: "test",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,

View file

@ -1,6 +1,7 @@
package db
import (
"log"
"net/netip"
"os"
"testing"
@ -27,19 +28,22 @@ func (s *Suite) SetUpTest(c *check.C) {
}
func (s *Suite) TearDownTest(c *check.C) {
os.RemoveAll(tmpDir)
// os.RemoveAll(tmpDir)
}
func (s *Suite) ResetDB(c *check.C) {
if len(tmpDir) != 0 {
os.RemoveAll(tmpDir)
}
// if len(tmpDir) != 0 {
// os.RemoveAll(tmpDir)
// }
var err error
tmpDir, err = os.MkdirTemp("", "autoygg-client-test")
tmpDir, err = os.MkdirTemp("", "headscale-db-test-*")
if err != nil {
c.Fatal(err)
}
log.Printf("database path: %s", tmpDir+"/headscale_test.db")
db, err = NewHeadscaleDatabase(
"sqlite3",
tmpDir+"/headscale_test.db",

View file

@ -172,12 +172,18 @@ func (api headscaleV1APIServer) RegisterNode(
) (*v1.RegisterNodeResponse, error) {
log.Trace().
Str("user", request.GetUser()).
Str("node_key", request.GetKey()).
Str("machine_key", request.GetKey()).
Msg("Registering node")
var mkey key.MachinePublic
err := mkey.UnmarshalText([]byte(request.GetKey()))
if err != nil {
return nil, err
}
node, err := api.h.db.RegisterNodeFromAuthCallback(
api.h.registrationCache,
request.GetKey(),
mkey,
request.GetUser(),
nil,
util.RegisterMethodCLI,
@ -521,13 +527,22 @@ func (api headscaleV1APIServer) DebugCreateNode(
Hostname: "DebugTestNode",
}
givenName, err := api.h.db.GenerateGivenName(request.GetKey(), request.GetName())
var mkey key.MachinePublic
err = mkey.UnmarshalText([]byte(request.GetKey()))
if err != nil {
return nil, err
}
givenName, err := api.h.db.GenerateGivenName(mkey, request.GetName())
if err != nil {
return nil, err
}
nodeKey := key.NewNode()
newNode := types.Node{
MachineKey: request.GetKey(),
MachineKey: mkey,
NodeKey: nodeKey.Public(),
Hostname: request.GetName(),
GivenName: givenName,
User: *user,
@ -538,14 +553,12 @@ func (api headscaleV1APIServer) DebugCreateNode(
HostInfo: types.HostInfo(hostinfo),
}
nodeKey := key.NodePublic{}
err = nodeKey.UnmarshalText([]byte(request.GetKey()))
if err != nil {
log.Panic().Msg("can not add node for debug. invalid node key")
}
log.Debug().
Str("machine_key", mkey.ShortString()).
Msg("adding debug machine via CLI, appending to registration cache")
api.h.registrationCache.Set(
nodeKey.String(),
mkey.String(),
newNode,
registerCacheExpiration,
)

View file

@ -12,7 +12,6 @@ import (
"time"
"github.com/gorilla/mux"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
@ -207,33 +206,16 @@ func (h *Headscale) RegisterWebAPI(
req *http.Request,
) {
vars := mux.Vars(req)
nodeKeyStr, ok := vars["nkey"]
if !util.NodePublicKeyRegex.Match([]byte(nodeKeyStr)) {
log.Warn().Str("node_key", nodeKeyStr).Msg("Invalid node key passed to registration url")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusUnauthorized)
_, err := writer.Write([]byte("Unauthorized"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return
}
machineKeyStr := vars["mkey"]
// We need to make sure we dont open for XSS style injections, if the parameter that
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
// the template and log an error.
var nodeKey key.NodePublic
err := nodeKey.UnmarshalText(
[]byte(nodeKeyStr),
var machineKey key.MachinePublic
err := machineKey.UnmarshalText(
[]byte(machineKeyStr),
)
if !ok || nodeKeyStr == "" || err != nil {
if err != nil {
log.Warn().Err(err).Msg("Failed to parse incoming nodekey")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
@ -251,7 +233,7 @@ func (h *Headscale) RegisterWebAPI(
var content bytes.Buffer
if err := registerWebAPITemplate.Execute(&content, registerWebAPITemplateConfig{
Key: nodeKeyStr,
Key: machineKey.String(),
}); err != nil {
log.Error().
Str("func", "RegisterWebAPI").

View file

@ -368,17 +368,6 @@ func (m *Mapper) marshalMapResponse(
) ([]byte, error) {
atomic.AddUint64(&m.seq, 1)
var machineKey key.MachinePublic
err := machineKey.UnmarshalText([]byte(node.MachineKey))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot parse client key")
return nil, err
}
jsonBody, err := json.Marshal(resp)
if err != nil {
log.Error().
@ -426,11 +415,11 @@ func (m *Mapper) marshalMapResponse(
if compression == util.ZstdCompression {
respBody = zstdEncode(jsonBody)
if !m.isNoise { // if legacy protocol
respBody = m.privateKey2019.SealTo(machineKey, respBody)
respBody = m.privateKey2019.SealTo(node.MachineKey, respBody)
}
} else {
if !m.isNoise { // if legacy protocol
respBody = m.privateKey2019.SealTo(machineKey, jsonBody)
respBody = m.privateKey2019.SealTo(node.MachineKey, jsonBody)
} else {
respBody = jsonBody
}

View file

@ -166,10 +166,16 @@ func Test_fullMapResponse(t *testing.T) {
expire := time.Date(2500, time.November, 11, 23, 0, 0, 0, time.UTC)
mini := &types.Node{
ID: 0,
MachineKey: "mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507",
NodeKey: "nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
DiscoKey: "discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084",
ID: 0,
MachineKey: mustMK(
"mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507",
),
NodeKey: mustNK(
"nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
),
DiscoKey: mustDK(
"discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084",
),
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
Hostname: "mini",
GivenName: "mini",
@ -226,7 +232,6 @@ func Test_fullMapResponse(t *testing.T) {
netip.MustParsePrefix("0.0.0.0/0"),
netip.MustParsePrefix("192.168.0.0/24"),
},
Endpoints: []string{},
DERP: "127.3.3.40:0",
Hostinfo: hiview(tailcfg.Hostinfo{}),
Created: created,
@ -244,10 +249,16 @@ func Test_fullMapResponse(t *testing.T) {
}
peer1 := &types.Node{
ID: 1,
MachineKey: "mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507",
NodeKey: "nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
DiscoKey: "discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084",
ID: 1,
MachineKey: mustMK(
"mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507",
),
NodeKey: mustNK(
"nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
),
DiscoKey: mustDK(
"discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084",
),
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
Hostname: "peer1",
GivenName: "peer1",
@ -278,7 +289,6 @@ func Test_fullMapResponse(t *testing.T) {
),
Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
Endpoints: []string{},
DERP: "127.3.3.40:0",
Hostinfo: hiview(tailcfg.Hostinfo{}),
Created: created,
@ -296,10 +306,16 @@ func Test_fullMapResponse(t *testing.T) {
}
peer2 := &types.Node{
ID: 2,
MachineKey: "mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507",
NodeKey: "nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
DiscoKey: "discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084",
ID: 2,
MachineKey: mustMK(
"mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507",
),
NodeKey: mustNK(
"nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
),
DiscoKey: mustDK(
"discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084",
),
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
Hostname: "peer2",
GivenName: "peer2",

View file

@ -52,21 +52,6 @@ func tailNode(
baseDomain string,
randomClientPort bool,
) (*tailcfg.Node, error) {
nodeKey, err := node.NodePublicKey()
if err != nil {
return nil, err
}
machineKey, err := node.MachinePublicKey()
if err != nil {
return nil, err
}
discoKey, err := node.DiscoPublicKey()
if err != nil {
return nil, err
}
addrs := node.IPAddresses.Prefixes()
allowedIPs := append(
@ -112,6 +97,11 @@ func tailNode(
tags, _ := pol.TagsOfNode(node)
tags = lo.Uniq(append(tags, node.ForcedTags...))
endpoints, err := node.EndpointsToAddrPort()
if err != nil {
return nil, err
}
tNode := tailcfg.Node{
ID: tailcfg.NodeID(node.ID), // this is the actual ID
StableID: tailcfg.StableNodeID(
@ -121,14 +111,14 @@ func tailNode(
User: tailcfg.UserID(node.UserID),
Key: nodeKey,
Key: node.NodeKey,
KeyExpiry: keyExpiry,
Machine: machineKey,
DiscoKey: discoKey,
Machine: node.MachineKey,
DiscoKey: node.DiscoKey,
Addresses: addrs,
AllowedIPs: allowedIPs,
Endpoints: node.Endpoints,
Endpoints: endpoints,
DERP: derp,
Hostinfo: hostInfo.View(),
Created: node.CreatedAt,

View file

@ -58,16 +58,36 @@ func TestTailNode(t *testing.T) {
pol: &policy.ACLPolicy{},
dnsConfig: &tailcfg.DNSConfig{},
baseDomain: "",
want: nil,
wantErr: true,
want: &tailcfg.Node{
StableID: "0",
Addresses: []netip.Prefix{},
AllowedIPs: []netip.Prefix{},
DERP: "127.3.3.40:0",
Hostinfo: hiview(tailcfg.Hostinfo{}),
Tags: []string{},
PrimaryRoutes: []netip.Prefix{},
Online: new(bool),
MachineAuthorized: true,
Capabilities: []tailcfg.NodeCapability{
"https://tailscale.com/cap/file-sharing", "https://tailscale.com/cap/is-admin",
"https://tailscale.com/cap/ssh", "debug-disable-upnp",
},
},
wantErr: false,
},
{
name: "minimal-node",
node: &types.Node{
ID: 0,
MachineKey: "mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507",
NodeKey: "nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
DiscoKey: "discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084",
ID: 0,
MachineKey: mustMK(
"mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507",
),
NodeKey: mustNK(
"nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
),
DiscoKey: mustDK(
"discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084",
),
IPAddresses: []netip.Addr{
netip.MustParseAddr("100.64.0.1"),
},
@ -133,10 +153,9 @@ func TestTailNode(t *testing.T) {
netip.MustParsePrefix("0.0.0.0/0"),
netip.MustParsePrefix("192.168.0.0/24"),
},
Endpoints: []string{},
DERP: "127.3.3.40:0",
Hostinfo: hiview(tailcfg.Hostinfo{}),
Created: created,
DERP: "127.3.3.40:0",
Hostinfo: hiview(tailcfg.Hostinfo{}),
Created: created,
Tags: []string{},

View file

@ -6,6 +6,7 @@ import (
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"tailscale.com/types/key"
)
type Notifier struct {
@ -17,9 +18,9 @@ func NewNotifier() *Notifier {
return &Notifier{}
}
func (n *Notifier) AddNode(machineKey string, c chan<- types.StateUpdate) {
log.Trace().Caller().Str("key", machineKey).Msg("acquiring lock to add node")
defer log.Trace().Caller().Str("key", machineKey).Msg("releasing lock to add node")
func (n *Notifier) AddNode(machineKey key.MachinePublic, c chan<- types.StateUpdate) {
log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("acquiring lock to add node")
defer log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("releasing lock to add node")
n.l.Lock()
defer n.l.Unlock()
@ -28,17 +29,17 @@ func (n *Notifier) AddNode(machineKey string, c chan<- types.StateUpdate) {
n.nodes = make(map[string]chan<- types.StateUpdate)
}
n.nodes[machineKey] = c
n.nodes[machineKey.String()] = c
log.Trace().
Str("machine_key", machineKey).
Str("machine_key", machineKey.ShortString()).
Int("open_chans", len(n.nodes)).
Msg("Added new channel")
}
func (n *Notifier) RemoveNode(machineKey string) {
log.Trace().Caller().Str("key", machineKey).Msg("acquiring lock to remove node")
defer log.Trace().Caller().Str("key", machineKey).Msg("releasing lock to remove node")
func (n *Notifier) RemoveNode(machineKey key.MachinePublic) {
log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("acquiring lock to remove node")
defer log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("releasing lock to remove node")
n.l.Lock()
defer n.l.Unlock()
@ -47,10 +48,10 @@ func (n *Notifier) RemoveNode(machineKey string) {
return
}
delete(n.nodes, machineKey)
delete(n.nodes, machineKey.String())
log.Trace().
Str("machine_key", machineKey).
Str("machine_key", machineKey.ShortString()).
Int("open_chans", len(n.nodes)).
Msg("Removed channel")
}

View file

@ -90,42 +90,28 @@ func (h *Headscale) determineTokenExpiration(idTokenExpiration time.Time) time.T
// RegisterOIDC redirects to the OIDC provider for authentication
// Puts NodeKey in cache so the callback can retrieve it using the oidc state param
// Listens in /oidc/register/:nKey.
// Listens in /oidc/register/:mKey.
func (h *Headscale) RegisterOIDC(
writer http.ResponseWriter,
req *http.Request,
) {
vars := mux.Vars(req)
nodeKeyStr, ok := vars["nkey"]
machineKeyStr, ok := vars["mkey"]
log.Debug().
Caller().
Str("node_key", nodeKeyStr).
Str("machine_key", machineKeyStr).
Bool("ok", ok).
Msg("Received oidc register call")
if !util.NodePublicKeyRegex.Match([]byte(nodeKeyStr)) {
log.Warn().Str("node_key", nodeKeyStr).Msg("Invalid node key passed to registration url")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusUnauthorized)
_, err := writer.Write([]byte("Unauthorized"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
return
}
// We need to make sure we dont open for XSS style injections, if the parameter that
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
// the template and log an error.
var nodeKey key.NodePublic
err := nodeKey.UnmarshalText(
[]byte(nodeKeyStr),
var machineKey key.MachinePublic
err := machineKey.UnmarshalText(
[]byte(machineKeyStr),
)
if !ok || nodeKeyStr == "" || err != nil {
if err != nil {
log.Warn().
Err(err).
Msg("Failed to parse incoming nodekey in OIDC registration")
@ -154,7 +140,7 @@ func (h *Headscale) RegisterOIDC(
// place the node key into the state cache, so it can be retrieved later
h.registrationCache.Set(
stateStr,
nodeKey,
machineKey,
registerCacheExpiration,
)
@ -232,7 +218,7 @@ func (h *Headscale) OIDCCallback(
return
}
nodeKey, nodeExists, err := h.validateNodeForOIDCCallback(
machineKey, nodeExists, err := h.validateNodeForOIDCCallback(
writer,
state,
claims,
@ -255,7 +241,7 @@ func (h *Headscale) OIDCCallback(
return
}
if err := h.registerNodeForOIDCCallback(writer, user, nodeKey, idTokenExpiry); err != nil {
if err := h.registerNodeForOIDCCallback(writer, user, machineKey, idTokenExpiry); err != nil {
return
}
@ -462,10 +448,10 @@ func (h *Headscale) validateNodeForOIDCCallback(
state string,
claims *IDTokenClaims,
expiry time.Time,
) (*key.NodePublic, bool, error) {
) (*key.MachinePublic, bool, error) {
// retrieve nodekey from state cache
nodeKeyIf, nodeKeyFound := h.registrationCache.Get(state)
if !nodeKeyFound {
machineKeyIf, machineKeyFound := h.registrationCache.Get(state)
if !machineKeyFound {
log.Trace().
Msg("requested node state key expired before authorisation completed")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
@ -478,11 +464,11 @@ func (h *Headscale) validateNodeForOIDCCallback(
return nil, false, errOIDCNodeKeyMissing
}
var nodeKey key.NodePublic
nodeKey, nodeKeyOK := nodeKeyIf.(key.NodePublic)
if !nodeKeyOK {
var machineKey key.MachinePublic
machineKey, machineKeyOK := machineKeyIf.(key.MachinePublic)
if !machineKeyOK {
log.Trace().
Interface("got", nodeKeyIf).
Interface("got", machineKeyIf).
Msg("requested node state key is not a nodekey")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
@ -498,7 +484,7 @@ func (h *Headscale) validateNodeForOIDCCallback(
// The error is not important, because if it does not
// exist, then this is a new node and we will move
// on to registration.
node, _ := h.db.GetNodeByNodeKey(nodeKey)
node, _ := h.db.GetNodeByMachineKey(machineKey)
if node != nil {
log.Trace().
@ -553,7 +539,7 @@ func (h *Headscale) validateNodeForOIDCCallback(
return nil, true, nil
}
return &nodeKey, false, nil
return &machineKey, false, nil
}
func getUserName(
@ -624,13 +610,13 @@ func (h *Headscale) findOrCreateNewUserForOIDCCallback(
func (h *Headscale) registerNodeForOIDCCallback(
writer http.ResponseWriter,
user *types.User,
nodeKey *key.NodePublic,
machineKey *key.MachinePublic,
expiry time.Time,
) error {
if _, err := h.db.RegisterNodeFromAuthCallback(
// TODO(kradalby): find a better way to use the cache across modules
h.registrationCache,
nodeKey.String(),
*machineKey,
user.Name,
&expiry,
util.RegisterMethodOIDC,

View file

@ -14,12 +14,29 @@ import (
"go4.org/netipx"
"gopkg.in/check.v1"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
var ipComparer = cmp.Comparer(func(x, y netip.Addr) bool {
return x.Compare(y) == 0
})
var mkeyComparer = cmp.Comparer(func(x, y key.MachinePublic) bool {
return x.String() == y.String()
})
var nkeyComparer = cmp.Comparer(func(x, y key.NodePublic) bool {
return x.String() == y.String()
})
var dkeyComparer = cmp.Comparer(func(x, y key.DiscoPublic) bool {
return x.String() == y.String()
})
var keyComparers []cmp.Option = []cmp.Option{
mkeyComparer, nkeyComparer, dkeyComparer,
}
func Test(t *testing.T) {
check.TestingT(t)
}
@ -951,7 +968,7 @@ func Test_listNodesInUser(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
got := filterNodesByUser(test.args.nodes, test.args.user)
if diff := cmp.Diff(test.want, got); diff != "" {
if diff := cmp.Diff(test.want, got, keyComparers...); diff != "" {
t.Errorf("listNodesInUser() = (-want +got):\n%s", diff)
}
})
@ -1704,7 +1721,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
test.args.nodes,
test.args.user,
)
if diff := cmp.Diff(test.want, got, ipComparer); diff != "" {
if diff := cmp.Diff(test.want, got, ipComparer, mkeyComparer, nkeyComparer, dkeyComparer); diff != "" {
t.Errorf("excludeCorrectlyTaggedNodes() (-want +got):\n%s", diff)
}
})
@ -2723,7 +2740,7 @@ func Test_getFilteredByACLPeers(t *testing.T) {
tt.args.nodes,
tt.args.rules,
)
if diff := cmp.Diff(tt.want, got, ipComparer); diff != "" {
if diff := cmp.Diff(tt.want, got, ipComparer, mkeyComparer, nkeyComparer, dkeyComparer); diff != "" {
t.Errorf("FilterNodesByACL() unexpected result (-want +got):\n%s", diff)
}
})
@ -2986,9 +3003,6 @@ func TestValidExpandTagOwnersInSources(t *testing.T) {
node := &types.Node{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testnodes",
IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.1")},
UserID: 0,
@ -3041,9 +3055,6 @@ func TestInvalidTagValidUser(t *testing.T) {
node := &types.Node{
ID: 1,
MachineKey: "12345",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testnodes",
IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.1")},
UserID: 1,
@ -3095,9 +3106,6 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) {
node := &types.Node{
ID: 1,
MachineKey: "12345",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testnodes",
IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.1")},
UserID: 1,
@ -3159,9 +3167,6 @@ func TestValidTagInvalidUser(t *testing.T) {
node := &types.Node{
ID: 1,
MachineKey: "12345",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "webserver",
IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.1")},
UserID: 1,
@ -3179,9 +3184,6 @@ func TestValidTagInvalidUser(t *testing.T) {
nodes2 := &types.Node{
ID: 2,
MachineKey: "56789",
NodeKey: "bar2",
DiscoKey: "faab",
Hostname: "user",
IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.2")},
UserID: 1,

View file

@ -34,7 +34,7 @@ func logPollFunc(
Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Str("node_key", node.NodeKey).
Str("node_key", node.NodeKey.ShortString()).
Str("node", node.Hostname).
Msg(msg)
},
@ -45,7 +45,7 @@ func logPollFunc(
Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Str("node_key", node.NodeKey).
Str("node_key", node.NodeKey.ShortString()).
Str("node", node.Hostname).
Err(err).
Msg(msg)
@ -81,7 +81,7 @@ func (h *Headscale) handlePoll(
Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Str("node_key", node.NodeKey).
Str("node_key", node.NodeKey.ShortString()).
Str("node", node.Hostname).
Strs("endpoints", node.Endpoints).
Msg("Received endpoint update")
@ -90,8 +90,8 @@ func (h *Headscale) handlePoll(
node.LastSeen = &now
node.Hostname = mapRequest.Hostinfo.Hostname
node.HostInfo = types.HostInfo(*mapRequest.Hostinfo)
node.DiscoKey = mapRequest.DiscoKey.String()
node.Endpoints = mapRequest.Endpoints
node.DiscoKey = mapRequest.DiscoKey
node.SetEndpointsFromAddrPorts(mapRequest.Endpoints)
if err := h.db.NodeSave(node); err != nil {
logErr(err, "Failed to persist/update node in the database")
@ -113,7 +113,7 @@ func (h *Headscale) handlePoll(
Type: types.StatePeerChanged,
Changed: types.Nodes{node},
},
node.MachineKey)
node.MachineKey.String())
writer.WriteHeader(http.StatusOK)
if f, ok := writer.(http.Flusher); ok {
@ -143,8 +143,8 @@ func (h *Headscale) handlePoll(
node.LastSeen = &now
node.Hostname = mapRequest.Hostinfo.Hostname
node.HostInfo = types.HostInfo(*mapRequest.Hostinfo)
node.DiscoKey = mapRequest.DiscoKey.String()
node.Endpoints = mapRequest.Endpoints
node.DiscoKey = mapRequest.DiscoKey
node.SetEndpointsFromAddrPorts(mapRequest.Endpoints)
// When a node connects to control, list the peers it has at
// that given point, further updates are kept in memory in
@ -222,7 +222,7 @@ func (h *Headscale) handlePoll(
Type: types.StatePeerChanged,
Changed: types.Nodes{node},
},
node.MachineKey)
node.MachineKey.String())
// Set up the client stream
h.pollNetMapStreamWG.Add(1)
@ -342,7 +342,7 @@ func (h *Headscale) handlePoll(
Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Str("node_key", node.NodeKey).
Str("node_key", node.NodeKey.ShortString()).
Str("node", node.Hostname).
TimeDiff("timeSpent", time.Now(), now).
Msg("update sent")

View file

@ -13,6 +13,7 @@ import (
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"go4.org/netipx"
"google.golang.org/protobuf/types/known/timestamppb"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
@ -24,10 +25,30 @@ var (
// Node is a Headscale client.
type Node struct {
ID uint64 `gorm:"primary_key"`
MachineKey string `gorm:"type:varchar(64);unique_index"`
NodeKey string
DiscoKey string
ID uint64 `gorm:"primary_key"`
// MachineKeyValue is the string representation of MachineKey
// it is _only_ used for reading and writing the key to the
// database and should not be used.
// Use MachineKey instead.
MachineKeyValue string `gorm:"column:machine_key;unique_index"`
// NodeKeyValue is the string representation of NodeKey
// it is _only_ used for reading and writing the key to the
// database and should not be used.
// Use NodeKey instead.
NodeKeyValue string `gorm:"column:node_key"`
// DiscoKeyValue is the string representation of DiscoKey
// it is _only_ used for reading and writing the key to the
// database and should not be used.
// Use DiscoKey instead.
DiscoKeyValue string `gorm:"column:disco_key"`
MachineKey key.MachinePublic `gorm:"-"`
NodeKey key.NodePublic `gorm:"-"`
DiscoKey key.DiscoPublic `gorm:"-"`
IPAddresses NodeAddresses
// Hostname represents the name given by the Tailscale
@ -174,6 +195,31 @@ func (node Node) IsExpired() bool {
return time.Now().UTC().After(*node.Expiry)
}
// TODO(kradalby): Try to replace the types in the DB to be correct.
func (node *Node) EndpointsToAddrPort() ([]netip.AddrPort, error) {
var ret []netip.AddrPort
for _, ep := range node.Endpoints {
addrPort, err := netip.ParseAddrPort(ep)
if err != nil {
return nil, err
}
ret = append(ret, addrPort)
}
return ret, nil
}
// TODO(kradalby): Try to replace the types in the DB to be correct.
func (node *Node) SetEndpointsFromAddrPorts(in []netip.AddrPort) {
var strs StringList
for _, addrPort := range in {
strs = append(strs, addrPort.String())
}
node.Endpoints = strs
}
// IsOnline returns if the node is connected to Headscale.
// This is really a naive implementation, as we don't really see
// if there is a working connection between the client and the server.
@ -226,13 +272,52 @@ func (nodes Nodes) FilterByIP(ip netip.Addr) Nodes {
return found
}
// BeforeSave is a hook that ensures that some values that
// cannot be directly marshalled into database values are stored
// correctly in the database.
// This currently means storing the keys as strings.
func (n *Node) BeforeSave(tx *gorm.DB) (err error) {
n.MachineKeyValue = n.MachineKey.String()
n.NodeKeyValue = n.NodeKey.String()
n.DiscoKeyValue = n.DiscoKey.String()
return
}
// AfterFind is a hook that ensures that Node objects fields that
// has a different type in the database is unwrapped and populated
// correctly.
// This currently unmarshals all the keys, stored as strings, into
// the proper types.
func (n *Node) AfterFind(tx *gorm.DB) (err error) {
var machineKey key.MachinePublic
if err := machineKey.UnmarshalText([]byte(n.MachineKeyValue)); err != nil {
return err
}
n.MachineKey = machineKey
var nodeKey key.NodePublic
if err := nodeKey.UnmarshalText([]byte(n.NodeKeyValue)); err != nil {
return err
}
n.NodeKey = nodeKey
var discoKey key.DiscoPublic
if err := discoKey.UnmarshalText([]byte(n.DiscoKeyValue)); err != nil {
return err
}
n.DiscoKey = discoKey
return
}
func (node *Node) Proto() *v1.Node {
nodeProto := &v1.Node{
Id: node.ID,
MachineKey: node.MachineKey,
MachineKey: node.MachineKey.String(),
NodeKey: node.NodeKey,
DiscoKey: node.DiscoKey,
NodeKey: node.NodeKey.String(),
DiscoKey: node.DiscoKey.String(),
IpAddresses: node.IPAddresses.StringSlice(),
Name: node.Hostname,
GivenName: node.GivenName,
@ -289,47 +374,6 @@ func (node *Node) GetFQDN(dnsConfig *tailcfg.DNSConfig, baseDomain string) (stri
return hostname, nil
}
func (node *Node) MachinePublicKey() (key.MachinePublic, error) {
var machineKey key.MachinePublic
if node.MachineKey != "" {
err := machineKey.UnmarshalText(
[]byte(node.MachineKey),
)
if err != nil {
return key.MachinePublic{}, fmt.Errorf("failed to parse machine public key: %w", err)
}
}
return machineKey, nil
}
func (node *Node) DiscoPublicKey() (key.DiscoPublic, error) {
var discoKey key.DiscoPublic
if node.DiscoKey != "" {
err := discoKey.UnmarshalText(
[]byte(node.DiscoKey),
)
if err != nil {
return key.DiscoPublic{}, fmt.Errorf("failed to parse disco public key: %w", err)
}
} else {
discoKey = key.DiscoPublic{}
}
return discoKey, nil
}
func (node *Node) NodePublicKey() (key.NodePublic, error) {
var nodeKey key.NodePublic
err := nodeKey.UnmarshalText([]byte(node.NodeKey))
if err != nil {
return key.NodePublic{}, fmt.Errorf("failed to parse node public key: %w", err)
}
return nodeKey, nil
}
func (node Node) String() string {
return node.Hostname
}