Use tailscale key types instead of strings (#1609)
* upgrade tailscale Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * make Node object use actualy tailscale key types This commit changes the Node struct to have both a field for strings to store the keys in the database and a dedicated Key for each type of key. The keys are populated and stored with Gorm hooks to ensure the data is stored in the db. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * use key types throughout the code Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * make sure machinekey is concistently used Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * use machine key in auth url Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * fix web register Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * use key type in notifier Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * fix relogin with webauth Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> --------- Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
c0fd06e3f5
commit
ed4e19996b
22 changed files with 550 additions and 471 deletions
|
@ -449,10 +449,10 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
|
|||
|
||||
router.HandleFunc("/health", h.HealthHandler).Methods(http.MethodGet)
|
||||
router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet)
|
||||
router.HandleFunc("/register/{nkey}", h.RegisterWebAPI).Methods(http.MethodGet)
|
||||
router.HandleFunc("/register/{mkey}", h.RegisterWebAPI).Methods(http.MethodGet)
|
||||
h.addLegacyHandlers(router)
|
||||
|
||||
router.HandleFunc("/oidc/register/{nkey}", h.RegisterOIDC).Methods(http.MethodGet)
|
||||
router.HandleFunc("/oidc/register/{mkey}", h.RegisterOIDC).Methods(http.MethodGet)
|
||||
router.HandleFunc("/oidc/callback", h.OIDCCallback).Methods(http.MethodGet)
|
||||
router.HandleFunc("/apple", h.AppleConfigMessage).Methods(http.MethodGet)
|
||||
router.HandleFunc("/apple/{platform}", h.ApplePlatformConfig).
|
||||
|
|
|
@ -45,7 +45,7 @@ func (h *Headscale) handleRegister(
|
|||
// is that the client will hammer headscale with requests until it gets a
|
||||
// successful RegisterResponse.
|
||||
if registerRequest.Followup != "" {
|
||||
if _, ok := h.registrationCache.Get(registerRequest.NodeKey.String()); ok {
|
||||
if _, ok := h.registrationCache.Get(machineKey.String()); ok {
|
||||
log.Debug().
|
||||
Caller().
|
||||
Str("node", registerRequest.Hostinfo.Hostname).
|
||||
|
@ -78,7 +78,7 @@ func (h *Headscale) handleRegister(
|
|||
Msg("New node not yet in the database")
|
||||
|
||||
givenName, err := h.db.GenerateGivenName(
|
||||
machineKey.String(),
|
||||
machineKey,
|
||||
registerRequest.Hostinfo.Hostname,
|
||||
)
|
||||
if err != nil {
|
||||
|
@ -97,10 +97,10 @@ func (h *Headscale) handleRegister(
|
|||
// We create the node and then keep it around until a callback
|
||||
// happens
|
||||
newNode := types.Node{
|
||||
MachineKey: machineKey.String(),
|
||||
MachineKey: machineKey,
|
||||
Hostname: registerRequest.Hostinfo.Hostname,
|
||||
GivenName: givenName,
|
||||
NodeKey: registerRequest.NodeKey.String(),
|
||||
NodeKey: registerRequest.NodeKey,
|
||||
LastSeen: &now,
|
||||
Expiry: &time.Time{},
|
||||
}
|
||||
|
@ -116,7 +116,7 @@ func (h *Headscale) handleRegister(
|
|||
}
|
||||
|
||||
h.registrationCache.Set(
|
||||
newNode.NodeKey,
|
||||
machineKey.String(),
|
||||
newNode,
|
||||
registerCacheExpiration,
|
||||
)
|
||||
|
@ -134,11 +134,7 @@ func (h *Headscale) handleRegister(
|
|||
// (juan): For a while we had a bug where we were not storing the MachineKey for the nodes using the TS2021,
|
||||
// due to a misunderstanding of the protocol https://github.com/juanfont/headscale/issues/1054
|
||||
// So if we have a not valid MachineKey (but we were able to fetch the node with the NodeKeys), we update it.
|
||||
var storedMachineKey key.MachinePublic
|
||||
err = storedMachineKey.UnmarshalText(
|
||||
[]byte(node.MachineKey),
|
||||
)
|
||||
if err != nil || storedMachineKey.IsZero() {
|
||||
if err != nil || node.MachineKey.IsZero() {
|
||||
if err := h.db.NodeSetMachineKey(node, machineKey); err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
@ -156,7 +152,7 @@ func (h *Headscale) handleRegister(
|
|||
// - Trying to log out (sending a expiry in the past)
|
||||
// - A valid, registered node, looking for /map
|
||||
// - Expired node wanting to reauthenticate
|
||||
if node.NodeKey == registerRequest.NodeKey.String() {
|
||||
if node.NodeKey.String() == registerRequest.NodeKey.String() {
|
||||
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
|
||||
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
|
||||
if !registerRequest.Expiry.IsZero() &&
|
||||
|
@ -176,7 +172,7 @@ func (h *Headscale) handleRegister(
|
|||
}
|
||||
|
||||
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
|
||||
if node.NodeKey == registerRequest.OldNodeKey.String() &&
|
||||
if node.NodeKey.String() == registerRequest.OldNodeKey.String() &&
|
||||
!node.IsExpired() {
|
||||
h.handleNodeKeyRefresh(
|
||||
writer,
|
||||
|
@ -207,9 +203,9 @@ func (h *Headscale) handleRegister(
|
|||
// we need to make sure the NodeKey matches the one in the request
|
||||
// TODO(juan): What happens when using fast user switching between two
|
||||
// headscale-managed tailnets?
|
||||
node.NodeKey = registerRequest.NodeKey.String()
|
||||
node.NodeKey = registerRequest.NodeKey
|
||||
h.registrationCache.Set(
|
||||
registerRequest.NodeKey.String(),
|
||||
machineKey.String(),
|
||||
*node,
|
||||
registerCacheExpiration,
|
||||
)
|
||||
|
@ -294,7 +290,7 @@ func (h *Headscale) handleAuthKey(
|
|||
Str("node", registerRequest.Hostinfo.Hostname).
|
||||
Msg("Authentication key was valid, proceeding to acquire IP addresses")
|
||||
|
||||
nodeKey := registerRequest.NodeKey.String()
|
||||
nodeKey := registerRequest.NodeKey
|
||||
|
||||
// retrieve node information if it exist
|
||||
// The error is not important, because if it does not
|
||||
|
@ -342,7 +338,7 @@ func (h *Headscale) handleAuthKey(
|
|||
} else {
|
||||
now := time.Now().UTC()
|
||||
|
||||
givenName, err := h.db.GenerateGivenName(machineKey.String(), registerRequest.Hostinfo.Hostname)
|
||||
givenName, err := h.db.GenerateGivenName(machineKey, registerRequest.Hostinfo.Hostname)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
@ -359,7 +355,7 @@ func (h *Headscale) handleAuthKey(
|
|||
Hostname: registerRequest.Hostinfo.Hostname,
|
||||
GivenName: givenName,
|
||||
UserID: pak.User.ID,
|
||||
MachineKey: machineKey.String(),
|
||||
MachineKey: machineKey,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
Expiry: ®isterRequest.Expiry,
|
||||
NodeKey: nodeKey,
|
||||
|
@ -460,12 +456,12 @@ func (h *Headscale) handleNewNode(
|
|||
resp.AuthURL = fmt.Sprintf(
|
||||
"%s/oidc/register/%s",
|
||||
strings.TrimSuffix(h.cfg.ServerURL, "/"),
|
||||
registerRequest.NodeKey,
|
||||
machineKey.String(),
|
||||
)
|
||||
} else {
|
||||
resp.AuthURL = fmt.Sprintf("%s/register/%s",
|
||||
strings.TrimSuffix(h.cfg.ServerURL, "/"),
|
||||
registerRequest.NodeKey)
|
||||
machineKey.String())
|
||||
}
|
||||
|
||||
respBody, err := mapper.MarshalResponse(resp, isNoise, h.privateKey2019, machineKey)
|
||||
|
@ -715,11 +711,11 @@ func (h *Headscale) handleNodeExpiredOrLoggedOut(
|
|||
if h.oauth2Config != nil {
|
||||
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
|
||||
strings.TrimSuffix(h.cfg.ServerURL, "/"),
|
||||
registerRequest.NodeKey)
|
||||
machineKey.String())
|
||||
} else {
|
||||
resp.AuthURL = fmt.Sprintf("%s/register/%s",
|
||||
strings.TrimSuffix(h.cfg.ServerURL, "/"),
|
||||
registerRequest.NodeKey)
|
||||
machineKey.String())
|
||||
}
|
||||
|
||||
respBody, err := mapper.MarshalResponse(resp, isNoise, h.privateKey2019, machineKey)
|
||||
|
|
|
@ -2,6 +2,7 @@ package db
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
@ -99,7 +100,7 @@ func NewHeadscaleDatabase(
|
|||
// node was registered.
|
||||
_ = dbConn.Migrator().RenameColumn(&types.Node{}, "nickname", "given_name")
|
||||
|
||||
// If the MacNodehine table has a column for registered,
|
||||
// If the Node table has a column for registered,
|
||||
// find all occourences of "false" and drop them. Then
|
||||
// remove the column.
|
||||
if dbConn.Migrator().HasColumn(&types.Node{}, "registered") {
|
||||
|
@ -114,13 +115,13 @@ func NewHeadscaleDatabase(
|
|||
for _, node := range nodes {
|
||||
log.Info().
|
||||
Str("node", node.Hostname).
|
||||
Str("machine_key", node.MachineKey).
|
||||
Str("machine_key", node.MachineKey.ShortString()).
|
||||
Msg("Deleting unregistered node")
|
||||
if err := dbConn.Delete(&types.Node{}, node.ID).Error; err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("node", node.Hostname).
|
||||
Str("machine_key", node.MachineKey).
|
||||
Str("machine_key", node.MachineKey.ShortString()).
|
||||
Msg("Error deleting unregistered node")
|
||||
}
|
||||
}
|
||||
|
@ -136,6 +137,50 @@ func NewHeadscaleDatabase(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
err = dbConn.AutoMigrate(&types.Node{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Ensure all keys have correct prefixes
|
||||
// https://github.com/tailscale/tailscale/blob/main/types/key/node.go#L35
|
||||
type result struct {
|
||||
ID uint64
|
||||
MachineKey string
|
||||
NodeKey string
|
||||
DiscoKey string
|
||||
}
|
||||
var results []result
|
||||
err = db.db.Raw("SELECT id, node_key, machine_key, disco_key FROM nodes").Find(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, node := range results {
|
||||
mKey := node.MachineKey
|
||||
if !strings.HasPrefix(node.MachineKey, "mkey:") {
|
||||
mKey = "mkey:" + node.MachineKey
|
||||
}
|
||||
nKey := node.NodeKey
|
||||
if !strings.HasPrefix(node.NodeKey, "nodekey:") {
|
||||
nKey = "nodekey:" + node.NodeKey
|
||||
}
|
||||
|
||||
dKey := node.DiscoKey
|
||||
if !strings.HasPrefix(node.DiscoKey, "discokey:") {
|
||||
dKey = "discokey:" + node.DiscoKey
|
||||
}
|
||||
|
||||
err := db.db.Exec("UPDATE nodes SET machine_key = @mKey, node_key = @nKey, disco_key = @dKey WHERE ID = @id",
|
||||
sql.Named("mKey", mKey),
|
||||
sql.Named("nKey", nKey),
|
||||
sql.Named("dKey", dKey),
|
||||
sql.Named("id", node.ID)).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if dbConn.Migrator().HasColumn(&types.Node{}, "enabled_routes") {
|
||||
log.Info().Msgf("Database has legacy enabled_routes column in node, migrating...")
|
||||
|
||||
|
@ -195,11 +240,6 @@ func NewHeadscaleDatabase(
|
|||
}
|
||||
}
|
||||
|
||||
err = dbConn.AutoMigrate(&types.Node{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if dbConn.Migrator().HasColumn(&types.Node{}, "given_name") {
|
||||
nodes := types.Nodes{}
|
||||
if err := dbConn.Find(&nodes).Error; err != nil {
|
||||
|
@ -253,27 +293,6 @@ func NewHeadscaleDatabase(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// Ensure all keys have correct prefixes
|
||||
// https://github.com/tailscale/tailscale/blob/main/types/key/node.go#L35
|
||||
nodes := types.Nodes{}
|
||||
if err := dbConn.Find(&nodes).Error; err != nil {
|
||||
log.Error().Err(err).Msg("Error accessing db")
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
if !strings.HasPrefix(node.DiscoKey, "discokey:") {
|
||||
node.DiscoKey = "discokey:" + node.DiscoKey
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(node.NodeKey, "nodekey:") {
|
||||
node.NodeKey = "nodekey:" + node.NodeKey
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(node.MachineKey, "mkey:") {
|
||||
node.MachineKey = "mkey:" + node.MachineKey
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(kradalby): is this needed?
|
||||
err = db.setValue("db_version", dbVersion)
|
||||
|
||||
|
|
|
@ -55,7 +55,7 @@ func (hsdb *HSDatabase) listPeers(node *types.Node) (types.Nodes, error) {
|
|||
Preload("User").
|
||||
Preload("Routes").
|
||||
Where("node_key <> ?",
|
||||
node.NodeKey).Find(&nodes).Error; err != nil {
|
||||
node.NodeKey.String()).Find(&nodes).Error; err != nil {
|
||||
return types.Nodes{}, err
|
||||
}
|
||||
|
||||
|
@ -268,7 +268,7 @@ func (hsdb *HSDatabase) SetTags(
|
|||
hsdb.notifier.NotifyWithIgnore(types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
Changed: types.Nodes{node},
|
||||
}, node.MachineKey)
|
||||
}, node.MachineKey.String())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -304,7 +304,7 @@ func (hsdb *HSDatabase) RenameNode(node *types.Node, newName string) error {
|
|||
hsdb.notifier.NotifyWithIgnore(types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
Changed: types.Nodes{node},
|
||||
}, node.MachineKey)
|
||||
}, node.MachineKey.String())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -330,7 +330,7 @@ func (hsdb *HSDatabase) nodeSetExpiry(node *types.Node, expiry time.Time) error
|
|||
hsdb.notifier.NotifyWithIgnore(types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
Changed: types.Nodes{node},
|
||||
}, node.MachineKey)
|
||||
}, node.MachineKey.String())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -376,7 +376,7 @@ func (hsdb *HSDatabase) UpdateLastSeen(node *types.Node) error {
|
|||
|
||||
func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
|
||||
cache *cache.Cache,
|
||||
nodeKeyStr string,
|
||||
mkey key.MachinePublic,
|
||||
userName string,
|
||||
nodeExpiry *time.Time,
|
||||
registrationMethod string,
|
||||
|
@ -384,20 +384,14 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
|
|||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
|
||||
nodeKey := key.NodePublic{}
|
||||
err := nodeKey.UnmarshalText([]byte(nodeKeyStr))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("nodeKey", nodeKey.ShortString()).
|
||||
Str("machine_key", mkey.ShortString()).
|
||||
Str("userName", userName).
|
||||
Str("registrationMethod", registrationMethod).
|
||||
Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)).
|
||||
Msg("Registering node from API/CLI or auth callback")
|
||||
|
||||
if nodeInterface, ok := cache.Get(nodeKey.String()); ok {
|
||||
if nodeInterface, ok := cache.Get(mkey.String()); ok {
|
||||
if registrationNode, ok := nodeInterface.(types.Node); ok {
|
||||
user, err := hsdb.getUser(userName)
|
||||
if err != nil {
|
||||
|
@ -425,7 +419,7 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
|
|||
)
|
||||
|
||||
if err == nil {
|
||||
cache.Delete(nodeKeyStr)
|
||||
cache.Delete(mkey.String())
|
||||
}
|
||||
|
||||
return node, err
|
||||
|
@ -448,8 +442,8 @@ func (hsdb *HSDatabase) RegisterNode(node types.Node) (*types.Node, error) {
|
|||
func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) {
|
||||
log.Debug().
|
||||
Str("node", node.Hostname).
|
||||
Str("machine_key", node.MachineKey).
|
||||
Str("node_key", node.NodeKey).
|
||||
Str("machine_key", node.MachineKey.ShortString()).
|
||||
Str("node_key", node.NodeKey.ShortString()).
|
||||
Str("user", node.User.Name).
|
||||
Msg("Registering node")
|
||||
|
||||
|
@ -464,8 +458,8 @@ func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) {
|
|||
log.Trace().
|
||||
Caller().
|
||||
Str("node", node.Hostname).
|
||||
Str("machine_key", node.MachineKey).
|
||||
Str("node_key", node.NodeKey).
|
||||
Str("machine_key", node.MachineKey.ShortString()).
|
||||
Str("node_key", node.NodeKey.ShortString()).
|
||||
Str("user", node.User.Name).
|
||||
Msg("Node authorized again")
|
||||
|
||||
|
@ -507,7 +501,7 @@ func (hsdb *HSDatabase) NodeSetNodeKey(node *types.Node, nodeKey key.NodePublic)
|
|||
defer hsdb.mu.Unlock()
|
||||
|
||||
if err := hsdb.db.Model(node).Updates(types.Node{
|
||||
NodeKey: nodeKey.String(),
|
||||
NodeKey: nodeKey,
|
||||
}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -524,7 +518,7 @@ func (hsdb *HSDatabase) NodeSetMachineKey(
|
|||
defer hsdb.mu.Unlock()
|
||||
|
||||
if err := hsdb.db.Model(node).Updates(types.Node{
|
||||
MachineKey: machineKey.String(),
|
||||
MachineKey: machineKey,
|
||||
}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -703,7 +697,7 @@ func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) erro
|
|||
hsdb.notifier.NotifyWithIgnore(types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
Changed: types.Nodes{node},
|
||||
}, node.MachineKey)
|
||||
}, node.MachineKey.String())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -734,7 +728,7 @@ func generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
|
|||
return normalizedHostname, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GenerateGivenName(machineKey string, suppliedName string) (string, error) {
|
||||
func (hsdb *HSDatabase) GenerateGivenName(mkey key.MachinePublic, suppliedName string) (string, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
|
@ -749,17 +743,22 @@ func (hsdb *HSDatabase) GenerateGivenName(machineKey string, suppliedName string
|
|||
return "", err
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
if node.MachineKey != machineKey && node.GivenName == givenName {
|
||||
postfixedName, err := generateGivenName(suppliedName, true)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
givenName = postfixedName
|
||||
var nodeFound *types.Node
|
||||
for idx, node := range nodes {
|
||||
if node.GivenName == givenName {
|
||||
nodeFound = nodes[idx]
|
||||
}
|
||||
}
|
||||
|
||||
if nodeFound != nil && nodeFound.MachineKey.String() != mkey.String() {
|
||||
postfixedName, err := generateGivenName(suppliedName, true)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
givenName = postfixedName
|
||||
}
|
||||
|
||||
return givenName, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -25,11 +25,13 @@ func (s *Suite) TestGetNode(c *check.C) {
|
|||
_, err = db.GetNode("test", "testnode")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
nodeKey := key.NewNode()
|
||||
machineKey := key.NewMachine()
|
||||
|
||||
node := &types.Node{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
Hostname: "testnode",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
|
@ -51,11 +53,13 @@ func (s *Suite) TestGetNodeByID(c *check.C) {
|
|||
_, err = db.GetNodeByID(0)
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
nodeKey := key.NewNode()
|
||||
machineKey := key.NewMachine()
|
||||
|
||||
node := types.Node{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
Hostname: "testnode",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
|
@ -82,9 +86,8 @@ func (s *Suite) TestGetNodeByNodeKey(c *check.C) {
|
|||
|
||||
node := types.Node{
|
||||
ID: 0,
|
||||
MachineKey: machineKey.Public().String(),
|
||||
NodeKey: nodeKey.Public().String(),
|
||||
DiscoKey: "faa",
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
Hostname: "testnode",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
|
@ -113,9 +116,8 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) {
|
|||
|
||||
node := types.Node{
|
||||
ID: 0,
|
||||
MachineKey: machineKey.Public().String(),
|
||||
NodeKey: nodeKey.Public().String(),
|
||||
DiscoKey: "faa",
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
Hostname: "testnode",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
|
@ -130,11 +132,14 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) {
|
|||
func (s *Suite) TestHardDeleteNode(c *check.C) {
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
nodeKey := key.NewNode()
|
||||
machineKey := key.NewMachine()
|
||||
|
||||
node := types.Node{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
Hostname: "testnode3",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
|
@ -160,11 +165,13 @@ func (s *Suite) TestListPeers(c *check.C) {
|
|||
c.Assert(err, check.NotNil)
|
||||
|
||||
for index := 0; index <= 10; index++ {
|
||||
nodeKey := key.NewNode()
|
||||
machineKey := key.NewMachine()
|
||||
|
||||
node := types.Node{
|
||||
ID: uint64(index),
|
||||
MachineKey: "foo" + strconv.Itoa(index),
|
||||
NodeKey: "bar" + strconv.Itoa(index),
|
||||
DiscoKey: "faa" + strconv.Itoa(index),
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
Hostname: "testnode" + strconv.Itoa(index),
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
|
@ -205,11 +212,13 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
|||
c.Assert(err, check.NotNil)
|
||||
|
||||
for index := 0; index <= 10; index++ {
|
||||
nodeKey := key.NewNode()
|
||||
machineKey := key.NewMachine()
|
||||
|
||||
node := types.Node{
|
||||
ID: uint64(index),
|
||||
MachineKey: "foo" + strconv.Itoa(index),
|
||||
NodeKey: "bar" + strconv.Itoa(index),
|
||||
DiscoKey: "faa" + strconv.Itoa(index),
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
IPAddresses: types.NodeAddresses{
|
||||
netip.MustParseAddr(fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1))),
|
||||
},
|
||||
|
@ -288,11 +297,13 @@ func (s *Suite) TestExpireNode(c *check.C) {
|
|||
_, err = db.GetNode("test", "testnode")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
nodeKey := key.NewNode()
|
||||
machineKey := key.NewMachine()
|
||||
|
||||
node := &types.Node{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
Hostname: "testnode",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
|
@ -345,11 +356,15 @@ func (s *Suite) TestGenerateGivenName(c *check.C) {
|
|||
_, err = db.GetNode("user-1", "testnode")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
nodeKey := key.NewNode()
|
||||
machineKey := key.NewMachine()
|
||||
|
||||
machineKey2 := key.NewMachine()
|
||||
|
||||
node := &types.Node{
|
||||
ID: 0,
|
||||
MachineKey: "node-key-1",
|
||||
NodeKey: "node-key-1",
|
||||
DiscoKey: "disco-key-1",
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
Hostname: "hostname-1",
|
||||
GivenName: "hostname-1",
|
||||
UserID: user1.ID,
|
||||
|
@ -358,25 +373,20 @@ func (s *Suite) TestGenerateGivenName(c *check.C) {
|
|||
}
|
||||
db.db.Save(node)
|
||||
|
||||
givenName, err := db.GenerateGivenName("node-key-2", "hostname-2")
|
||||
givenName, err := db.GenerateGivenName(machineKey2.Public(), "hostname-2")
|
||||
comment := check.Commentf("Same user, unique nodes, unique hostnames, no conflict")
|
||||
c.Assert(err, check.IsNil, comment)
|
||||
c.Assert(givenName, check.Equals, "hostname-2", comment)
|
||||
|
||||
givenName, err = db.GenerateGivenName("node-key-1", "hostname-1")
|
||||
givenName, err = db.GenerateGivenName(machineKey.Public(), "hostname-1")
|
||||
comment = check.Commentf("Same user, same node, same hostname, no conflict")
|
||||
c.Assert(err, check.IsNil, comment)
|
||||
c.Assert(givenName, check.Equals, "hostname-1", comment)
|
||||
|
||||
givenName, err = db.GenerateGivenName("node-key-2", "hostname-1")
|
||||
givenName, err = db.GenerateGivenName(machineKey2.Public(), "hostname-1")
|
||||
comment = check.Commentf("Same user, unique nodes, same hostname, conflict")
|
||||
c.Assert(err, check.IsNil, comment)
|
||||
c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", NodeGivenNameHashLength), comment)
|
||||
|
||||
givenName, err = db.GenerateGivenName("node-key-2", "hostname-1")
|
||||
comment = check.Commentf("Unique users, unique nodes, same hostname, conflict")
|
||||
c.Assert(err, check.IsNil, comment)
|
||||
c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", NodeGivenNameHashLength), comment)
|
||||
}
|
||||
|
||||
func (s *Suite) TestSetTags(c *check.C) {
|
||||
|
@ -389,11 +399,13 @@ func (s *Suite) TestSetTags(c *check.C) {
|
|||
_, err = db.GetNode("test", "testnode")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
nodeKey := key.NewNode()
|
||||
machineKey := key.NewMachine()
|
||||
|
||||
node := &types.Node{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
Hostname: "testnode",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
|
@ -565,6 +577,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
|
|||
c.Assert(err, check.IsNil)
|
||||
|
||||
nodeKey := key.NewNode()
|
||||
machineKey := key.NewMachine()
|
||||
|
||||
defaultRouteV4 := netip.MustParsePrefix("0.0.0.0/0")
|
||||
defaultRouteV6 := netip.MustParsePrefix("::/0")
|
||||
|
@ -574,9 +587,8 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
|
|||
|
||||
node := types.Node{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: nodeKey.Public().String(),
|
||||
DiscoKey: "faa",
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
Hostname: "test",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/netip"
|
||||
"os"
|
||||
"testing"
|
||||
|
@ -27,19 +28,22 @@ func (s *Suite) SetUpTest(c *check.C) {
|
|||
}
|
||||
|
||||
func (s *Suite) TearDownTest(c *check.C) {
|
||||
os.RemoveAll(tmpDir)
|
||||
// os.RemoveAll(tmpDir)
|
||||
}
|
||||
|
||||
func (s *Suite) ResetDB(c *check.C) {
|
||||
if len(tmpDir) != 0 {
|
||||
os.RemoveAll(tmpDir)
|
||||
}
|
||||
// if len(tmpDir) != 0 {
|
||||
// os.RemoveAll(tmpDir)
|
||||
// }
|
||||
|
||||
var err error
|
||||
tmpDir, err = os.MkdirTemp("", "autoygg-client-test")
|
||||
tmpDir, err = os.MkdirTemp("", "headscale-db-test-*")
|
||||
if err != nil {
|
||||
c.Fatal(err)
|
||||
}
|
||||
|
||||
log.Printf("database path: %s", tmpDir+"/headscale_test.db")
|
||||
|
||||
db, err = NewHeadscaleDatabase(
|
||||
"sqlite3",
|
||||
tmpDir+"/headscale_test.db",
|
||||
|
|
|
@ -172,12 +172,18 @@ func (api headscaleV1APIServer) RegisterNode(
|
|||
) (*v1.RegisterNodeResponse, error) {
|
||||
log.Trace().
|
||||
Str("user", request.GetUser()).
|
||||
Str("node_key", request.GetKey()).
|
||||
Str("machine_key", request.GetKey()).
|
||||
Msg("Registering node")
|
||||
|
||||
var mkey key.MachinePublic
|
||||
err := mkey.UnmarshalText([]byte(request.GetKey()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
node, err := api.h.db.RegisterNodeFromAuthCallback(
|
||||
api.h.registrationCache,
|
||||
request.GetKey(),
|
||||
mkey,
|
||||
request.GetUser(),
|
||||
nil,
|
||||
util.RegisterMethodCLI,
|
||||
|
@ -521,13 +527,22 @@ func (api headscaleV1APIServer) DebugCreateNode(
|
|||
Hostname: "DebugTestNode",
|
||||
}
|
||||
|
||||
givenName, err := api.h.db.GenerateGivenName(request.GetKey(), request.GetName())
|
||||
var mkey key.MachinePublic
|
||||
err = mkey.UnmarshalText([]byte(request.GetKey()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
givenName, err := api.h.db.GenerateGivenName(mkey, request.GetName())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nodeKey := key.NewNode()
|
||||
|
||||
newNode := types.Node{
|
||||
MachineKey: request.GetKey(),
|
||||
MachineKey: mkey,
|
||||
NodeKey: nodeKey.Public(),
|
||||
Hostname: request.GetName(),
|
||||
GivenName: givenName,
|
||||
User: *user,
|
||||
|
@ -538,14 +553,12 @@ func (api headscaleV1APIServer) DebugCreateNode(
|
|||
HostInfo: types.HostInfo(hostinfo),
|
||||
}
|
||||
|
||||
nodeKey := key.NodePublic{}
|
||||
err = nodeKey.UnmarshalText([]byte(request.GetKey()))
|
||||
if err != nil {
|
||||
log.Panic().Msg("can not add node for debug. invalid node key")
|
||||
}
|
||||
log.Debug().
|
||||
Str("machine_key", mkey.ShortString()).
|
||||
Msg("adding debug machine via CLI, appending to registration cache")
|
||||
|
||||
api.h.registrationCache.Set(
|
||||
nodeKey.String(),
|
||||
mkey.String(),
|
||||
newNode,
|
||||
registerCacheExpiration,
|
||||
)
|
||||
|
|
|
@ -12,7 +12,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
|
@ -207,33 +206,16 @@ func (h *Headscale) RegisterWebAPI(
|
|||
req *http.Request,
|
||||
) {
|
||||
vars := mux.Vars(req)
|
||||
nodeKeyStr, ok := vars["nkey"]
|
||||
|
||||
if !util.NodePublicKeyRegex.Match([]byte(nodeKeyStr)) {
|
||||
log.Warn().Str("node_key", nodeKeyStr).Msg("Invalid node key passed to registration url")
|
||||
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusUnauthorized)
|
||||
_, err := writer.Write([]byte("Unauthorized"))
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Failed to write response")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
machineKeyStr := vars["mkey"]
|
||||
|
||||
// We need to make sure we dont open for XSS style injections, if the parameter that
|
||||
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
|
||||
// the template and log an error.
|
||||
var nodeKey key.NodePublic
|
||||
err := nodeKey.UnmarshalText(
|
||||
[]byte(nodeKeyStr),
|
||||
var machineKey key.MachinePublic
|
||||
err := machineKey.UnmarshalText(
|
||||
[]byte(machineKeyStr),
|
||||
)
|
||||
|
||||
if !ok || nodeKeyStr == "" || err != nil {
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to parse incoming nodekey")
|
||||
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
|
@ -251,7 +233,7 @@ func (h *Headscale) RegisterWebAPI(
|
|||
|
||||
var content bytes.Buffer
|
||||
if err := registerWebAPITemplate.Execute(&content, registerWebAPITemplateConfig{
|
||||
Key: nodeKeyStr,
|
||||
Key: machineKey.String(),
|
||||
}); err != nil {
|
||||
log.Error().
|
||||
Str("func", "RegisterWebAPI").
|
||||
|
|
|
@ -368,17 +368,6 @@ func (m *Mapper) marshalMapResponse(
|
|||
) ([]byte, error) {
|
||||
atomic.AddUint64(&m.seq, 1)
|
||||
|
||||
var machineKey key.MachinePublic
|
||||
err := machineKey.UnmarshalText([]byte(node.MachineKey))
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Cannot parse client key")
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
jsonBody, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
|
@ -426,11 +415,11 @@ func (m *Mapper) marshalMapResponse(
|
|||
if compression == util.ZstdCompression {
|
||||
respBody = zstdEncode(jsonBody)
|
||||
if !m.isNoise { // if legacy protocol
|
||||
respBody = m.privateKey2019.SealTo(machineKey, respBody)
|
||||
respBody = m.privateKey2019.SealTo(node.MachineKey, respBody)
|
||||
}
|
||||
} else {
|
||||
if !m.isNoise { // if legacy protocol
|
||||
respBody = m.privateKey2019.SealTo(machineKey, jsonBody)
|
||||
respBody = m.privateKey2019.SealTo(node.MachineKey, jsonBody)
|
||||
} else {
|
||||
respBody = jsonBody
|
||||
}
|
||||
|
|
|
@ -166,10 +166,16 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
expire := time.Date(2500, time.November, 11, 23, 0, 0, 0, time.UTC)
|
||||
|
||||
mini := &types.Node{
|
||||
ID: 0,
|
||||
MachineKey: "mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507",
|
||||
NodeKey: "nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
|
||||
DiscoKey: "discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084",
|
||||
ID: 0,
|
||||
MachineKey: mustMK(
|
||||
"mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507",
|
||||
),
|
||||
NodeKey: mustNK(
|
||||
"nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
|
||||
),
|
||||
DiscoKey: mustDK(
|
||||
"discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084",
|
||||
),
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
|
||||
Hostname: "mini",
|
||||
GivenName: "mini",
|
||||
|
@ -226,7 +232,6 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
netip.MustParsePrefix("0.0.0.0/0"),
|
||||
netip.MustParsePrefix("192.168.0.0/24"),
|
||||
},
|
||||
Endpoints: []string{},
|
||||
DERP: "127.3.3.40:0",
|
||||
Hostinfo: hiview(tailcfg.Hostinfo{}),
|
||||
Created: created,
|
||||
|
@ -244,10 +249,16 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
}
|
||||
|
||||
peer1 := &types.Node{
|
||||
ID: 1,
|
||||
MachineKey: "mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507",
|
||||
NodeKey: "nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
|
||||
DiscoKey: "discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084",
|
||||
ID: 1,
|
||||
MachineKey: mustMK(
|
||||
"mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507",
|
||||
),
|
||||
NodeKey: mustNK(
|
||||
"nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
|
||||
),
|
||||
DiscoKey: mustDK(
|
||||
"discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084",
|
||||
),
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
|
||||
Hostname: "peer1",
|
||||
GivenName: "peer1",
|
||||
|
@ -278,7 +289,6 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
),
|
||||
Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
|
||||
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
|
||||
Endpoints: []string{},
|
||||
DERP: "127.3.3.40:0",
|
||||
Hostinfo: hiview(tailcfg.Hostinfo{}),
|
||||
Created: created,
|
||||
|
@ -296,10 +306,16 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
}
|
||||
|
||||
peer2 := &types.Node{
|
||||
ID: 2,
|
||||
MachineKey: "mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507",
|
||||
NodeKey: "nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
|
||||
DiscoKey: "discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084",
|
||||
ID: 2,
|
||||
MachineKey: mustMK(
|
||||
"mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507",
|
||||
),
|
||||
NodeKey: mustNK(
|
||||
"nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
|
||||
),
|
||||
DiscoKey: mustDK(
|
||||
"discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084",
|
||||
),
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
|
||||
Hostname: "peer2",
|
||||
GivenName: "peer2",
|
||||
|
|
|
@ -52,21 +52,6 @@ func tailNode(
|
|||
baseDomain string,
|
||||
randomClientPort bool,
|
||||
) (*tailcfg.Node, error) {
|
||||
nodeKey, err := node.NodePublicKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
machineKey, err := node.MachinePublicKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
discoKey, err := node.DiscoPublicKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
addrs := node.IPAddresses.Prefixes()
|
||||
|
||||
allowedIPs := append(
|
||||
|
@ -112,6 +97,11 @@ func tailNode(
|
|||
tags, _ := pol.TagsOfNode(node)
|
||||
tags = lo.Uniq(append(tags, node.ForcedTags...))
|
||||
|
||||
endpoints, err := node.EndpointsToAddrPort()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tNode := tailcfg.Node{
|
||||
ID: tailcfg.NodeID(node.ID), // this is the actual ID
|
||||
StableID: tailcfg.StableNodeID(
|
||||
|
@ -121,14 +111,14 @@ func tailNode(
|
|||
|
||||
User: tailcfg.UserID(node.UserID),
|
||||
|
||||
Key: nodeKey,
|
||||
Key: node.NodeKey,
|
||||
KeyExpiry: keyExpiry,
|
||||
|
||||
Machine: machineKey,
|
||||
DiscoKey: discoKey,
|
||||
Machine: node.MachineKey,
|
||||
DiscoKey: node.DiscoKey,
|
||||
Addresses: addrs,
|
||||
AllowedIPs: allowedIPs,
|
||||
Endpoints: node.Endpoints,
|
||||
Endpoints: endpoints,
|
||||
DERP: derp,
|
||||
Hostinfo: hostInfo.View(),
|
||||
Created: node.CreatedAt,
|
||||
|
|
|
@ -58,16 +58,36 @@ func TestTailNode(t *testing.T) {
|
|||
pol: &policy.ACLPolicy{},
|
||||
dnsConfig: &tailcfg.DNSConfig{},
|
||||
baseDomain: "",
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
want: &tailcfg.Node{
|
||||
StableID: "0",
|
||||
Addresses: []netip.Prefix{},
|
||||
AllowedIPs: []netip.Prefix{},
|
||||
DERP: "127.3.3.40:0",
|
||||
Hostinfo: hiview(tailcfg.Hostinfo{}),
|
||||
Tags: []string{},
|
||||
PrimaryRoutes: []netip.Prefix{},
|
||||
Online: new(bool),
|
||||
MachineAuthorized: true,
|
||||
Capabilities: []tailcfg.NodeCapability{
|
||||
"https://tailscale.com/cap/file-sharing", "https://tailscale.com/cap/is-admin",
|
||||
"https://tailscale.com/cap/ssh", "debug-disable-upnp",
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "minimal-node",
|
||||
node: &types.Node{
|
||||
ID: 0,
|
||||
MachineKey: "mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507",
|
||||
NodeKey: "nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
|
||||
DiscoKey: "discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084",
|
||||
ID: 0,
|
||||
MachineKey: mustMK(
|
||||
"mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507",
|
||||
),
|
||||
NodeKey: mustNK(
|
||||
"nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
|
||||
),
|
||||
DiscoKey: mustDK(
|
||||
"discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084",
|
||||
),
|
||||
IPAddresses: []netip.Addr{
|
||||
netip.MustParseAddr("100.64.0.1"),
|
||||
},
|
||||
|
@ -133,10 +153,9 @@ func TestTailNode(t *testing.T) {
|
|||
netip.MustParsePrefix("0.0.0.0/0"),
|
||||
netip.MustParsePrefix("192.168.0.0/24"),
|
||||
},
|
||||
Endpoints: []string{},
|
||||
DERP: "127.3.3.40:0",
|
||||
Hostinfo: hiview(tailcfg.Hostinfo{}),
|
||||
Created: created,
|
||||
DERP: "127.3.3.40:0",
|
||||
Hostinfo: hiview(tailcfg.Hostinfo{}),
|
||||
Created: created,
|
||||
|
||||
Tags: []string{},
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
type Notifier struct {
|
||||
|
@ -17,9 +18,9 @@ func NewNotifier() *Notifier {
|
|||
return &Notifier{}
|
||||
}
|
||||
|
||||
func (n *Notifier) AddNode(machineKey string, c chan<- types.StateUpdate) {
|
||||
log.Trace().Caller().Str("key", machineKey).Msg("acquiring lock to add node")
|
||||
defer log.Trace().Caller().Str("key", machineKey).Msg("releasing lock to add node")
|
||||
func (n *Notifier) AddNode(machineKey key.MachinePublic, c chan<- types.StateUpdate) {
|
||||
log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("acquiring lock to add node")
|
||||
defer log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("releasing lock to add node")
|
||||
|
||||
n.l.Lock()
|
||||
defer n.l.Unlock()
|
||||
|
@ -28,17 +29,17 @@ func (n *Notifier) AddNode(machineKey string, c chan<- types.StateUpdate) {
|
|||
n.nodes = make(map[string]chan<- types.StateUpdate)
|
||||
}
|
||||
|
||||
n.nodes[machineKey] = c
|
||||
n.nodes[machineKey.String()] = c
|
||||
|
||||
log.Trace().
|
||||
Str("machine_key", machineKey).
|
||||
Str("machine_key", machineKey.ShortString()).
|
||||
Int("open_chans", len(n.nodes)).
|
||||
Msg("Added new channel")
|
||||
}
|
||||
|
||||
func (n *Notifier) RemoveNode(machineKey string) {
|
||||
log.Trace().Caller().Str("key", machineKey).Msg("acquiring lock to remove node")
|
||||
defer log.Trace().Caller().Str("key", machineKey).Msg("releasing lock to remove node")
|
||||
func (n *Notifier) RemoveNode(machineKey key.MachinePublic) {
|
||||
log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("acquiring lock to remove node")
|
||||
defer log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("releasing lock to remove node")
|
||||
|
||||
n.l.Lock()
|
||||
defer n.l.Unlock()
|
||||
|
@ -47,10 +48,10 @@ func (n *Notifier) RemoveNode(machineKey string) {
|
|||
return
|
||||
}
|
||||
|
||||
delete(n.nodes, machineKey)
|
||||
delete(n.nodes, machineKey.String())
|
||||
|
||||
log.Trace().
|
||||
Str("machine_key", machineKey).
|
||||
Str("machine_key", machineKey.ShortString()).
|
||||
Int("open_chans", len(n.nodes)).
|
||||
Msg("Removed channel")
|
||||
}
|
||||
|
|
|
@ -90,42 +90,28 @@ func (h *Headscale) determineTokenExpiration(idTokenExpiration time.Time) time.T
|
|||
|
||||
// RegisterOIDC redirects to the OIDC provider for authentication
|
||||
// Puts NodeKey in cache so the callback can retrieve it using the oidc state param
|
||||
// Listens in /oidc/register/:nKey.
|
||||
// Listens in /oidc/register/:mKey.
|
||||
func (h *Headscale) RegisterOIDC(
|
||||
writer http.ResponseWriter,
|
||||
req *http.Request,
|
||||
) {
|
||||
vars := mux.Vars(req)
|
||||
nodeKeyStr, ok := vars["nkey"]
|
||||
machineKeyStr, ok := vars["mkey"]
|
||||
|
||||
log.Debug().
|
||||
Caller().
|
||||
Str("node_key", nodeKeyStr).
|
||||
Str("machine_key", machineKeyStr).
|
||||
Bool("ok", ok).
|
||||
Msg("Received oidc register call")
|
||||
|
||||
if !util.NodePublicKeyRegex.Match([]byte(nodeKeyStr)) {
|
||||
log.Warn().Str("node_key", nodeKeyStr).Msg("Invalid node key passed to registration url")
|
||||
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusUnauthorized)
|
||||
_, err := writer.Write([]byte("Unauthorized"))
|
||||
if err != nil {
|
||||
util.LogErr(err, "Failed to write response")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// We need to make sure we dont open for XSS style injections, if the parameter that
|
||||
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
|
||||
// the template and log an error.
|
||||
var nodeKey key.NodePublic
|
||||
err := nodeKey.UnmarshalText(
|
||||
[]byte(nodeKeyStr),
|
||||
var machineKey key.MachinePublic
|
||||
err := machineKey.UnmarshalText(
|
||||
[]byte(machineKeyStr),
|
||||
)
|
||||
|
||||
if !ok || nodeKeyStr == "" || err != nil {
|
||||
if err != nil {
|
||||
log.Warn().
|
||||
Err(err).
|
||||
Msg("Failed to parse incoming nodekey in OIDC registration")
|
||||
|
@ -154,7 +140,7 @@ func (h *Headscale) RegisterOIDC(
|
|||
// place the node key into the state cache, so it can be retrieved later
|
||||
h.registrationCache.Set(
|
||||
stateStr,
|
||||
nodeKey,
|
||||
machineKey,
|
||||
registerCacheExpiration,
|
||||
)
|
||||
|
||||
|
@ -232,7 +218,7 @@ func (h *Headscale) OIDCCallback(
|
|||
return
|
||||
}
|
||||
|
||||
nodeKey, nodeExists, err := h.validateNodeForOIDCCallback(
|
||||
machineKey, nodeExists, err := h.validateNodeForOIDCCallback(
|
||||
writer,
|
||||
state,
|
||||
claims,
|
||||
|
@ -255,7 +241,7 @@ func (h *Headscale) OIDCCallback(
|
|||
return
|
||||
}
|
||||
|
||||
if err := h.registerNodeForOIDCCallback(writer, user, nodeKey, idTokenExpiry); err != nil {
|
||||
if err := h.registerNodeForOIDCCallback(writer, user, machineKey, idTokenExpiry); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -462,10 +448,10 @@ func (h *Headscale) validateNodeForOIDCCallback(
|
|||
state string,
|
||||
claims *IDTokenClaims,
|
||||
expiry time.Time,
|
||||
) (*key.NodePublic, bool, error) {
|
||||
) (*key.MachinePublic, bool, error) {
|
||||
// retrieve nodekey from state cache
|
||||
nodeKeyIf, nodeKeyFound := h.registrationCache.Get(state)
|
||||
if !nodeKeyFound {
|
||||
machineKeyIf, machineKeyFound := h.registrationCache.Get(state)
|
||||
if !machineKeyFound {
|
||||
log.Trace().
|
||||
Msg("requested node state key expired before authorisation completed")
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
|
@ -478,11 +464,11 @@ func (h *Headscale) validateNodeForOIDCCallback(
|
|||
return nil, false, errOIDCNodeKeyMissing
|
||||
}
|
||||
|
||||
var nodeKey key.NodePublic
|
||||
nodeKey, nodeKeyOK := nodeKeyIf.(key.NodePublic)
|
||||
if !nodeKeyOK {
|
||||
var machineKey key.MachinePublic
|
||||
machineKey, machineKeyOK := machineKeyIf.(key.MachinePublic)
|
||||
if !machineKeyOK {
|
||||
log.Trace().
|
||||
Interface("got", nodeKeyIf).
|
||||
Interface("got", machineKeyIf).
|
||||
Msg("requested node state key is not a nodekey")
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
|
@ -498,7 +484,7 @@ func (h *Headscale) validateNodeForOIDCCallback(
|
|||
// The error is not important, because if it does not
|
||||
// exist, then this is a new node and we will move
|
||||
// on to registration.
|
||||
node, _ := h.db.GetNodeByNodeKey(nodeKey)
|
||||
node, _ := h.db.GetNodeByMachineKey(machineKey)
|
||||
|
||||
if node != nil {
|
||||
log.Trace().
|
||||
|
@ -553,7 +539,7 @@ func (h *Headscale) validateNodeForOIDCCallback(
|
|||
return nil, true, nil
|
||||
}
|
||||
|
||||
return &nodeKey, false, nil
|
||||
return &machineKey, false, nil
|
||||
}
|
||||
|
||||
func getUserName(
|
||||
|
@ -624,13 +610,13 @@ func (h *Headscale) findOrCreateNewUserForOIDCCallback(
|
|||
func (h *Headscale) registerNodeForOIDCCallback(
|
||||
writer http.ResponseWriter,
|
||||
user *types.User,
|
||||
nodeKey *key.NodePublic,
|
||||
machineKey *key.MachinePublic,
|
||||
expiry time.Time,
|
||||
) error {
|
||||
if _, err := h.db.RegisterNodeFromAuthCallback(
|
||||
// TODO(kradalby): find a better way to use the cache across modules
|
||||
h.registrationCache,
|
||||
nodeKey.String(),
|
||||
*machineKey,
|
||||
user.Name,
|
||||
&expiry,
|
||||
util.RegisterMethodOIDC,
|
||||
|
|
|
@ -14,12 +14,29 @@ import (
|
|||
"go4.org/netipx"
|
||||
"gopkg.in/check.v1"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
var ipComparer = cmp.Comparer(func(x, y netip.Addr) bool {
|
||||
return x.Compare(y) == 0
|
||||
})
|
||||
|
||||
var mkeyComparer = cmp.Comparer(func(x, y key.MachinePublic) bool {
|
||||
return x.String() == y.String()
|
||||
})
|
||||
|
||||
var nkeyComparer = cmp.Comparer(func(x, y key.NodePublic) bool {
|
||||
return x.String() == y.String()
|
||||
})
|
||||
|
||||
var dkeyComparer = cmp.Comparer(func(x, y key.DiscoPublic) bool {
|
||||
return x.String() == y.String()
|
||||
})
|
||||
|
||||
var keyComparers []cmp.Option = []cmp.Option{
|
||||
mkeyComparer, nkeyComparer, dkeyComparer,
|
||||
}
|
||||
|
||||
func Test(t *testing.T) {
|
||||
check.TestingT(t)
|
||||
}
|
||||
|
@ -951,7 +968,7 @@ func Test_listNodesInUser(t *testing.T) {
|
|||
t.Run(test.name, func(t *testing.T) {
|
||||
got := filterNodesByUser(test.args.nodes, test.args.user)
|
||||
|
||||
if diff := cmp.Diff(test.want, got); diff != "" {
|
||||
if diff := cmp.Diff(test.want, got, keyComparers...); diff != "" {
|
||||
t.Errorf("listNodesInUser() = (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
@ -1704,7 +1721,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
|
|||
test.args.nodes,
|
||||
test.args.user,
|
||||
)
|
||||
if diff := cmp.Diff(test.want, got, ipComparer); diff != "" {
|
||||
if diff := cmp.Diff(test.want, got, ipComparer, mkeyComparer, nkeyComparer, dkeyComparer); diff != "" {
|
||||
t.Errorf("excludeCorrectlyTaggedNodes() (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
@ -2723,7 +2740,7 @@ func Test_getFilteredByACLPeers(t *testing.T) {
|
|||
tt.args.nodes,
|
||||
tt.args.rules,
|
||||
)
|
||||
if diff := cmp.Diff(tt.want, got, ipComparer); diff != "" {
|
||||
if diff := cmp.Diff(tt.want, got, ipComparer, mkeyComparer, nkeyComparer, dkeyComparer); diff != "" {
|
||||
t.Errorf("FilterNodesByACL() unexpected result (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
@ -2986,9 +3003,6 @@ func TestValidExpandTagOwnersInSources(t *testing.T) {
|
|||
|
||||
node := &types.Node{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testnodes",
|
||||
IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.1")},
|
||||
UserID: 0,
|
||||
|
@ -3041,9 +3055,6 @@ func TestInvalidTagValidUser(t *testing.T) {
|
|||
|
||||
node := &types.Node{
|
||||
ID: 1,
|
||||
MachineKey: "12345",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testnodes",
|
||||
IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.1")},
|
||||
UserID: 1,
|
||||
|
@ -3095,9 +3106,6 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) {
|
|||
|
||||
node := &types.Node{
|
||||
ID: 1,
|
||||
MachineKey: "12345",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testnodes",
|
||||
IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.1")},
|
||||
UserID: 1,
|
||||
|
@ -3159,9 +3167,6 @@ func TestValidTagInvalidUser(t *testing.T) {
|
|||
|
||||
node := &types.Node{
|
||||
ID: 1,
|
||||
MachineKey: "12345",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "webserver",
|
||||
IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.1")},
|
||||
UserID: 1,
|
||||
|
@ -3179,9 +3184,6 @@ func TestValidTagInvalidUser(t *testing.T) {
|
|||
|
||||
nodes2 := &types.Node{
|
||||
ID: 2,
|
||||
MachineKey: "56789",
|
||||
NodeKey: "bar2",
|
||||
DiscoKey: "faab",
|
||||
Hostname: "user",
|
||||
IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.2")},
|
||||
UserID: 1,
|
||||
|
|
|
@ -34,7 +34,7 @@ func logPollFunc(
|
|||
Bool("readOnly", mapRequest.ReadOnly).
|
||||
Bool("omitPeers", mapRequest.OmitPeers).
|
||||
Bool("stream", mapRequest.Stream).
|
||||
Str("node_key", node.NodeKey).
|
||||
Str("node_key", node.NodeKey.ShortString()).
|
||||
Str("node", node.Hostname).
|
||||
Msg(msg)
|
||||
},
|
||||
|
@ -45,7 +45,7 @@ func logPollFunc(
|
|||
Bool("readOnly", mapRequest.ReadOnly).
|
||||
Bool("omitPeers", mapRequest.OmitPeers).
|
||||
Bool("stream", mapRequest.Stream).
|
||||
Str("node_key", node.NodeKey).
|
||||
Str("node_key", node.NodeKey.ShortString()).
|
||||
Str("node", node.Hostname).
|
||||
Err(err).
|
||||
Msg(msg)
|
||||
|
@ -81,7 +81,7 @@ func (h *Headscale) handlePoll(
|
|||
Bool("readOnly", mapRequest.ReadOnly).
|
||||
Bool("omitPeers", mapRequest.OmitPeers).
|
||||
Bool("stream", mapRequest.Stream).
|
||||
Str("node_key", node.NodeKey).
|
||||
Str("node_key", node.NodeKey.ShortString()).
|
||||
Str("node", node.Hostname).
|
||||
Strs("endpoints", node.Endpoints).
|
||||
Msg("Received endpoint update")
|
||||
|
@ -90,8 +90,8 @@ func (h *Headscale) handlePoll(
|
|||
node.LastSeen = &now
|
||||
node.Hostname = mapRequest.Hostinfo.Hostname
|
||||
node.HostInfo = types.HostInfo(*mapRequest.Hostinfo)
|
||||
node.DiscoKey = mapRequest.DiscoKey.String()
|
||||
node.Endpoints = mapRequest.Endpoints
|
||||
node.DiscoKey = mapRequest.DiscoKey
|
||||
node.SetEndpointsFromAddrPorts(mapRequest.Endpoints)
|
||||
|
||||
if err := h.db.NodeSave(node); err != nil {
|
||||
logErr(err, "Failed to persist/update node in the database")
|
||||
|
@ -113,7 +113,7 @@ func (h *Headscale) handlePoll(
|
|||
Type: types.StatePeerChanged,
|
||||
Changed: types.Nodes{node},
|
||||
},
|
||||
node.MachineKey)
|
||||
node.MachineKey.String())
|
||||
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
if f, ok := writer.(http.Flusher); ok {
|
||||
|
@ -143,8 +143,8 @@ func (h *Headscale) handlePoll(
|
|||
node.LastSeen = &now
|
||||
node.Hostname = mapRequest.Hostinfo.Hostname
|
||||
node.HostInfo = types.HostInfo(*mapRequest.Hostinfo)
|
||||
node.DiscoKey = mapRequest.DiscoKey.String()
|
||||
node.Endpoints = mapRequest.Endpoints
|
||||
node.DiscoKey = mapRequest.DiscoKey
|
||||
node.SetEndpointsFromAddrPorts(mapRequest.Endpoints)
|
||||
|
||||
// When a node connects to control, list the peers it has at
|
||||
// that given point, further updates are kept in memory in
|
||||
|
@ -222,7 +222,7 @@ func (h *Headscale) handlePoll(
|
|||
Type: types.StatePeerChanged,
|
||||
Changed: types.Nodes{node},
|
||||
},
|
||||
node.MachineKey)
|
||||
node.MachineKey.String())
|
||||
|
||||
// Set up the client stream
|
||||
h.pollNetMapStreamWG.Add(1)
|
||||
|
@ -342,7 +342,7 @@ func (h *Headscale) handlePoll(
|
|||
Bool("readOnly", mapRequest.ReadOnly).
|
||||
Bool("omitPeers", mapRequest.OmitPeers).
|
||||
Bool("stream", mapRequest.Stream).
|
||||
Str("node_key", node.NodeKey).
|
||||
Str("node_key", node.NodeKey.ShortString()).
|
||||
Str("node", node.Hostname).
|
||||
TimeDiff("timeSpent", time.Now(), now).
|
||||
Msg("update sent")
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
"go4.org/netipx"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
@ -24,10 +25,30 @@ var (
|
|||
|
||||
// Node is a Headscale client.
|
||||
type Node struct {
|
||||
ID uint64 `gorm:"primary_key"`
|
||||
MachineKey string `gorm:"type:varchar(64);unique_index"`
|
||||
NodeKey string
|
||||
DiscoKey string
|
||||
ID uint64 `gorm:"primary_key"`
|
||||
|
||||
// MachineKeyValue is the string representation of MachineKey
|
||||
// it is _only_ used for reading and writing the key to the
|
||||
// database and should not be used.
|
||||
// Use MachineKey instead.
|
||||
MachineKeyValue string `gorm:"column:machine_key;unique_index"`
|
||||
|
||||
// NodeKeyValue is the string representation of NodeKey
|
||||
// it is _only_ used for reading and writing the key to the
|
||||
// database and should not be used.
|
||||
// Use NodeKey instead.
|
||||
NodeKeyValue string `gorm:"column:node_key"`
|
||||
|
||||
// DiscoKeyValue is the string representation of DiscoKey
|
||||
// it is _only_ used for reading and writing the key to the
|
||||
// database and should not be used.
|
||||
// Use DiscoKey instead.
|
||||
DiscoKeyValue string `gorm:"column:disco_key"`
|
||||
|
||||
MachineKey key.MachinePublic `gorm:"-"`
|
||||
NodeKey key.NodePublic `gorm:"-"`
|
||||
DiscoKey key.DiscoPublic `gorm:"-"`
|
||||
|
||||
IPAddresses NodeAddresses
|
||||
|
||||
// Hostname represents the name given by the Tailscale
|
||||
|
@ -174,6 +195,31 @@ func (node Node) IsExpired() bool {
|
|||
return time.Now().UTC().After(*node.Expiry)
|
||||
}
|
||||
|
||||
// TODO(kradalby): Try to replace the types in the DB to be correct.
|
||||
func (node *Node) EndpointsToAddrPort() ([]netip.AddrPort, error) {
|
||||
var ret []netip.AddrPort
|
||||
for _, ep := range node.Endpoints {
|
||||
addrPort, err := netip.ParseAddrPort(ep)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret = append(ret, addrPort)
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// TODO(kradalby): Try to replace the types in the DB to be correct.
|
||||
func (node *Node) SetEndpointsFromAddrPorts(in []netip.AddrPort) {
|
||||
var strs StringList
|
||||
for _, addrPort := range in {
|
||||
strs = append(strs, addrPort.String())
|
||||
}
|
||||
|
||||
node.Endpoints = strs
|
||||
}
|
||||
|
||||
// IsOnline returns if the node is connected to Headscale.
|
||||
// This is really a naive implementation, as we don't really see
|
||||
// if there is a working connection between the client and the server.
|
||||
|
@ -226,13 +272,52 @@ func (nodes Nodes) FilterByIP(ip netip.Addr) Nodes {
|
|||
return found
|
||||
}
|
||||
|
||||
// BeforeSave is a hook that ensures that some values that
|
||||
// cannot be directly marshalled into database values are stored
|
||||
// correctly in the database.
|
||||
// This currently means storing the keys as strings.
|
||||
func (n *Node) BeforeSave(tx *gorm.DB) (err error) {
|
||||
n.MachineKeyValue = n.MachineKey.String()
|
||||
n.NodeKeyValue = n.NodeKey.String()
|
||||
n.DiscoKeyValue = n.DiscoKey.String()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// AfterFind is a hook that ensures that Node objects fields that
|
||||
// has a different type in the database is unwrapped and populated
|
||||
// correctly.
|
||||
// This currently unmarshals all the keys, stored as strings, into
|
||||
// the proper types.
|
||||
func (n *Node) AfterFind(tx *gorm.DB) (err error) {
|
||||
var machineKey key.MachinePublic
|
||||
if err := machineKey.UnmarshalText([]byte(n.MachineKeyValue)); err != nil {
|
||||
return err
|
||||
}
|
||||
n.MachineKey = machineKey
|
||||
|
||||
var nodeKey key.NodePublic
|
||||
if err := nodeKey.UnmarshalText([]byte(n.NodeKeyValue)); err != nil {
|
||||
return err
|
||||
}
|
||||
n.NodeKey = nodeKey
|
||||
|
||||
var discoKey key.DiscoPublic
|
||||
if err := discoKey.UnmarshalText([]byte(n.DiscoKeyValue)); err != nil {
|
||||
return err
|
||||
}
|
||||
n.DiscoKey = discoKey
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (node *Node) Proto() *v1.Node {
|
||||
nodeProto := &v1.Node{
|
||||
Id: node.ID,
|
||||
MachineKey: node.MachineKey,
|
||||
MachineKey: node.MachineKey.String(),
|
||||
|
||||
NodeKey: node.NodeKey,
|
||||
DiscoKey: node.DiscoKey,
|
||||
NodeKey: node.NodeKey.String(),
|
||||
DiscoKey: node.DiscoKey.String(),
|
||||
IpAddresses: node.IPAddresses.StringSlice(),
|
||||
Name: node.Hostname,
|
||||
GivenName: node.GivenName,
|
||||
|
@ -289,47 +374,6 @@ func (node *Node) GetFQDN(dnsConfig *tailcfg.DNSConfig, baseDomain string) (stri
|
|||
return hostname, nil
|
||||
}
|
||||
|
||||
func (node *Node) MachinePublicKey() (key.MachinePublic, error) {
|
||||
var machineKey key.MachinePublic
|
||||
|
||||
if node.MachineKey != "" {
|
||||
err := machineKey.UnmarshalText(
|
||||
[]byte(node.MachineKey),
|
||||
)
|
||||
if err != nil {
|
||||
return key.MachinePublic{}, fmt.Errorf("failed to parse machine public key: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return machineKey, nil
|
||||
}
|
||||
|
||||
func (node *Node) DiscoPublicKey() (key.DiscoPublic, error) {
|
||||
var discoKey key.DiscoPublic
|
||||
if node.DiscoKey != "" {
|
||||
err := discoKey.UnmarshalText(
|
||||
[]byte(node.DiscoKey),
|
||||
)
|
||||
if err != nil {
|
||||
return key.DiscoPublic{}, fmt.Errorf("failed to parse disco public key: %w", err)
|
||||
}
|
||||
} else {
|
||||
discoKey = key.DiscoPublic{}
|
||||
}
|
||||
|
||||
return discoKey, nil
|
||||
}
|
||||
|
||||
func (node *Node) NodePublicKey() (key.NodePublic, error) {
|
||||
var nodeKey key.NodePublic
|
||||
err := nodeKey.UnmarshalText([]byte(node.NodeKey))
|
||||
if err != nil {
|
||||
return key.NodePublic{}, fmt.Errorf("failed to parse node public key: %w", err)
|
||||
}
|
||||
|
||||
return nodeKey, nil
|
||||
}
|
||||
|
||||
func (node Node) String() string {
|
||||
return node.Hostname
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue