Redo route code (#2422)
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
16868190c8
commit
7891378f57
53 changed files with 2977 additions and 6251 deletions
|
@ -32,6 +32,7 @@ import (
|
|||
"github.com/juanfont/headscale/hscontrol/mapper"
|
||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/routes"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
zerolog "github.com/philip-bui/grpc-zerolog"
|
||||
|
@ -92,6 +93,7 @@ type Headscale struct {
|
|||
polManOnce sync.Once
|
||||
polMan policy.PolicyManager
|
||||
extraRecordMan *dns.ExtraRecordsMan
|
||||
primaryRoutes *routes.PrimaryRoutes
|
||||
|
||||
mapper *mapper.Mapper
|
||||
nodeNotifier *notifier.Notifier
|
||||
|
@ -134,6 +136,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
|||
registrationCache: registrationCache,
|
||||
pollNetMapStreamWG: sync.WaitGroup{},
|
||||
nodeNotifier: notifier.NewNotifier(cfg),
|
||||
primaryRoutes: routes.New(),
|
||||
}
|
||||
|
||||
app.db, err = db.NewHeadscaleDatabase(
|
||||
|
@ -495,6 +498,8 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
|
|||
|
||||
// TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
|
||||
// Maybe we should attempt a new in memory state and not go via the DB?
|
||||
// Maybe this should be implemented as an event bus?
|
||||
// A bool is returned indicating if a full update was sent to all nodes
|
||||
func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error {
|
||||
users, err := db.ListUsers()
|
||||
if err != nil {
|
||||
|
@ -516,6 +521,7 @@ func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *not
|
|||
|
||||
// TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
|
||||
// Maybe we should attempt a new in memory state and not go via the DB?
|
||||
// Maybe this should be implemented as an event bus?
|
||||
// A bool is returned indicating if a full update was sent to all nodes
|
||||
func nodesChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) (bool, error) {
|
||||
nodes, err := db.ListNodes()
|
||||
|
@ -566,7 +572,7 @@ func (h *Headscale) Serve() error {
|
|||
|
||||
// Fetch an initial DERP Map before we start serving
|
||||
h.DERPMap = derp.GetDERPMap(h.cfg.DERP)
|
||||
h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier, h.polMan)
|
||||
h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier, h.polMan, h.primaryRoutes)
|
||||
|
||||
if h.cfg.DERP.ServerEnabled {
|
||||
// When embedded DERP is enabled we always need a STUN server
|
||||
|
|
|
@ -84,16 +84,13 @@ func (h *Headscale) handleExistingNode(
|
|||
// If the request expiry is in the past, we consider it a logout.
|
||||
if requestExpiry.Before(time.Now()) {
|
||||
if node.IsEphemeral() {
|
||||
changedNodes, err := h.db.DeleteNode(node, h.nodeNotifier.LikelyConnectedMap())
|
||||
err := h.db.DeleteNode(node)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("deleting ephemeral node: %w", err)
|
||||
}
|
||||
|
||||
ctx := types.NotifyCtx(context.Background(), "logout-ephemeral", "na")
|
||||
h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerRemoved(node.ID))
|
||||
if changedNodes != nil {
|
||||
h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(changedNodes...))
|
||||
}
|
||||
}
|
||||
|
||||
expired = true
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"fmt"
|
||||
"net/netip"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -622,6 +623,62 @@ AND auth_key_id NOT IN (
|
|||
},
|
||||
Rollback: func(db *gorm.DB) error { return nil },
|
||||
},
|
||||
// Migrate all routes from the Route table to the new field ApprovedRoutes
|
||||
// in the Node table. Then drop the Route table.
|
||||
{
|
||||
ID: "202502131714",
|
||||
Migrate: func(tx *gorm.DB) error {
|
||||
if !tx.Migrator().HasColumn(&types.Node{}, "approved_routes") {
|
||||
err := tx.Migrator().AddColumn(&types.Node{}, "approved_routes")
|
||||
if err != nil {
|
||||
return fmt.Errorf("adding column types.Node: %w", err)
|
||||
}
|
||||
}
|
||||
// Ensure the ApprovedRoutes exist.
|
||||
// err := tx.AutoMigrate(&types.Node{})
|
||||
// if err != nil {
|
||||
// return fmt.Errorf("automigrating types.Node: %w", err)
|
||||
// }
|
||||
|
||||
nodeRoutes := map[uint64][]netip.Prefix{}
|
||||
|
||||
var routes []types.Route
|
||||
err = tx.Find(&routes).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetching routes: %w", err)
|
||||
}
|
||||
|
||||
for _, route := range routes {
|
||||
if route.Enabled {
|
||||
nodeRoutes[route.NodeID] = append(nodeRoutes[route.NodeID], route.Prefix)
|
||||
}
|
||||
}
|
||||
|
||||
for nodeID, routes := range nodeRoutes {
|
||||
slices.SortFunc(routes, util.ComparePrefix)
|
||||
slices.Compact(routes)
|
||||
|
||||
data, err := json.Marshal(routes)
|
||||
|
||||
err = tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", data).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("saving approved routes to new column: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
Rollback: func(db *gorm.DB) error { return nil },
|
||||
},
|
||||
{
|
||||
ID: "202502171819",
|
||||
Migrate: func(tx *gorm.DB) error {
|
||||
_ = tx.Migrator().DropColumn(&types.Node{}, "last_seen")
|
||||
|
||||
return nil
|
||||
},
|
||||
Rollback: func(db *gorm.DB) error { return nil },
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
@ -48,25 +48,43 @@ func TestMigrationsSQLite(t *testing.T) {
|
|||
{
|
||||
dbPath: "testdata/0-22-3-to-0-23-0-routes-are-dropped-2063.sqlite",
|
||||
wantFunc: func(t *testing.T, h *HSDatabase) {
|
||||
routes, err := Read(h.DB, func(rx *gorm.DB) (types.Routes, error) {
|
||||
return GetRoutes(rx)
|
||||
nodes, err := Read(h.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
||||
n1, err := GetNodeByID(rx, 1)
|
||||
n26, err := GetNodeByID(rx, 26)
|
||||
n31, err := GetNodeByID(rx, 31)
|
||||
n32, err := GetNodeByID(rx, 32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return types.Nodes{n1, n26, n31, n32}, nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, routes, 10)
|
||||
want := types.Routes{
|
||||
r(1, "0.0.0.0/0", true, true, false),
|
||||
r(1, "::/0", true, true, false),
|
||||
r(1, "10.9.110.0/24", true, true, true),
|
||||
r(26, "172.100.100.0/24", true, true, true),
|
||||
r(26, "172.100.100.0/24", true, false, false),
|
||||
r(31, "0.0.0.0/0", true, true, false),
|
||||
r(31, "0.0.0.0/0", true, false, false),
|
||||
r(31, "::/0", true, true, false),
|
||||
r(31, "::/0", true, false, false),
|
||||
r(32, "192.168.0.24/32", true, true, true),
|
||||
// want := types.Routes{
|
||||
// r(1, "0.0.0.0/0", true, false),
|
||||
// r(1, "::/0", true, false),
|
||||
// r(1, "10.9.110.0/24", true, true),
|
||||
// r(26, "172.100.100.0/24", true, true),
|
||||
// r(26, "172.100.100.0/24", true, false, false),
|
||||
// r(31, "0.0.0.0/0", true, false),
|
||||
// r(31, "0.0.0.0/0", true, false, false),
|
||||
// r(31, "::/0", true, false),
|
||||
// r(31, "::/0", true, false, false),
|
||||
// r(32, "192.168.0.24/32", true, true),
|
||||
// }
|
||||
want := [][]netip.Prefix{
|
||||
{ipp("0.0.0.0/0"), ipp("10.9.110.0/24"), ipp("::/0")},
|
||||
{ipp("172.100.100.0/24")},
|
||||
{ipp("0.0.0.0/0"), ipp("::/0")},
|
||||
{ipp("192.168.0.24/32")},
|
||||
}
|
||||
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), util.PrefixComparer); diff != "" {
|
||||
var got [][]netip.Prefix
|
||||
for _, node := range nodes {
|
||||
got = append(got, node.ApprovedRoutes)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(want, got, util.PrefixComparer); diff != "" {
|
||||
t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
},
|
||||
|
@ -74,13 +92,13 @@ func TestMigrationsSQLite(t *testing.T) {
|
|||
{
|
||||
dbPath: "testdata/0-22-3-to-0-23-0-routes-fail-foreign-key-2076.sqlite",
|
||||
wantFunc: func(t *testing.T, h *HSDatabase) {
|
||||
routes, err := Read(h.DB, func(rx *gorm.DB) (types.Routes, error) {
|
||||
return GetRoutes(rx)
|
||||
node, err := Read(h.DB, func(rx *gorm.DB) (*types.Node, error) {
|
||||
return GetNodeByID(rx, 13)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, routes, 4)
|
||||
want := types.Routes{
|
||||
assert.Len(t, node.ApprovedRoutes, 3)
|
||||
_ = types.Routes{
|
||||
// These routes exists, but have no nodes associated with them
|
||||
// when the migration starts.
|
||||
// r(1, "0.0.0.0/0", true, true, false),
|
||||
|
@ -111,7 +129,8 @@ func TestMigrationsSQLite(t *testing.T) {
|
|||
r(13, "::/0", true, true, false),
|
||||
r(13, "10.18.80.2/32", true, true, true),
|
||||
}
|
||||
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), util.PrefixComparer); diff != "" {
|
||||
want := []netip.Prefix{ipp("0.0.0.0/0"), ipp("10.18.80.2/32"), ipp("::/0")}
|
||||
if diff := cmp.Diff(want, node.ApprovedRoutes, util.PrefixComparer); diff != "" {
|
||||
t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
},
|
||||
|
@ -225,7 +244,7 @@ func TestMigrationsSQLite(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.dbPath, func(t *testing.T) {
|
||||
dbPath, err := testCopyOfDatabase(tt.dbPath)
|
||||
dbPath, err := testCopyOfDatabase(t, tt.dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("copying db for test: %s", err)
|
||||
}
|
||||
|
@ -247,7 +266,7 @@ func TestMigrationsSQLite(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func testCopyOfDatabase(src string) (string, error) {
|
||||
func testCopyOfDatabase(t *testing.T, src string) (string, error) {
|
||||
sourceFileStat, err := os.Stat(src)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
@ -263,11 +282,7 @@ func testCopyOfDatabase(src string) (string, error) {
|
|||
}
|
||||
defer source.Close()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "hsdb-test-*")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
fn := filepath.Base(src)
|
||||
dst := filepath.Join(tmpDir, fn)
|
||||
|
||||
|
@ -454,3 +469,27 @@ func TestMigrationsPostgres(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func dbForTest(t *testing.T) *HSDatabase {
|
||||
t.Helper()
|
||||
|
||||
dbPath := t.TempDir() + "/headscale_test.db"
|
||||
|
||||
db, err := NewHeadscaleDatabase(
|
||||
types.DatabaseConfig{
|
||||
Type: "sqlite3",
|
||||
Sqlite: types.SqliteConfig{
|
||||
Path: dbPath,
|
||||
},
|
||||
},
|
||||
"",
|
||||
emptyCache(),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("setting up database: %s", err)
|
||||
}
|
||||
|
||||
t.Logf("database set up at: %s", dbPath)
|
||||
|
||||
return db
|
||||
}
|
||||
|
|
|
@ -91,7 +91,7 @@ func TestIPAllocatorSequential(t *testing.T) {
|
|||
{
|
||||
name: "simple-with-db",
|
||||
dbFunc: func() *HSDatabase {
|
||||
db := dbForTest(t, "simple-with-db")
|
||||
db := dbForTest(t)
|
||||
user := types.User{Name: ""}
|
||||
db.DB.Save(&user)
|
||||
|
||||
|
@ -119,7 +119,7 @@ func TestIPAllocatorSequential(t *testing.T) {
|
|||
{
|
||||
name: "before-after-free-middle-in-db",
|
||||
dbFunc: func() *HSDatabase {
|
||||
db := dbForTest(t, "before-after-free-middle-in-db")
|
||||
db := dbForTest(t)
|
||||
user := types.User{Name: ""}
|
||||
db.DB.Save(&user)
|
||||
|
||||
|
@ -309,7 +309,7 @@ func TestBackfillIPAddresses(t *testing.T) {
|
|||
{
|
||||
name: "simple-backfill-ipv6",
|
||||
dbFunc: func() *HSDatabase {
|
||||
db := dbForTest(t, "simple-backfill-ipv6")
|
||||
db := dbForTest(t)
|
||||
user := types.User{Name: ""}
|
||||
db.DB.Save(&user)
|
||||
|
||||
|
@ -334,7 +334,7 @@ func TestBackfillIPAddresses(t *testing.T) {
|
|||
{
|
||||
name: "simple-backfill-ipv4",
|
||||
dbFunc: func() *HSDatabase {
|
||||
db := dbForTest(t, "simple-backfill-ipv4")
|
||||
db := dbForTest(t)
|
||||
user := types.User{Name: ""}
|
||||
db.DB.Save(&user)
|
||||
|
||||
|
@ -359,7 +359,7 @@ func TestBackfillIPAddresses(t *testing.T) {
|
|||
{
|
||||
name: "simple-backfill-remove-ipv6",
|
||||
dbFunc: func() *HSDatabase {
|
||||
db := dbForTest(t, "simple-backfill-remove-ipv6")
|
||||
db := dbForTest(t)
|
||||
user := types.User{Name: ""}
|
||||
db.DB.Save(&user)
|
||||
|
||||
|
@ -383,7 +383,7 @@ func TestBackfillIPAddresses(t *testing.T) {
|
|||
{
|
||||
name: "simple-backfill-remove-ipv4",
|
||||
dbFunc: func() *HSDatabase {
|
||||
db := dbForTest(t, "simple-backfill-remove-ipv4")
|
||||
db := dbForTest(t)
|
||||
user := types.User{Name: ""}
|
||||
db.DB.Save(&user)
|
||||
|
||||
|
@ -407,7 +407,7 @@ func TestBackfillIPAddresses(t *testing.T) {
|
|||
{
|
||||
name: "multi-backfill-ipv6",
|
||||
dbFunc: func() *HSDatabase {
|
||||
db := dbForTest(t, "simple-backfill-ipv6")
|
||||
db := dbForTest(t)
|
||||
user := types.User{Name: ""}
|
||||
db.DB.Save(&user)
|
||||
|
||||
|
@ -449,7 +449,6 @@ func TestBackfillIPAddresses(t *testing.T) {
|
|||
"UserID",
|
||||
"Endpoints",
|
||||
"Hostinfo",
|
||||
"Routes",
|
||||
"CreatedAt",
|
||||
"UpdatedAt",
|
||||
))
|
||||
|
@ -488,6 +487,10 @@ func TestBackfillIPAddresses(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestIPAllocatorNextNoReservedIPs(t *testing.T) {
|
||||
db, err := newSQLiteTestDB()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
alloc, err := NewIPAllocator(
|
||||
db,
|
||||
ptr.To(tsaddr.CGNATRange()),
|
||||
|
|
|
@ -12,12 +12,10 @@ import (
|
|||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/puzpuzpuz/xsync/v3"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -50,7 +48,6 @@ func ListPeers(tx *gorm.DB, nodeID types.NodeID) (types.Nodes, error) {
|
|||
Preload("AuthKey").
|
||||
Preload("AuthKey.User").
|
||||
Preload("User").
|
||||
Preload("Routes").
|
||||
Where("id <> ?",
|
||||
nodeID).Find(&nodes).Error; err != nil {
|
||||
return types.Nodes{}, err
|
||||
|
@ -73,7 +70,6 @@ func ListNodes(tx *gorm.DB) (types.Nodes, error) {
|
|||
Preload("AuthKey").
|
||||
Preload("AuthKey.User").
|
||||
Preload("User").
|
||||
Preload("Routes").
|
||||
Find(&nodes).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -127,7 +123,6 @@ func GetNodeByID(tx *gorm.DB, id types.NodeID) (*types.Node, error) {
|
|||
Preload("AuthKey").
|
||||
Preload("AuthKey.User").
|
||||
Preload("User").
|
||||
Preload("Routes").
|
||||
Find(&types.Node{ID: id}).First(&mach); result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
@ -151,7 +146,6 @@ func GetNodeByMachineKey(
|
|||
Preload("AuthKey").
|
||||
Preload("AuthKey.User").
|
||||
Preload("User").
|
||||
Preload("Routes").
|
||||
First(&mach, "machine_key = ?", machineKey.String()); result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
@ -175,7 +169,6 @@ func GetNodeByNodeKey(
|
|||
Preload("AuthKey").
|
||||
Preload("AuthKey.User").
|
||||
Preload("User").
|
||||
Preload("Routes").
|
||||
First(&mach, "node_key = ?", nodeKey.String()); result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
@ -201,7 +194,7 @@ func SetTags(
|
|||
if len(tags) == 0 {
|
||||
// if no tags are provided, we remove all forced tags
|
||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", "[]").Error; err != nil {
|
||||
return fmt.Errorf("failed to remove tags for node in the database: %w", err)
|
||||
return fmt.Errorf("removing tags: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -220,7 +213,34 @@ func SetTags(
|
|||
}
|
||||
|
||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", string(b)).Error; err != nil {
|
||||
return fmt.Errorf("failed to update tags for node in the database: %w", err)
|
||||
return fmt.Errorf("updating tags: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetTags takes a Node struct pointer and update the forced tags.
|
||||
func SetApprovedRoutes(
|
||||
tx *gorm.DB,
|
||||
nodeID types.NodeID,
|
||||
routes []netip.Prefix,
|
||||
) error {
|
||||
if len(routes) == 0 {
|
||||
// if no routes are provided, we remove all
|
||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", "[]").Error; err != nil {
|
||||
return fmt.Errorf("removing approved routes: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
b, err := json.Marshal(routes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", string(b)).Error; err != nil {
|
||||
return fmt.Errorf("updating approved routes: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -267,9 +287,9 @@ func NodeSetExpiry(tx *gorm.DB,
|
|||
return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("expiry", expiry).Error
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) DeleteNode(node *types.Node, isLikelyConnected *xsync.MapOf[types.NodeID, bool]) ([]types.NodeID, error) {
|
||||
return Write(hsdb.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
|
||||
return DeleteNode(tx, node, isLikelyConnected)
|
||||
func (hsdb *HSDatabase) DeleteNode(node *types.Node) error {
|
||||
return hsdb.Write(func(tx *gorm.DB) error {
|
||||
return DeleteNode(tx, node)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -277,19 +297,13 @@ func (hsdb *HSDatabase) DeleteNode(node *types.Node, isLikelyConnected *xsync.Ma
|
|||
// Caller is responsible for notifying all of change.
|
||||
func DeleteNode(tx *gorm.DB,
|
||||
node *types.Node,
|
||||
isLikelyConnected *xsync.MapOf[types.NodeID, bool],
|
||||
) ([]types.NodeID, error) {
|
||||
changed, err := deleteNodeRoutes(tx, node, isLikelyConnected)
|
||||
if err != nil {
|
||||
return changed, err
|
||||
}
|
||||
|
||||
) error {
|
||||
// Unscoped causes the node to be fully removed from the database.
|
||||
if err := tx.Unscoped().Delete(&types.Node{}, node.ID).Error; err != nil {
|
||||
return changed, err
|
||||
return err
|
||||
}
|
||||
|
||||
return changed, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteEphemeralNode deletes a Node from the database, note that this method
|
||||
|
@ -306,12 +320,6 @@ func (hsdb *HSDatabase) DeleteEphemeralNode(
|
|||
})
|
||||
}
|
||||
|
||||
// SetLastSeen sets a node's last seen field indicating that we
|
||||
// have recently communicating with this node.
|
||||
func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error {
|
||||
return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("last_seen", lastSeen).Error
|
||||
}
|
||||
|
||||
// HandleNodeFromAuthPath is called from the OIDC or CLI auth path
|
||||
// with a registrationID to register or reauthenticate a node.
|
||||
// If the node found in the registration cache is not already registered,
|
||||
|
@ -458,10 +466,6 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
|
|||
return nil, fmt.Errorf("failed register(save) node in the database: %w", err)
|
||||
}
|
||||
|
||||
if _, err := SaveNodeRoutes(tx, &node); err != nil {
|
||||
return nil, fmt.Errorf("failed to save node routes: %w", err)
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("node", node.Hostname).
|
||||
|
@ -504,141 +508,6 @@ func NodeSave(tx *gorm.DB, node *types.Node) error {
|
|||
return tx.Save(node).Error
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetAdvertisedRoutes(node *types.Node) ([]netip.Prefix, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) ([]netip.Prefix, error) {
|
||||
return GetAdvertisedRoutes(rx, node)
|
||||
})
|
||||
}
|
||||
|
||||
// GetAdvertisedRoutes returns the routes that are be advertised by the given node.
|
||||
func GetAdvertisedRoutes(tx *gorm.DB, node *types.Node) ([]netip.Prefix, error) {
|
||||
routes := types.Routes{}
|
||||
|
||||
err := tx.
|
||||
Preload("Node").
|
||||
Where("node_id = ? AND advertised = ?", node.ID, true).Find(&routes).Error
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("getting advertised routes for node(%d): %w", node.ID, err)
|
||||
}
|
||||
|
||||
var prefixes []netip.Prefix
|
||||
for _, route := range routes {
|
||||
prefixes = append(prefixes, netip.Prefix(route.Prefix))
|
||||
}
|
||||
|
||||
return prefixes, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetEnabledRoutes(node *types.Node) ([]netip.Prefix, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) ([]netip.Prefix, error) {
|
||||
return GetEnabledRoutes(rx, node)
|
||||
})
|
||||
}
|
||||
|
||||
// GetEnabledRoutes returns the routes that are enabled for the node.
|
||||
func GetEnabledRoutes(tx *gorm.DB, node *types.Node) ([]netip.Prefix, error) {
|
||||
routes := types.Routes{}
|
||||
|
||||
err := tx.
|
||||
Preload("Node").
|
||||
Where("node_id = ? AND advertised = ? AND enabled = ?", node.ID, true, true).
|
||||
Find(&routes).Error
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("getting enabled routes for node(%d): %w", node.ID, err)
|
||||
}
|
||||
|
||||
var prefixes []netip.Prefix
|
||||
for _, route := range routes {
|
||||
prefixes = append(prefixes, netip.Prefix(route.Prefix))
|
||||
}
|
||||
|
||||
return prefixes, nil
|
||||
}
|
||||
|
||||
func IsRoutesEnabled(tx *gorm.DB, node *types.Node, routeStr string) bool {
|
||||
route, err := netip.ParsePrefix(routeStr)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
enabledRoutes, err := GetEnabledRoutes(tx, node)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, enabledRoute := range enabledRoutes {
|
||||
if route == enabledRoute {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) enableRoutes(
|
||||
node *types.Node,
|
||||
newRoutes ...netip.Prefix,
|
||||
) (*types.StateUpdate, error) {
|
||||
return Write(hsdb.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
|
||||
return enableRoutes(tx, node, newRoutes...)
|
||||
})
|
||||
}
|
||||
|
||||
// enableRoutes enables new routes based on a list of new routes.
|
||||
func enableRoutes(tx *gorm.DB,
|
||||
node *types.Node, newRoutes ...netip.Prefix,
|
||||
) (*types.StateUpdate, error) {
|
||||
advertisedRoutes, err := GetAdvertisedRoutes(tx, node)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, newRoute := range newRoutes {
|
||||
if !slices.Contains(advertisedRoutes, newRoute) {
|
||||
return nil, fmt.Errorf(
|
||||
"route (%s) is not available on node %s: %w",
|
||||
node.Hostname,
|
||||
newRoute, ErrNodeRouteIsNotAvailable,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Separate loop so we don't leave things in a half-updated state
|
||||
for _, prefix := range newRoutes {
|
||||
route := types.Route{}
|
||||
err := tx.Preload("Node").
|
||||
Where("node_id = ? AND prefix = ?", node.ID, prefix.String()).
|
||||
First(&route).Error
|
||||
if err == nil {
|
||||
route.Enabled = true
|
||||
|
||||
// Mark already as primary if there is only this node offering this subnet
|
||||
// (and is not an exit route)
|
||||
if !route.IsExitRoute() {
|
||||
route.IsPrimary = isUniquePrefix(tx, route)
|
||||
}
|
||||
|
||||
err = tx.Save(&route).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to enable route: %w", err)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("failed to find route: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure the node has the latest routes when notifying the other
|
||||
// nodes
|
||||
nRoutes, err := GetNodeRoutes(tx, node)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read back routes: %w", err)
|
||||
}
|
||||
|
||||
node.Routes = nRoutes
|
||||
|
||||
return ptr.To(types.UpdatePeerChanged(node.ID)), nil
|
||||
}
|
||||
|
||||
func generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
|
||||
suppliedName = util.ConvertWithFQDNRules(suppliedName)
|
||||
if len(suppliedName) > util.LabelHostnameLength {
|
||||
|
|
|
@ -15,12 +15,10 @@ import (
|
|||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/puzpuzpuz/xsync/v3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/check.v1"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/ptr"
|
||||
|
@ -102,7 +100,7 @@ func (s *Suite) TestHardDeleteNode(c *check.C) {
|
|||
trx := db.DB.Save(&node)
|
||||
c.Assert(trx.Error, check.IsNil)
|
||||
|
||||
_, err = db.DeleteNode(&node, xsync.NewMapOf[types.NodeID, bool]())
|
||||
err = db.DeleteNode(&node)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.getNode(types.UserID(user.ID), "testnode3")
|
||||
|
@ -458,142 +456,143 @@ func TestHeadscale_generateGivenName(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestAutoApproveRoutes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
acl string
|
||||
routes []netip.Prefix
|
||||
want []netip.Prefix
|
||||
}{
|
||||
{
|
||||
name: "2068-approve-issue-sub",
|
||||
acl: `
|
||||
{
|
||||
"groups": {
|
||||
"group:k8s": ["test"]
|
||||
},
|
||||
// TODO(kradalby): replace this test
|
||||
// func TestAutoApproveRoutes(t *testing.T) {
|
||||
// tests := []struct {
|
||||
// name string
|
||||
// acl string
|
||||
// routes []netip.Prefix
|
||||
// want []netip.Prefix
|
||||
// }{
|
||||
// {
|
||||
// name: "2068-approve-issue-sub",
|
||||
// acl: `
|
||||
// {
|
||||
// "groups": {
|
||||
// "group:k8s": ["test"]
|
||||
// },
|
||||
|
||||
"acls": [
|
||||
{"action": "accept", "users": ["*"], "ports": ["*:*"]},
|
||||
],
|
||||
// "acls": [
|
||||
// {"action": "accept", "users": ["*"], "ports": ["*:*"]},
|
||||
// ],
|
||||
|
||||
"autoApprovers": {
|
||||
"routes": {
|
||||
"10.42.0.0/16": ["test"],
|
||||
}
|
||||
}
|
||||
}`,
|
||||
routes: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")},
|
||||
want: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")},
|
||||
},
|
||||
{
|
||||
name: "2068-approve-issue-sub",
|
||||
acl: `
|
||||
{
|
||||
"tagOwners": {
|
||||
"tag:exit": ["test"],
|
||||
},
|
||||
// "autoApprovers": {
|
||||
// "routes": {
|
||||
// "10.42.0.0/16": ["test"],
|
||||
// }
|
||||
// }
|
||||
// }`,
|
||||
// routes: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")},
|
||||
// want: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")},
|
||||
// },
|
||||
// {
|
||||
// name: "2068-approve-issue-sub",
|
||||
// acl: `
|
||||
// {
|
||||
// "tagOwners": {
|
||||
// "tag:exit": ["test"],
|
||||
// },
|
||||
|
||||
"groups": {
|
||||
"group:test": ["test"]
|
||||
},
|
||||
// "groups": {
|
||||
// "group:test": ["test"]
|
||||
// },
|
||||
|
||||
"acls": [
|
||||
{"action": "accept", "users": ["*"], "ports": ["*:*"]},
|
||||
],
|
||||
// "acls": [
|
||||
// {"action": "accept", "users": ["*"], "ports": ["*:*"]},
|
||||
// ],
|
||||
|
||||
"autoApprovers": {
|
||||
"exitNode": ["tag:exit"],
|
||||
"routes": {
|
||||
"10.10.0.0/16": ["group:test"],
|
||||
"10.11.0.0/16": ["test"],
|
||||
}
|
||||
}
|
||||
}`,
|
||||
routes: []netip.Prefix{
|
||||
tsaddr.AllIPv4(),
|
||||
tsaddr.AllIPv6(),
|
||||
netip.MustParsePrefix("10.10.0.0/16"),
|
||||
netip.MustParsePrefix("10.11.0.0/24"),
|
||||
},
|
||||
want: []netip.Prefix{
|
||||
tsaddr.AllIPv4(),
|
||||
netip.MustParsePrefix("10.10.0.0/16"),
|
||||
netip.MustParsePrefix("10.11.0.0/24"),
|
||||
tsaddr.AllIPv6(),
|
||||
},
|
||||
},
|
||||
}
|
||||
// "autoApprovers": {
|
||||
// "exitNode": ["tag:exit"],
|
||||
// "routes": {
|
||||
// "10.10.0.0/16": ["group:test"],
|
||||
// "10.11.0.0/16": ["test"],
|
||||
// }
|
||||
// }
|
||||
// }`,
|
||||
// routes: []netip.Prefix{
|
||||
// tsaddr.AllIPv4(),
|
||||
// tsaddr.AllIPv6(),
|
||||
// netip.MustParsePrefix("10.10.0.0/16"),
|
||||
// netip.MustParsePrefix("10.11.0.0/24"),
|
||||
// },
|
||||
// want: []netip.Prefix{
|
||||
// tsaddr.AllIPv4(),
|
||||
// netip.MustParsePrefix("10.10.0.0/16"),
|
||||
// netip.MustParsePrefix("10.11.0.0/24"),
|
||||
// tsaddr.AllIPv6(),
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
adb, err := newSQLiteTestDB()
|
||||
require.NoError(t, err)
|
||||
pol, err := policy.LoadACLPolicyFromBytes([]byte(tt.acl))
|
||||
// for _, tt := range tests {
|
||||
// t.Run(tt.name, func(t *testing.T) {
|
||||
// adb, err := newSQLiteTestDB()
|
||||
// require.NoError(t, err)
|
||||
// pol, err := policy.LoadACLPolicyFromBytes([]byte(tt.acl))
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pol)
|
||||
// require.NoError(t, err)
|
||||
// require.NotNil(t, pol)
|
||||
|
||||
user, err := adb.CreateUser(types.User{Name: "test"})
|
||||
require.NoError(t, err)
|
||||
// user, err := adb.CreateUser(types.User{Name: "test"})
|
||||
// require.NoError(t, err)
|
||||
|
||||
pak, err := adb.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||
require.NoError(t, err)
|
||||
// pak, err := adb.CreatePreAuthKey(types.UserID(user.ID), false, nil, nil)
|
||||
// require.NoError(t, err)
|
||||
|
||||
nodeKey := key.NewNode()
|
||||
machineKey := key.NewMachine()
|
||||
// nodeKey := key.NewNode()
|
||||
// machineKey := key.NewMachine()
|
||||
|
||||
v4 := netip.MustParseAddr("100.64.0.1")
|
||||
node := types.Node{
|
||||
ID: 0,
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
Hostname: "test",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: ptr.To(pak.ID),
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RequestTags: []string{"tag:exit"},
|
||||
RoutableIPs: tt.routes,
|
||||
},
|
||||
IPv4: &v4,
|
||||
}
|
||||
// v4 := netip.MustParseAddr("100.64.0.1")
|
||||
// node := types.Node{
|
||||
// ID: 0,
|
||||
// MachineKey: machineKey.Public(),
|
||||
// NodeKey: nodeKey.Public(),
|
||||
// Hostname: "test",
|
||||
// UserID: user.ID,
|
||||
// RegisterMethod: util.RegisterMethodAuthKey,
|
||||
// AuthKeyID: ptr.To(pak.ID),
|
||||
// Hostinfo: &tailcfg.Hostinfo{
|
||||
// RequestTags: []string{"tag:exit"},
|
||||
// RoutableIPs: tt.routes,
|
||||
// },
|
||||
// IPv4: &v4,
|
||||
// }
|
||||
|
||||
trx := adb.DB.Save(&node)
|
||||
require.NoError(t, trx.Error)
|
||||
// trx := adb.DB.Save(&node)
|
||||
// require.NoError(t, trx.Error)
|
||||
|
||||
sendUpdate, err := adb.SaveNodeRoutes(&node)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, sendUpdate)
|
||||
// sendUpdate, err := adb.SaveNodeRoutes(&node)
|
||||
// require.NoError(t, err)
|
||||
// assert.False(t, sendUpdate)
|
||||
|
||||
node0ByID, err := adb.GetNodeByID(0)
|
||||
require.NoError(t, err)
|
||||
// node0ByID, err := adb.GetNodeByID(0)
|
||||
// require.NoError(t, err)
|
||||
|
||||
users, err := adb.ListUsers()
|
||||
assert.NoError(t, err)
|
||||
// users, err := adb.ListUsers()
|
||||
// assert.NoError(t, err)
|
||||
|
||||
nodes, err := adb.ListNodes()
|
||||
assert.NoError(t, err)
|
||||
// nodes, err := adb.ListNodes()
|
||||
// assert.NoError(t, err)
|
||||
|
||||
pm, err := policy.NewPolicyManager([]byte(tt.acl), users, nodes)
|
||||
assert.NoError(t, err)
|
||||
// pm, err := policy.NewPolicyManager([]byte(tt.acl), users, nodes)
|
||||
// assert.NoError(t, err)
|
||||
|
||||
// TODO(kradalby): Check state update
|
||||
err = adb.EnableAutoApprovedRoutes(pm, node0ByID)
|
||||
require.NoError(t, err)
|
||||
// // TODO(kradalby): Check state update
|
||||
// err = adb.EnableAutoApprovedRoutes(pm, node0ByID)
|
||||
// require.NoError(t, err)
|
||||
|
||||
enabledRoutes, err := adb.GetEnabledRoutes(node0ByID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, enabledRoutes, len(tt.want))
|
||||
// enabledRoutes, err := adb.GetEnabledRoutes(node0ByID)
|
||||
// require.NoError(t, err)
|
||||
// assert.Len(t, enabledRoutes, len(tt.want))
|
||||
|
||||
tsaddr.SortPrefixes(enabledRoutes)
|
||||
// tsaddr.SortPrefixes(enabledRoutes)
|
||||
|
||||
if diff := cmp.Diff(tt.want, enabledRoutes, util.Comparers...); diff != "" {
|
||||
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
// if diff := cmp.Diff(tt.want, enabledRoutes, util.Comparers...); diff != "" {
|
||||
// t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
|
||||
// }
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
|
||||
func TestEphemeralGarbageCollectorOrder(t *testing.T) {
|
||||
want := []types.NodeID{1, 3}
|
||||
|
|
|
@ -1,676 +0,0 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sort"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/puzpuzpuz/xsync/v3"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/types/ptr"
|
||||
"tailscale.com/util/set"
|
||||
)
|
||||
|
||||
var ErrRouteIsNotAvailable = errors.New("route is not available")
|
||||
|
||||
func GetRoutes(tx *gorm.DB) (types.Routes, error) {
|
||||
var routes types.Routes
|
||||
err := tx.
|
||||
Preload("Node").
|
||||
Preload("Node.User").
|
||||
Find(&routes).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
func getAdvertisedAndEnabledRoutes(tx *gorm.DB) (types.Routes, error) {
|
||||
var routes types.Routes
|
||||
err := tx.
|
||||
Preload("Node").
|
||||
Preload("Node.User").
|
||||
Where("advertised = ? AND enabled = ?", true, true).
|
||||
Find(&routes).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
func getRoutesByPrefix(tx *gorm.DB, pref netip.Prefix) (types.Routes, error) {
|
||||
var routes types.Routes
|
||||
err := tx.
|
||||
Preload("Node").
|
||||
Preload("Node.User").
|
||||
Where("prefix = ?", pref.String()).
|
||||
Find(&routes).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
func GetNodeAdvertisedRoutes(tx *gorm.DB, node *types.Node) (types.Routes, error) {
|
||||
var routes types.Routes
|
||||
err := tx.
|
||||
Preload("Node").
|
||||
Preload("Node.User").
|
||||
Where("node_id = ? AND advertised = true", node.ID).
|
||||
Find(&routes).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetNodeRoutes(node *types.Node) (types.Routes, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (types.Routes, error) {
|
||||
return GetNodeRoutes(rx, node)
|
||||
})
|
||||
}
|
||||
|
||||
func GetNodeRoutes(tx *gorm.DB, node *types.Node) (types.Routes, error) {
|
||||
var routes types.Routes
|
||||
err := tx.
|
||||
Preload("Node").
|
||||
Preload("Node.User").
|
||||
Where("node_id = ?", node.ID).
|
||||
Find(&routes).Error
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
func GetRoute(tx *gorm.DB, id uint64) (*types.Route, error) {
|
||||
var route types.Route
|
||||
err := tx.
|
||||
Preload("Node").
|
||||
Preload("Node.User").
|
||||
First(&route, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &route, nil
|
||||
}
|
||||
|
||||
func EnableRoute(tx *gorm.DB, id uint64) (*types.StateUpdate, error) {
|
||||
route, err := GetRoute(tx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Tailscale requires both IPv4 and IPv6 exit routes to
|
||||
// be enabled at the same time, as per
|
||||
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
|
||||
if route.IsExitRoute() {
|
||||
return enableRoutes(
|
||||
tx,
|
||||
route.Node,
|
||||
tsaddr.AllIPv4(),
|
||||
tsaddr.AllIPv6(),
|
||||
)
|
||||
}
|
||||
|
||||
return enableRoutes(tx, route.Node, netip.Prefix(route.Prefix))
|
||||
}
|
||||
|
||||
func DisableRoute(tx *gorm.DB,
|
||||
id uint64,
|
||||
isLikelyConnected *xsync.MapOf[types.NodeID, bool],
|
||||
) ([]types.NodeID, error) {
|
||||
route, err := GetRoute(tx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var routes types.Routes
|
||||
node := route.Node
|
||||
|
||||
// Tailscale requires both IPv4 and IPv6 exit routes to
|
||||
// be enabled at the same time, as per
|
||||
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
|
||||
var update []types.NodeID
|
||||
if !route.IsExitRoute() {
|
||||
route.Enabled = false
|
||||
err = tx.Save(route).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
update, err = failoverRouteTx(tx, isLikelyConnected, route)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
routes, err = GetNodeRoutes(tx, node)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := range routes {
|
||||
if routes[i].IsExitRoute() {
|
||||
routes[i].Enabled = false
|
||||
routes[i].IsPrimary = false
|
||||
|
||||
err = tx.Save(&routes[i]).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If update is empty, it means that one was not created
|
||||
// by failover (as a failover was not necessary), create
|
||||
// one and return to the caller.
|
||||
if update == nil {
|
||||
update = []types.NodeID{node.ID}
|
||||
}
|
||||
|
||||
return update, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) DeleteRoute(
|
||||
id uint64,
|
||||
isLikelyConnected *xsync.MapOf[types.NodeID, bool],
|
||||
) ([]types.NodeID, error) {
|
||||
return Write(hsdb.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
|
||||
return DeleteRoute(tx, id, isLikelyConnected)
|
||||
})
|
||||
}
|
||||
|
||||
func DeleteRoute(
|
||||
tx *gorm.DB,
|
||||
id uint64,
|
||||
isLikelyConnected *xsync.MapOf[types.NodeID, bool],
|
||||
) ([]types.NodeID, error) {
|
||||
route, err := GetRoute(tx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if route.Node == nil {
|
||||
// If the route is not assigned to a node, just delete it,
|
||||
// there are no updates to be sent as no nodes are
|
||||
// dependent on it
|
||||
if err := tx.Unscoped().Delete(&route).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var routes types.Routes
|
||||
node := route.Node
|
||||
|
||||
// Tailscale requires both IPv4 and IPv6 exit routes to
|
||||
// be enabled at the same time, as per
|
||||
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
|
||||
// This means that if we delete a route which is an exit route, delete both.
|
||||
var update []types.NodeID
|
||||
if route.IsExitRoute() {
|
||||
routes, err = GetNodeRoutes(tx, node)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var routesToDelete types.Routes
|
||||
for _, r := range routes {
|
||||
if r.IsExitRoute() {
|
||||
routesToDelete = append(routesToDelete, r)
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Unscoped().Delete(&routesToDelete).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
update, err = failoverRouteTx(tx, isLikelyConnected, route)
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if err := tx.Unscoped().Delete(&route).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// If update is empty, it means that one was not created
|
||||
// by failover (as a failover was not necessary), create
|
||||
// one and return to the caller.
|
||||
if routes == nil {
|
||||
routes, err = GetNodeRoutes(tx, node)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
node.Routes = routes
|
||||
|
||||
if update == nil {
|
||||
update = []types.NodeID{node.ID}
|
||||
}
|
||||
|
||||
return update, nil
|
||||
}
|
||||
|
||||
func deleteNodeRoutes(tx *gorm.DB, node *types.Node, isLikelyConnected *xsync.MapOf[types.NodeID, bool]) ([]types.NodeID, error) {
|
||||
routes, err := GetNodeRoutes(tx, node)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting node routes: %w", err)
|
||||
}
|
||||
|
||||
var changed []types.NodeID
|
||||
for i := range routes {
|
||||
if err := tx.Unscoped().Delete(&routes[i]).Error; err != nil {
|
||||
return nil, fmt.Errorf("deleting route(%d): %w", &routes[i].ID, err)
|
||||
}
|
||||
|
||||
// TODO(kradalby): This is a bit too aggressive, we could probably
|
||||
// figure out which routes needs to be failed over rather than all.
|
||||
chn, err := failoverRouteTx(tx, isLikelyConnected, &routes[i])
|
||||
if err != nil {
|
||||
return changed, fmt.Errorf("failing over route after delete: %w", err)
|
||||
}
|
||||
|
||||
if chn != nil {
|
||||
changed = append(changed, chn...)
|
||||
}
|
||||
}
|
||||
|
||||
return changed, nil
|
||||
}
|
||||
|
||||
// isUniquePrefix returns if there is another node providing the same route already.
|
||||
func isUniquePrefix(tx *gorm.DB, route types.Route) bool {
|
||||
var count int64
|
||||
tx.Model(&types.Route{}).
|
||||
Where("prefix = ? AND node_id != ? AND advertised = ? AND enabled = ?",
|
||||
route.Prefix.String(),
|
||||
route.NodeID,
|
||||
true, true).Count(&count)
|
||||
|
||||
return count == 0
|
||||
}
|
||||
|
||||
func getPrimaryRoute(tx *gorm.DB, prefix netip.Prefix) (*types.Route, error) {
|
||||
var route types.Route
|
||||
err := tx.
|
||||
Preload("Node").
|
||||
Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", prefix.String(), true, true, true).
|
||||
First(&route).Error
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
return &route, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetNodePrimaryRoutes(node *types.Node) (types.Routes, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (types.Routes, error) {
|
||||
return GetNodePrimaryRoutes(rx, node)
|
||||
})
|
||||
}
|
||||
|
||||
// getNodePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover)
|
||||
// Exit nodes are not considered for this, as they are never marked as Primary.
|
||||
func GetNodePrimaryRoutes(tx *gorm.DB, node *types.Node) (types.Routes, error) {
|
||||
var routes types.Routes
|
||||
err := tx.
|
||||
Preload("Node").
|
||||
Where("node_id = ? AND advertised = ? AND enabled = ? AND is_primary = ?", node.ID, true, true, true).
|
||||
Find(&routes).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) SaveNodeRoutes(node *types.Node) (bool, error) {
|
||||
return Write(hsdb.DB, func(tx *gorm.DB) (bool, error) {
|
||||
return SaveNodeRoutes(tx, node)
|
||||
})
|
||||
}
|
||||
|
||||
// SaveNodeRoutes takes a node and updates the database with
|
||||
// the new routes.
|
||||
// It returns a bool whether an update should be sent as the
|
||||
// saved route impacts nodes.
|
||||
func SaveNodeRoutes(tx *gorm.DB, node *types.Node) (bool, error) {
|
||||
sendUpdate := false
|
||||
|
||||
currentRoutes := types.Routes{}
|
||||
err := tx.Where("node_id = ?", node.ID).Find(¤tRoutes).Error
|
||||
if err != nil {
|
||||
return sendUpdate, err
|
||||
}
|
||||
|
||||
advertisedRoutes := map[netip.Prefix]bool{}
|
||||
for _, prefix := range node.Hostinfo.RoutableIPs {
|
||||
advertisedRoutes[prefix] = false
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
Str("node", node.Hostname).
|
||||
Interface("advertisedRoutes", advertisedRoutes).
|
||||
Interface("currentRoutes", currentRoutes).
|
||||
Msg("updating routes")
|
||||
|
||||
for pos, route := range currentRoutes {
|
||||
if _, ok := advertisedRoutes[netip.Prefix(route.Prefix)]; ok {
|
||||
if !route.Advertised {
|
||||
currentRoutes[pos].Advertised = true
|
||||
err := tx.Save(¤tRoutes[pos]).Error
|
||||
if err != nil {
|
||||
return sendUpdate, err
|
||||
}
|
||||
|
||||
// If a route that is newly "saved" is already
|
||||
// enabled, set sendUpdate to true as it is now
|
||||
// available.
|
||||
if route.Enabled {
|
||||
sendUpdate = true
|
||||
}
|
||||
}
|
||||
advertisedRoutes[netip.Prefix(route.Prefix)] = true
|
||||
} else if route.Advertised {
|
||||
currentRoutes[pos].Advertised = false
|
||||
currentRoutes[pos].Enabled = false
|
||||
err := tx.Save(¤tRoutes[pos]).Error
|
||||
if err != nil {
|
||||
return sendUpdate, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for prefix, exists := range advertisedRoutes {
|
||||
if !exists {
|
||||
route := types.Route{
|
||||
NodeID: node.ID.Uint64(),
|
||||
Prefix: prefix,
|
||||
Advertised: true,
|
||||
Enabled: false,
|
||||
}
|
||||
err := tx.Create(&route).Error
|
||||
if err != nil {
|
||||
return sendUpdate, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return sendUpdate, nil
|
||||
}
|
||||
|
||||
// FailoverNodeRoutesIfNecessary takes a node and checks if the node's route
|
||||
// need to be failed over to another host.
|
||||
// If needed, the failover will be attempted.
|
||||
func FailoverNodeRoutesIfNecessary(
|
||||
tx *gorm.DB,
|
||||
isLikelyConnected *xsync.MapOf[types.NodeID, bool],
|
||||
node *types.Node,
|
||||
) (*types.StateUpdate, error) {
|
||||
nodeRoutes, err := GetNodeRoutes(tx, node)
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
changedNodes := make(set.Set[types.NodeID])
|
||||
|
||||
nodeRouteLoop:
|
||||
for _, nodeRoute := range nodeRoutes {
|
||||
routes, err := getRoutesByPrefix(tx, netip.Prefix(nodeRoute.Prefix))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting routes by prefix: %w", err)
|
||||
}
|
||||
|
||||
for _, route := range routes {
|
||||
if route.IsPrimary {
|
||||
// if we have a primary route, and the node is connected
|
||||
// nothing needs to be done.
|
||||
if val, ok := isLikelyConnected.Load(route.Node.ID); ok && val {
|
||||
continue nodeRouteLoop
|
||||
}
|
||||
|
||||
// if not, we need to failover the route
|
||||
failover := failoverRoute(isLikelyConnected, &route, routes)
|
||||
if failover != nil {
|
||||
err := failover.save(tx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("saving failover routes: %w", err)
|
||||
}
|
||||
|
||||
changedNodes.Add(failover.old.Node.ID)
|
||||
changedNodes.Add(failover.new.Node.ID)
|
||||
|
||||
continue nodeRouteLoop
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
chng := changedNodes.Slice()
|
||||
sort.SliceStable(chng, func(i, j int) bool {
|
||||
return chng[i] < chng[j]
|
||||
})
|
||||
|
||||
if len(changedNodes) != 0 {
|
||||
return ptr.To(types.UpdatePeerChanged(chng...)), nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// failoverRouteTx takes a route that is no longer available,
|
||||
// this can be either from:
|
||||
// - being disabled
|
||||
// - being deleted
|
||||
// - host going offline
|
||||
//
|
||||
// and tries to find a new route to take over its place.
|
||||
// If the given route was not primary, it returns early.
|
||||
func failoverRouteTx(
|
||||
tx *gorm.DB,
|
||||
isLikelyConnected *xsync.MapOf[types.NodeID, bool],
|
||||
r *types.Route,
|
||||
) ([]types.NodeID, error) {
|
||||
if r == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// This route is not a primary route, and it is not
|
||||
// being served to nodes.
|
||||
if !r.IsPrimary {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// We do not have to failover exit nodes
|
||||
if r.IsExitRoute() {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
routes, err := getRoutesByPrefix(tx, netip.Prefix(r.Prefix))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting routes by prefix: %w", err)
|
||||
}
|
||||
|
||||
fo := failoverRoute(isLikelyConnected, r, routes)
|
||||
if fo == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
err = fo.save(tx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("saving failover route: %w", err)
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
Str("hostname", fo.new.Node.Hostname).
|
||||
Msgf("set primary to new route, was: id(%d), host(%s), now: id(%d), host(%s)", fo.old.ID, fo.old.Node.Hostname, fo.new.ID, fo.new.Node.Hostname)
|
||||
|
||||
// Return a list of the machinekeys of the changed nodes.
|
||||
return []types.NodeID{fo.old.Node.ID, fo.new.Node.ID}, nil
|
||||
}
|
||||
|
||||
type failover struct {
|
||||
old *types.Route
|
||||
new *types.Route
|
||||
}
|
||||
|
||||
func (f *failover) save(tx *gorm.DB) error {
|
||||
err := tx.Save(f.old).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("saving old primary: %w", err)
|
||||
}
|
||||
|
||||
err = tx.Save(f.new).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("saving new primary: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func failoverRoute(
|
||||
isLikelyConnected *xsync.MapOf[types.NodeID, bool],
|
||||
routeToReplace *types.Route,
|
||||
altRoutes types.Routes,
|
||||
) *failover {
|
||||
if routeToReplace == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// This route is not a primary route, and it is not
|
||||
// being served to nodes.
|
||||
if !routeToReplace.IsPrimary {
|
||||
return nil
|
||||
}
|
||||
|
||||
// We do not have to failover exit nodes
|
||||
if routeToReplace.IsExitRoute() {
|
||||
return nil
|
||||
}
|
||||
|
||||
var newPrimary *types.Route
|
||||
|
||||
// Find a new suitable route
|
||||
for idx, route := range altRoutes {
|
||||
if routeToReplace.ID == route.ID {
|
||||
continue
|
||||
}
|
||||
|
||||
if !route.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
if isLikelyConnected != nil {
|
||||
if val, ok := isLikelyConnected.Load(route.Node.ID); ok && val {
|
||||
newPrimary = &altRoutes[idx]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If a new route was not found/available,
|
||||
// return without an error.
|
||||
// We do not want to update the database as
|
||||
// the one currently marked as primary is the
|
||||
// best we got.
|
||||
if newPrimary == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
routeToReplace.IsPrimary = false
|
||||
newPrimary.IsPrimary = true
|
||||
|
||||
return &failover{
|
||||
old: routeToReplace,
|
||||
new: newPrimary,
|
||||
}
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
|
||||
polMan policy.PolicyManager,
|
||||
node *types.Node,
|
||||
) error {
|
||||
return hsdb.Write(func(tx *gorm.DB) error {
|
||||
return EnableAutoApprovedRoutes(tx, polMan, node)
|
||||
})
|
||||
}
|
||||
|
||||
// EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy.
|
||||
func EnableAutoApprovedRoutes(
|
||||
tx *gorm.DB,
|
||||
polMan policy.PolicyManager,
|
||||
node *types.Node,
|
||||
) error {
|
||||
if node.IPv4 == nil && node.IPv6 == nil {
|
||||
return nil // This node has no IPAddresses, so can't possibly match any autoApprovers ACLs
|
||||
}
|
||||
|
||||
routes, err := GetNodeAdvertisedRoutes(tx, node)
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return fmt.Errorf("getting advertised routes for node(%s %d): %w", node.Hostname, node.ID, err)
|
||||
}
|
||||
|
||||
log.Trace().Interface("routes", routes).Msg("routes for autoapproving")
|
||||
|
||||
var approvedRoutes types.Routes
|
||||
|
||||
for _, advertisedRoute := range routes {
|
||||
if advertisedRoute.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
routeApprovers := polMan.ApproversForRoute(netip.Prefix(advertisedRoute.Prefix))
|
||||
|
||||
log.Trace().
|
||||
Str("node", node.Hostname).
|
||||
Uint("user.id", node.User.ID).
|
||||
Strs("routeApprovers", routeApprovers).
|
||||
Str("prefix", netip.Prefix(advertisedRoute.Prefix).String()).
|
||||
Msg("looking up route for autoapproving")
|
||||
|
||||
for _, approvedAlias := range routeApprovers {
|
||||
if approvedAlias == node.User.Username() {
|
||||
approvedRoutes = append(approvedRoutes, advertisedRoute)
|
||||
} else {
|
||||
// TODO(kradalby): figure out how to get this to depend on less stuff
|
||||
approvedIps, err := polMan.ExpandAlias(approvedAlias)
|
||||
if err != nil {
|
||||
return fmt.Errorf("expanding alias %q for autoApprovers: %w", approvedAlias, err)
|
||||
}
|
||||
|
||||
// approvedIPs should contain all of node's IPs if it matches the rule, so check for first
|
||||
if approvedIps.Contains(*node.IPv4) {
|
||||
approvedRoutes = append(approvedRoutes, advertisedRoute)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, approvedRoute := range approvedRoutes {
|
||||
_, err := EnableRoute(tx, uint64(approvedRoute.ID))
|
||||
if err != nil {
|
||||
return fmt.Errorf("enabling approved route(%d): %w", approvedRoute.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
File diff suppressed because it is too large
Load diff
|
@ -100,6 +100,11 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
|||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(registrationsJSON)
|
||||
}))
|
||||
debug.Handle("routes", "Routes", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(h.primaryRoutes.String()))
|
||||
}))
|
||||
|
||||
err := statsviz.Register(debugMux)
|
||||
if err == nil {
|
||||
|
|
|
@ -6,7 +6,9 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -18,6 +20,7 @@ import (
|
|||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
|
||||
|
@ -326,6 +329,51 @@ func (api headscaleV1APIServer) SetTags(
|
|||
return &v1.SetTagsResponse{Node: node.Proto()}, nil
|
||||
}
|
||||
|
||||
func (api headscaleV1APIServer) SetApprovedRoutes(
|
||||
ctx context.Context,
|
||||
request *v1.SetApprovedRoutesRequest,
|
||||
) (*v1.SetApprovedRoutesResponse, error) {
|
||||
var routes []netip.Prefix
|
||||
for _, route := range request.GetRoutes() {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing route: %w", err)
|
||||
}
|
||||
|
||||
// If the prefix is an exit route, add both. The client expect both
|
||||
// to annotate the node as an exit node.
|
||||
if prefix == tsaddr.AllIPv4() || prefix == tsaddr.AllIPv6() {
|
||||
routes = append(routes, tsaddr.AllIPv4(), tsaddr.AllIPv6())
|
||||
} else {
|
||||
routes = append(routes, prefix)
|
||||
}
|
||||
}
|
||||
slices.SortFunc(routes, util.ComparePrefix)
|
||||
slices.Compact(routes)
|
||||
|
||||
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
err := db.SetApprovedRoutes(tx, types.NodeID(request.GetNodeId()), routes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db.GetNodeByID(tx, types.NodeID(request.GetNodeId()))
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.InvalidArgument, err.Error())
|
||||
}
|
||||
|
||||
if api.h.primaryRoutes.SetRoutes(node.ID, node.SubnetRoutes()...) {
|
||||
ctx := types.NotifyCtx(ctx, "poll-primary-change", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
} else {
|
||||
ctx = types.NotifyCtx(ctx, "cli-approveroutes", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
|
||||
}
|
||||
|
||||
return &v1.SetApprovedRoutesResponse{Node: node.Proto()}, nil
|
||||
}
|
||||
|
||||
func validateTag(tag string) error {
|
||||
if strings.Index(tag, "tag:") != 0 {
|
||||
return errors.New("tag must start with the string 'tag:'")
|
||||
|
@ -348,10 +396,7 @@ func (api headscaleV1APIServer) DeleteNode(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
changedNodes, err := api.h.db.DeleteNode(
|
||||
node,
|
||||
api.h.nodeNotifier.LikelyConnectedMap(),
|
||||
)
|
||||
err = api.h.db.DeleteNode(node)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -359,10 +404,6 @@ func (api headscaleV1APIServer) DeleteNode(
|
|||
ctx = types.NotifyCtx(ctx, "cli-deletenode", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerRemoved(node.ID))
|
||||
|
||||
if changedNodes != nil {
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(changedNodes...))
|
||||
}
|
||||
|
||||
return &v1.DeleteNodeResponse{}, nil
|
||||
}
|
||||
|
||||
|
@ -533,100 +574,6 @@ func (api headscaleV1APIServer) BackfillNodeIPs(
|
|||
return &v1.BackfillNodeIPsResponse{Changes: changes}, nil
|
||||
}
|
||||
|
||||
func (api headscaleV1APIServer) GetRoutes(
|
||||
ctx context.Context,
|
||||
request *v1.GetRoutesRequest,
|
||||
) (*v1.GetRoutesResponse, error) {
|
||||
routes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Routes, error) {
|
||||
return db.GetRoutes(rx)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &v1.GetRoutesResponse{
|
||||
Routes: types.Routes(routes).Proto(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (api headscaleV1APIServer) EnableRoute(
|
||||
ctx context.Context,
|
||||
request *v1.EnableRouteRequest,
|
||||
) (*v1.EnableRouteResponse, error) {
|
||||
update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
|
||||
return db.EnableRoute(tx, request.GetRouteId())
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if update != nil {
|
||||
ctx := types.NotifyCtx(ctx, "cli-enableroute", "unknown")
|
||||
api.h.nodeNotifier.NotifyAll(
|
||||
ctx, *update)
|
||||
}
|
||||
|
||||
return &v1.EnableRouteResponse{}, nil
|
||||
}
|
||||
|
||||
func (api headscaleV1APIServer) DisableRoute(
|
||||
ctx context.Context,
|
||||
request *v1.DisableRouteRequest,
|
||||
) (*v1.DisableRouteResponse, error) {
|
||||
update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
|
||||
return db.DisableRoute(tx, request.GetRouteId(), api.h.nodeNotifier.LikelyConnectedMap())
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if update != nil {
|
||||
ctx := types.NotifyCtx(ctx, "cli-disableroute", "unknown")
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(update...))
|
||||
}
|
||||
|
||||
return &v1.DisableRouteResponse{}, nil
|
||||
}
|
||||
|
||||
func (api headscaleV1APIServer) GetNodeRoutes(
|
||||
ctx context.Context,
|
||||
request *v1.GetNodeRoutesRequest,
|
||||
) (*v1.GetNodeRoutesResponse, error) {
|
||||
node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
routes, err := api.h.db.GetNodeRoutes(node)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &v1.GetNodeRoutesResponse{
|
||||
Routes: types.Routes(routes).Proto(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (api headscaleV1APIServer) DeleteRoute(
|
||||
ctx context.Context,
|
||||
request *v1.DeleteRouteRequest,
|
||||
) (*v1.DeleteRouteResponse, error) {
|
||||
isConnected := api.h.nodeNotifier.LikelyConnectedMap()
|
||||
update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
|
||||
return db.DeleteRoute(tx, request.GetRouteId(), isConnected)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if update != nil {
|
||||
ctx := types.NotifyCtx(ctx, "cli-deleteroute", "unknown")
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(update...))
|
||||
}
|
||||
|
||||
return &v1.DeleteRouteResponse{}, nil
|
||||
}
|
||||
|
||||
func (api headscaleV1APIServer) CreateApiKey(
|
||||
ctx context.Context,
|
||||
request *v1.CreateApiKeyRequest,
|
||||
|
|
|
@ -18,6 +18,7 @@ import (
|
|||
"github.com/juanfont/headscale/hscontrol/db"
|
||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/routes"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
|
@ -56,6 +57,7 @@ type Mapper struct {
|
|||
derpMap *tailcfg.DERPMap
|
||||
notif *notifier.Notifier
|
||||
polMan policy.PolicyManager
|
||||
primary *routes.PrimaryRoutes
|
||||
|
||||
uid string
|
||||
created time.Time
|
||||
|
@ -73,6 +75,7 @@ func NewMapper(
|
|||
derpMap *tailcfg.DERPMap,
|
||||
notif *notifier.Notifier,
|
||||
polMan policy.PolicyManager,
|
||||
primary *routes.PrimaryRoutes,
|
||||
) *Mapper {
|
||||
uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
|
||||
|
||||
|
@ -82,6 +85,7 @@ func NewMapper(
|
|||
derpMap: derpMap,
|
||||
notif: notif,
|
||||
polMan: polMan,
|
||||
primary: primary,
|
||||
|
||||
uid: uid,
|
||||
created: time.Now(),
|
||||
|
@ -97,15 +101,22 @@ func generateUserProfiles(
|
|||
node *types.Node,
|
||||
peers types.Nodes,
|
||||
) []tailcfg.UserProfile {
|
||||
userMap := make(map[uint]types.User)
|
||||
userMap[node.User.ID] = node.User
|
||||
userMap := make(map[uint]*types.User)
|
||||
ids := make([]uint, 0, len(userMap))
|
||||
userMap[node.User.ID] = &node.User
|
||||
ids = append(ids, node.User.ID)
|
||||
for _, peer := range peers {
|
||||
userMap[peer.User.ID] = peer.User // not worth checking if already is there
|
||||
userMap[peer.User.ID] = &peer.User
|
||||
ids = append(ids, peer.User.ID)
|
||||
}
|
||||
|
||||
slices.Sort(ids)
|
||||
slices.Compact(ids)
|
||||
var profiles []tailcfg.UserProfile
|
||||
for _, user := range userMap {
|
||||
profiles = append(profiles, user.TailscaleUserProfile())
|
||||
for _, id := range ids {
|
||||
if userMap[id] != nil {
|
||||
profiles = append(profiles, userMap[id].TailscaleUserProfile())
|
||||
}
|
||||
}
|
||||
|
||||
return profiles
|
||||
|
@ -166,6 +177,7 @@ func (m *Mapper) fullMapResponse(
|
|||
resp,
|
||||
true, // full change
|
||||
m.polMan,
|
||||
m.primary,
|
||||
node,
|
||||
capVer,
|
||||
peers,
|
||||
|
@ -271,6 +283,7 @@ func (m *Mapper) PeerChangedResponse(
|
|||
&resp,
|
||||
false, // partial change
|
||||
m.polMan,
|
||||
m.primary,
|
||||
node,
|
||||
mapRequest.Version,
|
||||
changedNodes,
|
||||
|
@ -299,7 +312,7 @@ func (m *Mapper) PeerChangedResponse(
|
|||
|
||||
// Add the node itself, it might have changed, and particularly
|
||||
// if there are no patches or changes, this is a self update.
|
||||
tailnode, err := tailNode(node, mapRequest.Version, m.polMan, m.cfg)
|
||||
tailnode, err := tailNode(node, mapRequest.Version, m.polMan, m.primary, m.cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -446,7 +459,7 @@ func (m *Mapper) baseWithConfigMapResponse(
|
|||
) (*tailcfg.MapResponse, error) {
|
||||
resp := m.baseMapResponse()
|
||||
|
||||
tailnode, err := tailNode(node, capVer, m.polMan, m.cfg)
|
||||
tailnode, err := tailNode(node, capVer, m.polMan, m.primary, m.cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -500,6 +513,7 @@ func appendPeerChanges(
|
|||
|
||||
fullChange bool,
|
||||
polMan policy.PolicyManager,
|
||||
primary *routes.PrimaryRoutes,
|
||||
node *types.Node,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
changed types.Nodes,
|
||||
|
@ -522,7 +536,7 @@ func appendPeerChanges(
|
|||
|
||||
dnsConfig := generateDNSConfig(cfg, node)
|
||||
|
||||
tailPeers, err := tailNodes(changed, capVer, polMan, cfg)
|
||||
tailPeers, err := tailNodes(changed, capVer, polMan, primary, cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -6,12 +6,11 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/routes"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"gopkg.in/check.v1"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
|
@ -24,51 +23,6 @@ var iap = func(ipStr string) *netip.Addr {
|
|||
return &ip
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
|
||||
mach := func(hostname, username string, userid uint) *types.Node {
|
||||
return &types.Node{
|
||||
Hostname: hostname,
|
||||
UserID: userid,
|
||||
User: types.User{
|
||||
Model: gorm.Model{
|
||||
ID: userid,
|
||||
},
|
||||
Name: username,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
nodeInShared1 := mach("test_get_shared_nodes_1", "user1", 1)
|
||||
nodeInShared2 := mach("test_get_shared_nodes_2", "user2", 2)
|
||||
nodeInShared3 := mach("test_get_shared_nodes_3", "user3", 3)
|
||||
node2InShared1 := mach("test_get_shared_nodes_4", "user1", 1)
|
||||
|
||||
userProfiles := generateUserProfiles(
|
||||
nodeInShared1,
|
||||
types.Nodes{
|
||||
nodeInShared2, nodeInShared3, node2InShared1,
|
||||
},
|
||||
)
|
||||
|
||||
c.Assert(len(userProfiles), check.Equals, 3)
|
||||
|
||||
users := []string{
|
||||
"user1", "user2", "user3",
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
found := false
|
||||
for _, userProfile := range userProfiles {
|
||||
if userProfile.DisplayName == user {
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
c.Assert(found, check.Equals, true)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSConfigMapResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
magicDNS bool
|
||||
|
@ -159,11 +113,11 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
lastSeen := time.Date(2009, time.November, 10, 23, 9, 0, 0, time.UTC)
|
||||
expire := time.Date(2500, time.November, 11, 23, 0, 0, 0, time.UTC)
|
||||
|
||||
user1 := types.User{Model: gorm.Model{ID: 0}, Name: "mini"}
|
||||
user2 := types.User{Model: gorm.Model{ID: 1}, Name: "peer2"}
|
||||
user1 := types.User{Model: gorm.Model{ID: 1}, Name: "user1"}
|
||||
user2 := types.User{Model: gorm.Model{ID: 2}, Name: "user2"}
|
||||
|
||||
mini := &types.Node{
|
||||
ID: 0,
|
||||
ID: 1,
|
||||
MachineKey: mustMK(
|
||||
"mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507",
|
||||
),
|
||||
|
@ -182,35 +136,22 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
AuthKey: &types.PreAuthKey{},
|
||||
LastSeen: &lastSeen,
|
||||
Expiry: &expire,
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
Routes: []types.Route{
|
||||
{
|
||||
Prefix: tsaddr.AllIPv4(),
|
||||
Advertised: true,
|
||||
Enabled: true,
|
||||
IsPrimary: false,
|
||||
},
|
||||
{
|
||||
Prefix: netip.MustParsePrefix("192.168.0.0/24"),
|
||||
Advertised: true,
|
||||
Enabled: true,
|
||||
IsPrimary: true,
|
||||
},
|
||||
{
|
||||
Prefix: netip.MustParsePrefix("172.0.0.0/10"),
|
||||
Advertised: true,
|
||||
Enabled: false,
|
||||
IsPrimary: true,
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{
|
||||
tsaddr.AllIPv4(),
|
||||
netip.MustParsePrefix("192.168.0.0/24"),
|
||||
netip.MustParsePrefix("172.0.0.0/10"),
|
||||
},
|
||||
},
|
||||
CreatedAt: created,
|
||||
ApprovedRoutes: []netip.Prefix{tsaddr.AllIPv4(), netip.MustParsePrefix("192.168.0.0/24")},
|
||||
CreatedAt: created,
|
||||
}
|
||||
|
||||
tailMini := &tailcfg.Node{
|
||||
ID: 0,
|
||||
StableID: "0",
|
||||
ID: 1,
|
||||
StableID: "1",
|
||||
Name: "mini",
|
||||
User: 0,
|
||||
User: tailcfg.UserID(user1.ID),
|
||||
Key: mustNK(
|
||||
"nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
|
||||
),
|
||||
|
@ -227,12 +168,17 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
tsaddr.AllIPv4(),
|
||||
netip.MustParsePrefix("192.168.0.0/24"),
|
||||
},
|
||||
HomeDERP: 0,
|
||||
LegacyDERPString: "127.3.3.40:0",
|
||||
Hostinfo: hiview(tailcfg.Hostinfo{}),
|
||||
HomeDERP: 0,
|
||||
LegacyDERPString: "127.3.3.40:0",
|
||||
Hostinfo: hiview(tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{
|
||||
tsaddr.AllIPv4(),
|
||||
netip.MustParsePrefix("192.168.0.0/24"),
|
||||
netip.MustParsePrefix("172.0.0.0/10"),
|
||||
},
|
||||
}),
|
||||
Created: created,
|
||||
Tags: []string{},
|
||||
PrimaryRoutes: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")},
|
||||
LastSeen: &lastSeen,
|
||||
MachineAuthorized: true,
|
||||
|
||||
|
@ -244,7 +190,7 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
}
|
||||
|
||||
peer1 := &types.Node{
|
||||
ID: 1,
|
||||
ID: 2,
|
||||
MachineKey: mustMK(
|
||||
"mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507",
|
||||
),
|
||||
|
@ -257,20 +203,20 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
IPv4: iap("100.64.0.2"),
|
||||
Hostname: "peer1",
|
||||
GivenName: "peer1",
|
||||
UserID: user1.ID,
|
||||
User: user1,
|
||||
UserID: user2.ID,
|
||||
User: user2,
|
||||
ForcedTags: []string{},
|
||||
LastSeen: &lastSeen,
|
||||
Expiry: &expire,
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
Routes: []types.Route{},
|
||||
CreatedAt: created,
|
||||
}
|
||||
|
||||
tailPeer1 := &tailcfg.Node{
|
||||
ID: 1,
|
||||
StableID: "1",
|
||||
ID: 2,
|
||||
StableID: "2",
|
||||
Name: "peer1",
|
||||
User: tailcfg.UserID(user2.ID),
|
||||
Key: mustNK(
|
||||
"nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
|
||||
),
|
||||
|
@ -288,7 +234,6 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
Hostinfo: hiview(tailcfg.Hostinfo{}),
|
||||
Created: created,
|
||||
Tags: []string{},
|
||||
PrimaryRoutes: []netip.Prefix{},
|
||||
LastSeen: &lastSeen,
|
||||
MachineAuthorized: true,
|
||||
|
||||
|
@ -299,30 +244,6 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
peer2 := &types.Node{
|
||||
ID: 2,
|
||||
MachineKey: mustMK(
|
||||
"mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507",
|
||||
),
|
||||
NodeKey: mustNK(
|
||||
"nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
|
||||
),
|
||||
DiscoKey: mustDK(
|
||||
"discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084",
|
||||
),
|
||||
IPv4: iap("100.64.0.3"),
|
||||
Hostname: "peer2",
|
||||
GivenName: "peer2",
|
||||
UserID: user2.ID,
|
||||
User: user2,
|
||||
ForcedTags: []string{},
|
||||
LastSeen: &lastSeen,
|
||||
Expiry: &expire,
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
Routes: []types.Route{},
|
||||
CreatedAt: created,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
pol *policy.ACLPolicy
|
||||
|
@ -364,7 +285,7 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
Domain: "",
|
||||
CollectServices: "false",
|
||||
PacketFilter: []tailcfg.FilterRule{},
|
||||
UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}},
|
||||
UserProfiles: []tailcfg.UserProfile{{ID: tailcfg.UserID(user1.ID), LoginName: "user1", DisplayName: "user1"}},
|
||||
SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
|
||||
ControlTime: &time.Time{},
|
||||
Debug: &tailcfg.Debug{
|
||||
|
@ -398,9 +319,12 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
Domain: "",
|
||||
CollectServices: "false",
|
||||
PacketFilter: []tailcfg.FilterRule{},
|
||||
UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}},
|
||||
SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
|
||||
ControlTime: &time.Time{},
|
||||
UserProfiles: []tailcfg.UserProfile{
|
||||
{ID: tailcfg.UserID(user1.ID), LoginName: "user1", DisplayName: "user1"},
|
||||
{ID: tailcfg.UserID(user2.ID), LoginName: "user2", DisplayName: "user2"},
|
||||
},
|
||||
SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
|
||||
ControlTime: &time.Time{},
|
||||
Debug: &tailcfg.Debug{
|
||||
DisableLogTail: true,
|
||||
},
|
||||
|
@ -410,6 +334,9 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
{
|
||||
name: "with-pol-map-response",
|
||||
pol: &policy.ACLPolicy{
|
||||
Hosts: policy.Hosts{
|
||||
"mini": netip.MustParsePrefix("100.64.0.1/32"),
|
||||
},
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
|
@ -421,7 +348,6 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
node: mini,
|
||||
peers: types.Nodes{
|
||||
peer1,
|
||||
peer2,
|
||||
},
|
||||
derpMap: &tailcfg.DERPMap{},
|
||||
cfg: &types.Config{
|
||||
|
@ -449,7 +375,8 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
},
|
||||
},
|
||||
UserProfiles: []tailcfg.UserProfile{
|
||||
{LoginName: "mini", DisplayName: "mini"},
|
||||
{ID: tailcfg.UserID(user1.ID), LoginName: "user1", DisplayName: "user1"},
|
||||
{ID: tailcfg.UserID(user2.ID), LoginName: "user2", DisplayName: "user2"},
|
||||
},
|
||||
SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
|
||||
ControlTime: &time.Time{},
|
||||
|
@ -464,6 +391,12 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
polMan, _ := policy.NewPolicyManagerForTest(tt.pol, []types.User{user1, user2}, append(tt.peers, tt.node))
|
||||
primary := routes.New()
|
||||
|
||||
primary.SetRoutes(tt.node.ID, tt.node.SubnetRoutes()...)
|
||||
for _, peer := range tt.peers {
|
||||
primary.SetRoutes(peer.ID, peer.SubnetRoutes()...)
|
||||
}
|
||||
|
||||
mappy := NewMapper(
|
||||
nil,
|
||||
|
@ -471,6 +404,7 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
tt.derpMap,
|
||||
nil,
|
||||
polMan,
|
||||
primary,
|
||||
)
|
||||
|
||||
got, err := mappy.fullMapResponse(
|
||||
|
@ -485,8 +419,6 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
return
|
||||
}
|
||||
|
||||
spew.Dump(got)
|
||||
|
||||
if diff := cmp.Diff(
|
||||
tt.want,
|
||||
got,
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/routes"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/samber/lo"
|
||||
"tailscale.com/tailcfg"
|
||||
|
@ -15,6 +16,7 @@ func tailNodes(
|
|||
nodes types.Nodes,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
polMan policy.PolicyManager,
|
||||
primary *routes.PrimaryRoutes,
|
||||
cfg *types.Config,
|
||||
) ([]*tailcfg.Node, error) {
|
||||
tNodes := make([]*tailcfg.Node, len(nodes))
|
||||
|
@ -24,6 +26,7 @@ func tailNodes(
|
|||
node,
|
||||
capVer,
|
||||
polMan,
|
||||
primary,
|
||||
cfg,
|
||||
)
|
||||
if err != nil {
|
||||
|
@ -41,6 +44,7 @@ func tailNode(
|
|||
node *types.Node,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
polMan policy.PolicyManager,
|
||||
primary *routes.PrimaryRoutes,
|
||||
cfg *types.Config,
|
||||
) (*tailcfg.Node, error) {
|
||||
addrs := node.Prefixes()
|
||||
|
@ -49,17 +53,8 @@ func tailNode(
|
|||
[]netip.Prefix{},
|
||||
addrs...) // we append the node own IP, as it is required by the clients
|
||||
|
||||
primaryPrefixes := []netip.Prefix{}
|
||||
|
||||
for _, route := range node.Routes {
|
||||
if route.Enabled {
|
||||
if route.IsPrimary {
|
||||
allowedIPs = append(allowedIPs, netip.Prefix(route.Prefix))
|
||||
primaryPrefixes = append(primaryPrefixes, netip.Prefix(route.Prefix))
|
||||
} else if route.IsExitRoute() {
|
||||
allowedIPs = append(allowedIPs, netip.Prefix(route.Prefix))
|
||||
}
|
||||
}
|
||||
for _, route := range node.SubnetRoutes() {
|
||||
allowedIPs = append(allowedIPs, netip.Prefix(route))
|
||||
}
|
||||
|
||||
var derp int
|
||||
|
@ -103,6 +98,7 @@ func tailNode(
|
|||
Machine: node.MachineKey,
|
||||
DiscoKey: node.DiscoKey,
|
||||
Addresses: addrs,
|
||||
PrimaryRoutes: primary.PrimaryRoutes(node.ID),
|
||||
AllowedIPs: allowedIPs,
|
||||
Endpoints: node.Endpoints,
|
||||
HomeDERP: derp,
|
||||
|
@ -114,8 +110,6 @@ func tailNode(
|
|||
|
||||
Tags: tags,
|
||||
|
||||
PrimaryRoutes: primaryPrefixes,
|
||||
|
||||
MachineAuthorized: !node.IsExpired(),
|
||||
Expired: node.IsExpired(),
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/routes"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
|
@ -72,7 +73,6 @@ func TestTailNode(t *testing.T) {
|
|||
LegacyDERPString: "127.3.3.40:0",
|
||||
Hostinfo: hiview(tailcfg.Hostinfo{}),
|
||||
Tags: []string{},
|
||||
PrimaryRoutes: []netip.Prefix{},
|
||||
MachineAuthorized: true,
|
||||
|
||||
CapMap: tailcfg.NodeCapMap{
|
||||
|
@ -107,28 +107,15 @@ func TestTailNode(t *testing.T) {
|
|||
AuthKey: &types.PreAuthKey{},
|
||||
LastSeen: &lastSeen,
|
||||
Expiry: &expire,
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
Routes: []types.Route{
|
||||
{
|
||||
Prefix: tsaddr.AllIPv4(),
|
||||
Advertised: true,
|
||||
Enabled: true,
|
||||
IsPrimary: false,
|
||||
},
|
||||
{
|
||||
Prefix: netip.MustParsePrefix("192.168.0.0/24"),
|
||||
Advertised: true,
|
||||
Enabled: true,
|
||||
IsPrimary: true,
|
||||
},
|
||||
{
|
||||
Prefix: netip.MustParsePrefix("172.0.0.0/10"),
|
||||
Advertised: true,
|
||||
Enabled: false,
|
||||
IsPrimary: true,
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{
|
||||
tsaddr.AllIPv4(),
|
||||
netip.MustParsePrefix("192.168.0.0/24"),
|
||||
netip.MustParsePrefix("172.0.0.0/10"),
|
||||
},
|
||||
},
|
||||
CreatedAt: created,
|
||||
ApprovedRoutes: []netip.Prefix{tsaddr.AllIPv4(), netip.MustParsePrefix("192.168.0.0/24")},
|
||||
CreatedAt: created,
|
||||
},
|
||||
pol: &policy.ACLPolicy{},
|
||||
dnsConfig: &tailcfg.DNSConfig{},
|
||||
|
@ -159,8 +146,14 @@ func TestTailNode(t *testing.T) {
|
|||
},
|
||||
HomeDERP: 0,
|
||||
LegacyDERPString: "127.3.3.40:0",
|
||||
Hostinfo: hiview(tailcfg.Hostinfo{}),
|
||||
Created: created,
|
||||
Hostinfo: hiview(tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{
|
||||
tsaddr.AllIPv4(),
|
||||
netip.MustParsePrefix("192.168.0.0/24"),
|
||||
netip.MustParsePrefix("172.0.0.0/10"),
|
||||
},
|
||||
}),
|
||||
Created: created,
|
||||
|
||||
Tags: []string{},
|
||||
|
||||
|
@ -187,15 +180,22 @@ func TestTailNode(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
polMan, _ := policy.NewPolicyManagerForTest(tt.pol, []types.User{}, types.Nodes{tt.node})
|
||||
primary := routes.New()
|
||||
cfg := &types.Config{
|
||||
BaseDomain: tt.baseDomain,
|
||||
TailcfgDNSConfig: tt.dnsConfig,
|
||||
RandomizeClientPort: false,
|
||||
}
|
||||
_ = primary.SetRoutes(tt.node.ID, tt.node.SubnetRoutes()...)
|
||||
|
||||
// This is a hack to avoid having a second node to test the primary route.
|
||||
// This should be baked into the test case proper if it is extended in the future.
|
||||
_ = primary.SetRoutes(2, netip.MustParsePrefix("192.168.0.0/24"))
|
||||
got, err := tailNode(
|
||||
tt.node,
|
||||
0,
|
||||
polMan,
|
||||
primary,
|
||||
cfg,
|
||||
)
|
||||
|
||||
|
@ -249,6 +249,7 @@ func TestNodeExpiry(t *testing.T) {
|
|||
node,
|
||||
0,
|
||||
&policy.PolicyManagerV1{},
|
||||
nil,
|
||||
&types.Config{},
|
||||
)
|
||||
if err != nil {
|
||||
|
|
|
@ -243,6 +243,7 @@ func (pol *ACLPolicy) CompileFilterRules(
|
|||
// ReduceFilterRules takes a node and a set of rules and removes all rules and destinations
|
||||
// that are not relevant to that particular node.
|
||||
func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.FilterRule {
|
||||
// TODO(kradalby): Make this nil and not alloc unless needed
|
||||
ret := []tailcfg.FilterRule{}
|
||||
|
||||
for _, rule := range rules {
|
||||
|
@ -264,13 +265,11 @@ func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.F
|
|||
|
||||
// If the node exposes routes, ensure they are note removed
|
||||
// when the filters are reduced.
|
||||
if node.Hostinfo != nil {
|
||||
if len(node.Hostinfo.RoutableIPs) > 0 {
|
||||
for _, routableIP := range node.Hostinfo.RoutableIPs {
|
||||
if expanded.OverlapsPrefix(routableIP) {
|
||||
dests = append(dests, dest)
|
||||
continue DEST_LOOP
|
||||
}
|
||||
if len(node.SubnetRoutes()) > 0 {
|
||||
for _, routableIP := range node.SubnetRoutes() {
|
||||
if expanded.OverlapsPrefix(routableIP) {
|
||||
dests = append(dests, dest)
|
||||
continue DEST_LOOP
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2165,6 +2165,9 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
netip.MustParsePrefix("10.33.0.0/16"),
|
||||
},
|
||||
},
|
||||
ApprovedRoutes: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.33.0.0/16"),
|
||||
},
|
||||
},
|
||||
peers: types.Nodes{
|
||||
&types.Node{
|
||||
|
@ -2292,6 +2295,7 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: tsaddr.ExitRoutes(),
|
||||
},
|
||||
ApprovedRoutes: tsaddr.ExitRoutes(),
|
||||
},
|
||||
peers: types.Nodes{
|
||||
&types.Node{
|
||||
|
@ -2398,6 +2402,7 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: tsaddr.ExitRoutes(),
|
||||
},
|
||||
ApprovedRoutes: tsaddr.ExitRoutes(),
|
||||
},
|
||||
peers: types.Nodes{
|
||||
&types.Node{
|
||||
|
@ -2513,6 +2518,10 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
netip.MustParsePrefix("16.0.0.0/16"),
|
||||
},
|
||||
},
|
||||
ApprovedRoutes: []netip.Prefix{
|
||||
netip.MustParsePrefix("8.0.0.0/16"),
|
||||
netip.MustParsePrefix("16.0.0.0/16"),
|
||||
},
|
||||
},
|
||||
peers: types.Nodes{
|
||||
&types.Node{
|
||||
|
@ -2603,6 +2612,10 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
netip.MustParsePrefix("16.0.0.0/8"),
|
||||
},
|
||||
},
|
||||
ApprovedRoutes: []netip.Prefix{
|
||||
netip.MustParsePrefix("8.0.0.0/8"),
|
||||
netip.MustParsePrefix("16.0.0.0/8"),
|
||||
},
|
||||
},
|
||||
peers: types.Nodes{
|
||||
&types.Node{
|
||||
|
@ -2683,7 +2696,8 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")},
|
||||
},
|
||||
ForcedTags: []string{"tag:access-servers"},
|
||||
ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")},
|
||||
ForcedTags: []string{"tag:access-servers"},
|
||||
},
|
||||
peers: types.Nodes{
|
||||
&types.Node{
|
||||
|
@ -3475,14 +3489,10 @@ func Test_getFilteredByACLPeers(t *testing.T) {
|
|||
IPv4: iap("100.64.0.2"),
|
||||
Hostname: "router",
|
||||
User: types.User{Name: "router"},
|
||||
Routes: types.Routes{
|
||||
types.Route{
|
||||
NodeID: 2,
|
||||
Prefix: netip.MustParsePrefix("10.33.0.0/16"),
|
||||
IsPrimary: true,
|
||||
Enabled: true,
|
||||
},
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.33.0.0/16")},
|
||||
},
|
||||
ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("10.33.0.0/16")},
|
||||
},
|
||||
},
|
||||
rules: []tailcfg.FilterRule{
|
||||
|
@ -3508,14 +3518,10 @@ func Test_getFilteredByACLPeers(t *testing.T) {
|
|||
IPv4: iap("100.64.0.2"),
|
||||
Hostname: "router",
|
||||
User: types.User{Name: "router"},
|
||||
Routes: types.Routes{
|
||||
types.Route{
|
||||
NodeID: 2,
|
||||
Prefix: netip.MustParsePrefix("10.33.0.0/16"),
|
||||
IsPrimary: true,
|
||||
Enabled: true,
|
||||
},
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.33.0.0/16")},
|
||||
},
|
||||
ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("10.33.0.0/16")},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
|
@ -9,8 +9,8 @@ import (
|
|||
)
|
||||
|
||||
type Match struct {
|
||||
Srcs *netipx.IPSet
|
||||
Dests *netipx.IPSet
|
||||
srcs *netipx.IPSet
|
||||
dests *netipx.IPSet
|
||||
}
|
||||
|
||||
func MatchFromFilterRule(rule tailcfg.FilterRule) Match {
|
||||
|
@ -42,16 +42,16 @@ func MatchFromStrings(sources, destinations []string) Match {
|
|||
destsSet, _ := dests.IPSet()
|
||||
|
||||
match := Match{
|
||||
Srcs: srcsSet,
|
||||
Dests: destsSet,
|
||||
srcs: srcsSet,
|
||||
dests: destsSet,
|
||||
}
|
||||
|
||||
return match
|
||||
}
|
||||
|
||||
func (m *Match) SrcsContainsIPs(ips []netip.Addr) bool {
|
||||
func (m *Match) SrcsContainsIPs(ips ...netip.Addr) bool {
|
||||
for _, ip := range ips {
|
||||
if m.Srcs.Contains(ip) {
|
||||
if m.srcs.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
@ -59,9 +59,29 @@ func (m *Match) SrcsContainsIPs(ips []netip.Addr) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func (m *Match) DestsContainsIP(ips []netip.Addr) bool {
|
||||
func (m *Match) DestsContainsIP(ips ...netip.Addr) bool {
|
||||
for _, ip := range ips {
|
||||
if m.Dests.Contains(ip) {
|
||||
if m.dests.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Match) SrcsOverlapsPrefixes(prefixes ...netip.Prefix) bool {
|
||||
for _, prefix := range prefixes {
|
||||
if m.srcs.ContainsPrefix(prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Match) DestsOverlapsPrefixes(prefixes ...netip.Prefix) bool {
|
||||
for _, prefix := range prefixes {
|
||||
if m.dests.ContainsPrefix(prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,6 +23,9 @@ type PolicyManager interface {
|
|||
SetPolicy([]byte) (bool, error)
|
||||
SetUsers(users []types.User) (bool, error)
|
||||
SetNodes(nodes types.Nodes) (bool, error)
|
||||
|
||||
// NodeCanApproveRoute reports whether the given node can approve the given route.
|
||||
NodeCanApproveRoute(*types.Node, netip.Prefix) bool
|
||||
}
|
||||
|
||||
func NewPolicyManagerFromPath(path string, users []types.User, nodes types.Nodes) (PolicyManager, error) {
|
||||
|
@ -185,3 +188,32 @@ func (pm *PolicyManagerV1) ExpandAlias(alias string) (*netipx.IPSet, error) {
|
|||
}
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
func (pm *PolicyManagerV1) NodeCanApproveRoute(node *types.Node, route netip.Prefix) bool {
|
||||
if pm.pol == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
approvers, _ := pm.pol.AutoApprovers.GetRouteApprovers(route)
|
||||
|
||||
for _, approvedAlias := range approvers {
|
||||
if approvedAlias == node.User.Username() {
|
||||
return true
|
||||
} else {
|
||||
ips, err := pm.pol.ExpandAlias(pm.nodes, pm.users, approvedAlias)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// approvedIPs should contain all of node's IPs if it matches the rule, so check for first
|
||||
if ips.Contains(*node.IPv4) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -7,16 +7,15 @@ import (
|
|||
"net/http"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/db"
|
||||
"github.com/juanfont/headscale/hscontrol/mapper"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/samber/lo"
|
||||
"github.com/sasha-s/go-deadlock"
|
||||
xslices "golang.org/x/exp/slices"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
@ -205,7 +204,15 @@ func (m *mapSession) serveLongPoll() {
|
|||
if m.h.nodeNotifier.RemoveNode(m.node.ID, m.ch) {
|
||||
// Failover the node's routes if any.
|
||||
m.h.updateNodeOnlineStatus(false, m.node)
|
||||
m.pollFailoverRoutes("node closing connection", m.node)
|
||||
|
||||
// When a node disconnects, and it causes the primary route map to change,
|
||||
// send a full update to all nodes.
|
||||
// TODO(kradalby): This can likely be made more effective, but likely most
|
||||
// nodes has access to the same routes, so it might not be a big deal.
|
||||
if m.h.primaryRoutes.SetRoutes(m.node.ID) {
|
||||
ctx := types.NotifyCtx(context.Background(), "poll-primary-change", m.node.Hostname)
|
||||
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
}
|
||||
}
|
||||
|
||||
m.afterServeLongPoll()
|
||||
|
@ -216,7 +223,10 @@ func (m *mapSession) serveLongPoll() {
|
|||
m.h.pollNetMapStreamWG.Add(1)
|
||||
defer m.h.pollNetMapStreamWG.Done()
|
||||
|
||||
m.pollFailoverRoutes("node connected", m.node)
|
||||
if m.h.primaryRoutes.SetRoutes(m.node.ID, m.node.SubnetRoutes()...) {
|
||||
ctx := types.NotifyCtx(context.Background(), "poll-primary-change", m.node.Hostname)
|
||||
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
}
|
||||
|
||||
// Upgrade the writer to a ResponseController
|
||||
rc := http.NewResponseController(m.w)
|
||||
|
@ -383,22 +393,6 @@ func (m *mapSession) serveLongPoll() {
|
|||
}
|
||||
}
|
||||
|
||||
func (m *mapSession) pollFailoverRoutes(where string, node *types.Node) {
|
||||
update, err := db.Write(m.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
|
||||
return db.FailoverNodeRoutesIfNecessary(tx, m.h.nodeNotifier.LikelyConnectedMap(), node)
|
||||
})
|
||||
if err != nil {
|
||||
m.errf(err, fmt.Sprintf("failed to ensure failover routes, %s", where))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if update != nil && !update.Empty() {
|
||||
ctx := types.NotifyCtx(context.Background(), fmt.Sprintf("poll-%s-routes-ensurefailover", strings.ReplaceAll(where, " ", "-")), node.Hostname)
|
||||
m.h.nodeNotifier.NotifyWithIgnore(ctx, *update, node.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// updateNodeOnlineStatus records the last seen status of a node and notifies peers
|
||||
// about change in their online/offline status.
|
||||
// It takes a StateUpdateType of either StatePeerOnlineChanged or StatePeerOfflineChanged.
|
||||
|
@ -414,15 +408,6 @@ func (h *Headscale) updateNodeOnlineStatus(online bool, node *types.Node) {
|
|||
// lastSeen is only relevant if the node is disconnected.
|
||||
node.LastSeen = &now
|
||||
change.LastSeen = &now
|
||||
|
||||
err := h.db.Write(func(tx *gorm.DB) error {
|
||||
return db.SetLastSeen(tx, node.ID, *node.LastSeen)
|
||||
})
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Cannot update node LastSeen")
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-onlinestatus", node.Hostname)
|
||||
|
@ -471,36 +456,47 @@ func (m *mapSession) handleEndpointUpdate() {
|
|||
// If the hostinfo has changed, but not the routes, just update
|
||||
// hostinfo and let the function continue.
|
||||
if routesChanged {
|
||||
var err error
|
||||
_, err = m.h.db.SaveNodeRoutes(m.node)
|
||||
if err != nil {
|
||||
m.errf(err, "Error processing node routes")
|
||||
http.Error(m.w, "", http.StatusInternalServerError)
|
||||
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// TODO(kradalby): Only update the node that has actually changed
|
||||
// TODO(kradalby): I am not sure if we need this?
|
||||
nodesChangedHook(m.h.db, m.h.polMan, m.h.nodeNotifier)
|
||||
|
||||
if m.h.polMan != nil {
|
||||
// update routes with peer information
|
||||
err := m.h.db.EnableAutoApprovedRoutes(m.h.polMan, m.node)
|
||||
if err != nil {
|
||||
m.errf(err, "Error running auto approved routes")
|
||||
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
|
||||
// Take all the routes presented to us by the node and check
|
||||
// if any of them should be auto approved by the policy.
|
||||
// If any of them are, add them to the approved routes of the node.
|
||||
// Keep all the old entries and compact the list to remove duplicates.
|
||||
var newApproved []netip.Prefix
|
||||
for _, route := range m.node.Hostinfo.RoutableIPs {
|
||||
if m.h.polMan.NodeCanApproveRoute(m.node, route) {
|
||||
newApproved = append(newApproved, route)
|
||||
}
|
||||
}
|
||||
if newApproved != nil {
|
||||
newApproved = append(newApproved, m.node.ApprovedRoutes...)
|
||||
slices.SortFunc(newApproved, util.ComparePrefix)
|
||||
slices.Compact(newApproved)
|
||||
newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool {
|
||||
return route.IsValid()
|
||||
})
|
||||
m.node.ApprovedRoutes = newApproved
|
||||
|
||||
if m.h.primaryRoutes.SetRoutes(m.node.ID, m.node.SubnetRoutes()...) {
|
||||
ctx := types.NotifyCtx(m.ctx, "poll-primary-change", m.node.Hostname)
|
||||
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
} else {
|
||||
ctx := types.NotifyCtx(m.ctx, "cli-approveroutes", m.node.Hostname)
|
||||
m.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(m.node.ID), m.node.ID)
|
||||
|
||||
// TODO(kradalby): I am not sure if we need this?
|
||||
// Send an update to the node itself with to ensure it
|
||||
// has an updated packetfilter allowing the new route
|
||||
// if it is defined in the ACL.
|
||||
ctx = types.NotifyCtx(m.ctx, "poll-nodeupdate-self-hostinfochange", m.node.Hostname)
|
||||
m.h.nodeNotifier.NotifyByNodeID(
|
||||
ctx,
|
||||
types.UpdateSelf(m.node.ID),
|
||||
m.node.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// Send an update to the node itself with to ensure it
|
||||
// has an updated packetfilter allowing the new route
|
||||
// if it is defined in the ACL.
|
||||
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-self-hostinfochange", m.node.Hostname)
|
||||
m.h.nodeNotifier.NotifyByNodeID(
|
||||
ctx,
|
||||
types.UpdateSelf(m.node.ID),
|
||||
m.node.ID)
|
||||
}
|
||||
|
||||
// Check if there has been a change to Hostname and update them
|
||||
|
|
186
hscontrol/routes/primary.go
Normal file
186
hscontrol/routes/primary.go
Normal file
|
@ -0,0 +1,186 @@
|
|||
package routes
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
xmaps "golang.org/x/exp/maps"
|
||||
"tailscale.com/util/set"
|
||||
)
|
||||
|
||||
type PrimaryRoutes struct {
|
||||
mu sync.Mutex
|
||||
|
||||
// routes is a map of prefixes that are adverties and approved and available
|
||||
// in the global headscale state.
|
||||
routes map[types.NodeID]set.Set[netip.Prefix]
|
||||
|
||||
// primaries is a map of prefixes to the node that is the primary for that prefix.
|
||||
primaries map[netip.Prefix]types.NodeID
|
||||
isPrimary map[types.NodeID]bool
|
||||
}
|
||||
|
||||
func New() *PrimaryRoutes {
|
||||
return &PrimaryRoutes{
|
||||
routes: make(map[types.NodeID]set.Set[netip.Prefix]),
|
||||
primaries: make(map[netip.Prefix]types.NodeID),
|
||||
isPrimary: make(map[types.NodeID]bool),
|
||||
}
|
||||
}
|
||||
|
||||
// updatePrimaryLocked recalculates the primary routes and updates the internal state.
|
||||
// It returns true if the primary routes have changed.
|
||||
// It is assumed that the caller holds the lock.
|
||||
// The algorthm is as follows:
|
||||
// 1. Reset the primaries map.
|
||||
// 2. Iterate over the routes and count the number of times a prefix is advertised.
|
||||
// 3. If a prefix is advertised by at least two nodes, it is a primary route.
|
||||
// 4. If the primary routes have changed, update the internal state and return true.
|
||||
// 5. Otherwise, return false.
|
||||
func (pr *PrimaryRoutes) updatePrimaryLocked() bool {
|
||||
// reset the primaries map, as we are going to recalculate it.
|
||||
allPrimaries := make(map[netip.Prefix][]types.NodeID)
|
||||
pr.isPrimary = make(map[types.NodeID]bool)
|
||||
changed := false
|
||||
|
||||
// sort the node ids so we can iterate over them in a deterministic order.
|
||||
// this is important so the same node is chosen two times in a row
|
||||
// as the primary route.
|
||||
ids := types.NodeIDs(xmaps.Keys(pr.routes))
|
||||
sort.Sort(ids)
|
||||
|
||||
// Create a map of prefixes to nodes that serve them so we
|
||||
// can determine the primary route for each prefix.
|
||||
for _, id := range ids {
|
||||
routes := pr.routes[id]
|
||||
for route := range routes {
|
||||
if _, ok := allPrimaries[route]; !ok {
|
||||
allPrimaries[route] = []types.NodeID{id}
|
||||
} else {
|
||||
allPrimaries[route] = append(allPrimaries[route], id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Go through all prefixes and determine the primary route for each.
|
||||
// If the number of routes is below the minimum, remove the primary.
|
||||
// If the current primary is still available, continue.
|
||||
// If the current primary is not available, select a new one.
|
||||
for prefix, nodes := range allPrimaries {
|
||||
if node, ok := pr.primaries[prefix]; ok {
|
||||
if len(nodes) < 2 {
|
||||
delete(pr.primaries, prefix)
|
||||
changed = true
|
||||
continue
|
||||
}
|
||||
|
||||
// If the current primary is still available, continue.
|
||||
if slices.Contains(nodes, node) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if len(nodes) >= 2 {
|
||||
pr.primaries[prefix] = nodes[0]
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up any remaining primaries that are no longer valid.
|
||||
for prefix := range pr.primaries {
|
||||
if _, ok := allPrimaries[prefix]; !ok {
|
||||
delete(pr.primaries, prefix)
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
|
||||
// Populate the quick lookup index for primary routes
|
||||
for _, nodeID := range pr.primaries {
|
||||
pr.isPrimary[nodeID] = true
|
||||
}
|
||||
|
||||
return changed
|
||||
}
|
||||
|
||||
func (pr *PrimaryRoutes) SetRoutes(node types.NodeID, prefix ...netip.Prefix) bool {
|
||||
pr.mu.Lock()
|
||||
defer pr.mu.Unlock()
|
||||
|
||||
// If no routes are being set, remove the node from the routes map.
|
||||
if len(prefix) == 0 {
|
||||
log.Printf("Removing node %d from routes", node)
|
||||
if _, ok := pr.routes[node]; ok {
|
||||
delete(pr.routes, node)
|
||||
return pr.updatePrimaryLocked()
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
if _, ok := pr.routes[node]; !ok {
|
||||
pr.routes[node] = make(set.Set[netip.Prefix], len(prefix))
|
||||
}
|
||||
|
||||
for _, p := range prefix {
|
||||
pr.routes[node].Add(p)
|
||||
}
|
||||
|
||||
return pr.updatePrimaryLocked()
|
||||
}
|
||||
|
||||
func (pr *PrimaryRoutes) PrimaryRoutes(id types.NodeID) []netip.Prefix {
|
||||
if pr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
pr.mu.Lock()
|
||||
defer pr.mu.Unlock()
|
||||
|
||||
// Short circuit if the node is not a primary for any route.
|
||||
if _, ok := pr.isPrimary[id]; !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
var routes []netip.Prefix
|
||||
|
||||
for prefix, node := range pr.primaries {
|
||||
if node == id {
|
||||
routes = append(routes, prefix)
|
||||
}
|
||||
}
|
||||
|
||||
return routes
|
||||
}
|
||||
|
||||
func (pr *PrimaryRoutes) String() string {
|
||||
pr.mu.Lock()
|
||||
defer pr.mu.Unlock()
|
||||
|
||||
return pr.stringLocked()
|
||||
}
|
||||
|
||||
func (pr *PrimaryRoutes) stringLocked() string {
|
||||
var sb strings.Builder
|
||||
|
||||
fmt.Fprintln(&sb, "Available routes:")
|
||||
|
||||
ids := types.NodeIDs(xmaps.Keys(pr.routes))
|
||||
sort.Sort(ids)
|
||||
for _, id := range ids {
|
||||
prefixes := pr.routes[id]
|
||||
fmt.Fprintf(&sb, "\nNode %d: %s", id, strings.Join(util.PrefixesToString(prefixes.Slice()), ", "))
|
||||
}
|
||||
|
||||
fmt.Fprintln(&sb, "\n\nCurrent primary routes:")
|
||||
for route, nodeID := range pr.primaries {
|
||||
fmt.Fprintf(&sb, "\nRoute %s: %d", route, nodeID)
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
316
hscontrol/routes/primary_test.go
Normal file
316
hscontrol/routes/primary_test.go
Normal file
|
@ -0,0 +1,316 @@
|
|||
package routes
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
)
|
||||
|
||||
// mp is a helper function that wraps netip.MustParsePrefix.
|
||||
func mp(prefix string) netip.Prefix {
|
||||
return netip.MustParsePrefix(prefix)
|
||||
}
|
||||
|
||||
func TestPrimaryRoutes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
operations func(pr *PrimaryRoutes) bool
|
||||
nodeID types.NodeID
|
||||
expectedRoutes []netip.Prefix
|
||||
expectedChange bool
|
||||
}{
|
||||
{
|
||||
name: "single-node-registers-single-route",
|
||||
operations: func(pr *PrimaryRoutes) bool {
|
||||
return pr.SetRoutes(1, mp("192.168.1.0/24"))
|
||||
},
|
||||
nodeID: 1,
|
||||
expectedRoutes: nil,
|
||||
expectedChange: false,
|
||||
},
|
||||
{
|
||||
name: "multiple-nodes-register-different-routes",
|
||||
operations: func(pr *PrimaryRoutes) bool {
|
||||
pr.SetRoutes(1, mp("192.168.1.0/24"))
|
||||
return pr.SetRoutes(2, mp("192.168.2.0/24"))
|
||||
},
|
||||
nodeID: 1,
|
||||
expectedRoutes: nil,
|
||||
expectedChange: false,
|
||||
},
|
||||
{
|
||||
name: "multiple-nodes-register-overlapping-routes",
|
||||
operations: func(pr *PrimaryRoutes) bool {
|
||||
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
|
||||
return pr.SetRoutes(2, mp("192.168.1.0/24")) // true
|
||||
},
|
||||
nodeID: 1,
|
||||
expectedRoutes: []netip.Prefix{mp("192.168.1.0/24")},
|
||||
expectedChange: true,
|
||||
},
|
||||
{
|
||||
name: "node-deregisters-a-route",
|
||||
operations: func(pr *PrimaryRoutes) bool {
|
||||
pr.SetRoutes(1, mp("192.168.1.0/24"))
|
||||
return pr.SetRoutes(1) // Deregister by setting no routes
|
||||
},
|
||||
nodeID: 1,
|
||||
expectedRoutes: nil,
|
||||
expectedChange: false,
|
||||
},
|
||||
{
|
||||
name: "node-deregisters-one-of-multiple-routes",
|
||||
operations: func(pr *PrimaryRoutes) bool {
|
||||
pr.SetRoutes(1, mp("192.168.1.0/24"), mp("192.168.2.0/24"))
|
||||
return pr.SetRoutes(1, mp("192.168.2.0/24")) // Deregister one route by setting the remaining route
|
||||
},
|
||||
nodeID: 1,
|
||||
expectedRoutes: nil,
|
||||
expectedChange: false,
|
||||
},
|
||||
{
|
||||
name: "node-registers-and-deregisters-routes-in-sequence",
|
||||
operations: func(pr *PrimaryRoutes) bool {
|
||||
pr.SetRoutes(1, mp("192.168.1.0/24"))
|
||||
pr.SetRoutes(2, mp("192.168.2.0/24"))
|
||||
pr.SetRoutes(1) // Deregister by setting no routes
|
||||
return pr.SetRoutes(1, mp("192.168.3.0/24"))
|
||||
},
|
||||
nodeID: 1,
|
||||
expectedRoutes: nil,
|
||||
expectedChange: false,
|
||||
},
|
||||
{
|
||||
name: "no-change-in-primary-routes",
|
||||
operations: func(pr *PrimaryRoutes) bool {
|
||||
return pr.SetRoutes(1, mp("192.168.1.0/24"))
|
||||
},
|
||||
nodeID: 1,
|
||||
expectedRoutes: nil,
|
||||
expectedChange: false,
|
||||
},
|
||||
{
|
||||
name: "multiple-nodes-register-same-route",
|
||||
operations: func(pr *PrimaryRoutes) bool {
|
||||
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
|
||||
pr.SetRoutes(2, mp("192.168.1.0/24")) // true
|
||||
return pr.SetRoutes(3, mp("192.168.1.0/24")) // false
|
||||
},
|
||||
nodeID: 1,
|
||||
expectedRoutes: []netip.Prefix{mp("192.168.1.0/24")},
|
||||
expectedChange: false,
|
||||
},
|
||||
{
|
||||
name: "register-multiple-routes-shift-primary-check-old-primary",
|
||||
operations: func(pr *PrimaryRoutes) bool {
|
||||
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
|
||||
pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary
|
||||
pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary
|
||||
return pr.SetRoutes(1) // true, 2 primary
|
||||
},
|
||||
nodeID: 1,
|
||||
expectedRoutes: nil,
|
||||
expectedChange: true,
|
||||
},
|
||||
{
|
||||
name: "register-multiple-routes-shift-primary-check-primary",
|
||||
operations: func(pr *PrimaryRoutes) bool {
|
||||
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
|
||||
pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary
|
||||
pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary
|
||||
return pr.SetRoutes(1) // true, 2 primary
|
||||
},
|
||||
nodeID: 2,
|
||||
expectedRoutes: []netip.Prefix{mp("192.168.1.0/24")},
|
||||
expectedChange: true,
|
||||
},
|
||||
{
|
||||
name: "register-multiple-routes-shift-primary-check-non-primary",
|
||||
operations: func(pr *PrimaryRoutes) bool {
|
||||
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
|
||||
pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary
|
||||
pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary
|
||||
return pr.SetRoutes(1) // true, 2 primary
|
||||
},
|
||||
nodeID: 3,
|
||||
expectedRoutes: nil,
|
||||
expectedChange: true,
|
||||
},
|
||||
{
|
||||
name: "primary-route-map-is-cleared-up-no-primary",
|
||||
operations: func(pr *PrimaryRoutes) bool {
|
||||
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
|
||||
pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary
|
||||
pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary
|
||||
pr.SetRoutes(1) // true, 2 primary
|
||||
|
||||
return pr.SetRoutes(2) // true, no primary
|
||||
},
|
||||
nodeID: 2,
|
||||
expectedRoutes: nil,
|
||||
expectedChange: true,
|
||||
},
|
||||
{
|
||||
name: "primary-route-map-is-cleared-up-all-no-primary",
|
||||
operations: func(pr *PrimaryRoutes) bool {
|
||||
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
|
||||
pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary
|
||||
pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary
|
||||
pr.SetRoutes(1) // true, 2 primary
|
||||
pr.SetRoutes(2) // true, no primary
|
||||
|
||||
return pr.SetRoutes(3) // false, no primary
|
||||
},
|
||||
nodeID: 2,
|
||||
expectedRoutes: nil,
|
||||
expectedChange: false,
|
||||
},
|
||||
{
|
||||
name: "primary-route-map-is-cleared-up",
|
||||
operations: func(pr *PrimaryRoutes) bool {
|
||||
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
|
||||
pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary
|
||||
pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary
|
||||
pr.SetRoutes(1) // true, 2 primary
|
||||
|
||||
return pr.SetRoutes(2) // true, no primary
|
||||
},
|
||||
nodeID: 2,
|
||||
expectedRoutes: nil,
|
||||
expectedChange: true,
|
||||
},
|
||||
{
|
||||
name: "primary-route-no-flake",
|
||||
operations: func(pr *PrimaryRoutes) bool {
|
||||
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
|
||||
pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary
|
||||
pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary
|
||||
pr.SetRoutes(1) // true, 2 primary
|
||||
|
||||
return pr.SetRoutes(1, mp("192.168.1.0/24")) // false, 2 primary
|
||||
},
|
||||
nodeID: 2,
|
||||
expectedRoutes: []netip.Prefix{mp("192.168.1.0/24")},
|
||||
expectedChange: false,
|
||||
},
|
||||
{
|
||||
name: "primary-route-no-flake-check-old-primary",
|
||||
operations: func(pr *PrimaryRoutes) bool {
|
||||
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
|
||||
pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary
|
||||
pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary
|
||||
pr.SetRoutes(1) // true, 2 primary
|
||||
|
||||
return pr.SetRoutes(1, mp("192.168.1.0/24")) // false, 2 primary
|
||||
},
|
||||
nodeID: 1,
|
||||
expectedRoutes: nil,
|
||||
expectedChange: false,
|
||||
},
|
||||
{
|
||||
name: "primary-route-no-flake-full-integration",
|
||||
operations: func(pr *PrimaryRoutes) bool {
|
||||
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
|
||||
pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary
|
||||
pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary
|
||||
pr.SetRoutes(1) // true, 2 primary
|
||||
pr.SetRoutes(2) // true, no primary
|
||||
pr.SetRoutes(1, mp("192.168.1.0/24")) // true, 1 primary
|
||||
pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary
|
||||
pr.SetRoutes(1) // true, 2 primary
|
||||
|
||||
return pr.SetRoutes(1, mp("192.168.1.0/24")) // false, 2 primary
|
||||
},
|
||||
nodeID: 2,
|
||||
expectedRoutes: []netip.Prefix{mp("192.168.1.0/24")},
|
||||
expectedChange: false,
|
||||
},
|
||||
{
|
||||
name: "multiple-nodes-register-same-route-and-exit",
|
||||
operations: func(pr *PrimaryRoutes) bool {
|
||||
pr.SetRoutes(1, mp("0.0.0.0/0"), mp("192.168.1.0/24"))
|
||||
return pr.SetRoutes(2, mp("192.168.1.0/24"))
|
||||
},
|
||||
nodeID: 1,
|
||||
expectedRoutes: []netip.Prefix{mp("192.168.1.0/24")},
|
||||
expectedChange: true,
|
||||
},
|
||||
{
|
||||
name: "deregister-non-existent-route",
|
||||
operations: func(pr *PrimaryRoutes) bool {
|
||||
return pr.SetRoutes(1) // Deregister by setting no routes
|
||||
},
|
||||
nodeID: 1,
|
||||
expectedRoutes: nil,
|
||||
expectedChange: false,
|
||||
},
|
||||
{
|
||||
name: "register-empty-prefix-list",
|
||||
operations: func(pr *PrimaryRoutes) bool {
|
||||
return pr.SetRoutes(1)
|
||||
},
|
||||
nodeID: 1,
|
||||
expectedRoutes: nil,
|
||||
expectedChange: false,
|
||||
},
|
||||
{
|
||||
name: "deregister-empty-prefix-list",
|
||||
operations: func(pr *PrimaryRoutes) bool {
|
||||
return pr.SetRoutes(1)
|
||||
},
|
||||
nodeID: 1,
|
||||
expectedRoutes: nil,
|
||||
expectedChange: false,
|
||||
},
|
||||
{
|
||||
name: "concurrent-access",
|
||||
operations: func(pr *PrimaryRoutes) bool {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
var change1, change2 bool
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
change1 = pr.SetRoutes(1, mp("192.168.1.0/24"))
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
change2 = pr.SetRoutes(2, mp("192.168.2.0/24"))
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
return change1 || change2
|
||||
},
|
||||
nodeID: 1,
|
||||
expectedRoutes: nil,
|
||||
expectedChange: false,
|
||||
},
|
||||
{
|
||||
name: "no-routes-registered",
|
||||
operations: func(pr *PrimaryRoutes) bool {
|
||||
// No operations
|
||||
return false
|
||||
},
|
||||
nodeID: 1,
|
||||
expectedRoutes: nil,
|
||||
expectedChange: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pr := New()
|
||||
change := tt.operations(pr)
|
||||
if change != tt.expectedChange {
|
||||
t.Errorf("change = %v, want %v", change, tt.expectedChange)
|
||||
}
|
||||
routes := pr.PrimaryRoutes(tt.nodeID)
|
||||
if diff := cmp.Diff(tt.expectedRoutes, routes, util.Comparers...); diff != "" {
|
||||
t.Errorf("PrimaryRoutes() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -4,6 +4,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -25,8 +26,11 @@ var (
|
|||
)
|
||||
|
||||
type NodeID uint64
|
||||
type NodeIDs []NodeID
|
||||
|
||||
// type NodeConnectedMap *xsync.MapOf[NodeID, bool]
|
||||
func (n NodeIDs) Len() int { return len(n) }
|
||||
func (n NodeIDs) Less(i, j int) bool { return n[i] < n[j] }
|
||||
func (n NodeIDs) Swap(i, j int) { n[i], n[j] = n[j], n[i] }
|
||||
|
||||
func (id NodeID) StableID() tailcfg.StableNodeID {
|
||||
return tailcfg.StableNodeID(strconv.FormatUint(uint64(id), util.Base10))
|
||||
|
@ -84,10 +88,21 @@ type Node struct {
|
|||
AuthKeyID *uint64 `sql:"DEFAULT:NULL"`
|
||||
AuthKey *PreAuthKey
|
||||
|
||||
LastSeen *time.Time
|
||||
Expiry *time.Time
|
||||
Expiry *time.Time
|
||||
|
||||
Routes []Route `gorm:"constraint:OnDelete:CASCADE;"`
|
||||
// LastSeen is when the node was last in contact with
|
||||
// headscale. It is best effort and not persisted.
|
||||
LastSeen *time.Time `gorm:"-"`
|
||||
|
||||
// DEPRECATED: Use the ApprovedRoutes field instead.
|
||||
// TODO(kradalby): remove when ApprovedRoutes is used all over the code.
|
||||
// Routes []Route `gorm:"constraint:OnDelete:CASCADE;"`
|
||||
|
||||
// ApprovedRoutes is a list of routes that the node is allowed to announce
|
||||
// as a subnet router. They are not necessarily the routes that the node
|
||||
// announces at the moment.
|
||||
// See [Node.Hostinfo]
|
||||
ApprovedRoutes []netip.Prefix `gorm:"column:approved_routes;serializer:json"`
|
||||
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
@ -96,9 +111,7 @@ type Node struct {
|
|||
IsOnline *bool `gorm:"-"`
|
||||
}
|
||||
|
||||
type (
|
||||
Nodes []*Node
|
||||
)
|
||||
type Nodes []*Node
|
||||
|
||||
// GivenNameHasBeenChanged returns whether the `givenName` can be automatically changed based on the `Hostname` of the node.
|
||||
func (node *Node) GivenNameHasBeenChanged() bool {
|
||||
|
@ -185,23 +198,22 @@ func (node *Node) CanAccess(filter []tailcfg.FilterRule, node2 *Node) bool {
|
|||
|
||||
// TODO(kradalby): Regenerate this every time the filter change, instead of
|
||||
// every time we use it.
|
||||
// Part of #2416
|
||||
matchers := make([]matcher.Match, len(filter))
|
||||
for i, rule := range filter {
|
||||
matchers[i] = matcher.MatchFromFilterRule(rule)
|
||||
}
|
||||
|
||||
for _, route := range node2.Routes {
|
||||
if route.Enabled {
|
||||
allowedIPs = append(allowedIPs, netip.Prefix(route.Prefix).Addr())
|
||||
}
|
||||
}
|
||||
|
||||
for _, matcher := range matchers {
|
||||
if !matcher.SrcsContainsIPs(src) {
|
||||
if !matcher.SrcsContainsIPs(src...) {
|
||||
continue
|
||||
}
|
||||
|
||||
if matcher.DestsContainsIP(allowedIPs) {
|
||||
if matcher.DestsContainsIP(allowedIPs...) {
|
||||
return true
|
||||
}
|
||||
|
||||
if matcher.DestsOverlapsPrefixes(node2.SubnetRoutes()...) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
@ -245,11 +257,14 @@ func (node *Node) Proto() *v1.Node {
|
|||
DiscoKey: node.DiscoKey.String(),
|
||||
|
||||
// TODO(kradalby): replace list with v4, v6 field?
|
||||
IpAddresses: node.IPsAsString(),
|
||||
Name: node.Hostname,
|
||||
GivenName: node.GivenName,
|
||||
User: node.User.Proto(),
|
||||
ForcedTags: node.ForcedTags,
|
||||
IpAddresses: node.IPsAsString(),
|
||||
Name: node.Hostname,
|
||||
GivenName: node.GivenName,
|
||||
User: node.User.Proto(),
|
||||
ForcedTags: node.ForcedTags,
|
||||
ApprovedRoutes: util.PrefixesToString(node.ApprovedRoutes),
|
||||
AvailableRoutes: util.PrefixesToString(node.AnnouncedRoutes()),
|
||||
SubnetRoutes: util.PrefixesToString(node.SubnetRoutes()),
|
||||
|
||||
RegisterMethod: node.RegisterMethodToV1Enum(),
|
||||
|
||||
|
@ -297,6 +312,29 @@ func (node *Node) GetFQDN(baseDomain string) (string, error) {
|
|||
return hostname, nil
|
||||
}
|
||||
|
||||
// AnnouncedRoutes returns the list of routes that the node announces.
|
||||
// It should be used instead of checking Hostinfo.RoutableIPs directly.
|
||||
func (node *Node) AnnouncedRoutes() []netip.Prefix {
|
||||
if node.Hostinfo == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return node.Hostinfo.RoutableIPs
|
||||
}
|
||||
|
||||
// SubnetRoutes returns the list of routes that the node announces and are approved.
|
||||
func (node *Node) SubnetRoutes() []netip.Prefix {
|
||||
var routes []netip.Prefix
|
||||
|
||||
for _, route := range node.AnnouncedRoutes() {
|
||||
if slices.Contains(node.ApprovedRoutes, route) {
|
||||
routes = append(routes, route)
|
||||
}
|
||||
}
|
||||
|
||||
return routes
|
||||
}
|
||||
|
||||
// func (node *Node) String() string {
|
||||
// return node.Hostname
|
||||
// }
|
||||
|
|
|
@ -1,102 +1,31 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/net/tsaddr"
|
||||
)
|
||||
|
||||
// Deprecated: Approval of routes is denormalised onto the relevant node.
|
||||
// Struct is kept for GORM migrations only.
|
||||
type Route struct {
|
||||
gorm.Model
|
||||
|
||||
NodeID uint64 `gorm:"not null"`
|
||||
Node *Node
|
||||
|
||||
// TODO(kradalby): change this custom type to netip.Prefix
|
||||
Prefix netip.Prefix `gorm:"serializer:text"`
|
||||
|
||||
// Advertised is now only stored as part of [Node.Hostinfo].
|
||||
Advertised bool
|
||||
Enabled bool
|
||||
IsPrimary bool
|
||||
|
||||
// Enabled is stored directly on the node as ApprovedRoutes.
|
||||
Enabled bool
|
||||
|
||||
// IsPrimary is only determined in memory as it is only relevant
|
||||
// when the server is up.
|
||||
IsPrimary bool
|
||||
}
|
||||
|
||||
// Deprecated: Approval of routes is denormalised onto the relevant node.
|
||||
type Routes []Route
|
||||
|
||||
func (r *Route) String() string {
|
||||
return fmt.Sprintf("%s:%s", r.Node.Hostname, netip.Prefix(r.Prefix).String())
|
||||
}
|
||||
|
||||
func (r *Route) IsExitRoute() bool {
|
||||
return tsaddr.IsExitRoute(r.Prefix)
|
||||
}
|
||||
|
||||
func (r *Route) IsAnnouncable() bool {
|
||||
return r.Advertised && r.Enabled
|
||||
}
|
||||
|
||||
func (rs Routes) Prefixes() []netip.Prefix {
|
||||
prefixes := make([]netip.Prefix, len(rs))
|
||||
for i, r := range rs {
|
||||
prefixes[i] = netip.Prefix(r.Prefix)
|
||||
}
|
||||
|
||||
return prefixes
|
||||
}
|
||||
|
||||
// Primaries returns Primary routes from a list of routes.
|
||||
func (rs Routes) Primaries() Routes {
|
||||
res := make(Routes, 0)
|
||||
for _, route := range rs {
|
||||
if route.IsPrimary {
|
||||
res = append(res, route)
|
||||
}
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func (rs Routes) PrefixMap() map[netip.Prefix][]Route {
|
||||
res := map[netip.Prefix][]Route{}
|
||||
|
||||
for _, route := range rs {
|
||||
if _, ok := res[route.Prefix]; ok {
|
||||
res[route.Prefix] = append(res[route.Prefix], route)
|
||||
} else {
|
||||
res[route.Prefix] = []Route{route}
|
||||
}
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func (rs Routes) Proto() []*v1.Route {
|
||||
protoRoutes := []*v1.Route{}
|
||||
|
||||
for _, route := range rs {
|
||||
protoRoute := v1.Route{
|
||||
Id: uint64(route.ID),
|
||||
Prefix: route.Prefix.String(),
|
||||
Advertised: route.Advertised,
|
||||
Enabled: route.Enabled,
|
||||
IsPrimary: route.IsPrimary,
|
||||
CreatedAt: timestamppb.New(route.CreatedAt),
|
||||
UpdatedAt: timestamppb.New(route.UpdatedAt),
|
||||
}
|
||||
|
||||
if route.Node != nil {
|
||||
protoRoute.Node = route.Node.Proto()
|
||||
}
|
||||
|
||||
if route.DeletedAt.Valid {
|
||||
protoRoute.DeletedAt = timestamppb.New(route.DeletedAt.Time)
|
||||
}
|
||||
|
||||
protoRoutes = append(protoRoutes, &protoRoute)
|
||||
}
|
||||
|
||||
return protoRoutes
|
||||
}
|
||||
|
|
|
@ -1,89 +0,0 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
)
|
||||
|
||||
func TestPrefixMap(t *testing.T) {
|
||||
ipp := func(s string) netip.Prefix { return netip.MustParsePrefix(s) }
|
||||
|
||||
tests := []struct {
|
||||
rs Routes
|
||||
want map[netip.Prefix][]Route
|
||||
}{
|
||||
{
|
||||
rs: Routes{
|
||||
Route{
|
||||
Prefix: ipp("10.0.0.0/24"),
|
||||
},
|
||||
},
|
||||
want: map[netip.Prefix][]Route{
|
||||
ipp("10.0.0.0/24"): Routes{
|
||||
Route{
|
||||
Prefix: ipp("10.0.0.0/24"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
rs: Routes{
|
||||
Route{
|
||||
Prefix: ipp("10.0.0.0/24"),
|
||||
},
|
||||
Route{
|
||||
Prefix: ipp("10.0.1.0/24"),
|
||||
},
|
||||
},
|
||||
want: map[netip.Prefix][]Route{
|
||||
ipp("10.0.0.0/24"): Routes{
|
||||
Route{
|
||||
Prefix: ipp("10.0.0.0/24"),
|
||||
},
|
||||
},
|
||||
ipp("10.0.1.0/24"): Routes{
|
||||
Route{
|
||||
Prefix: ipp("10.0.1.0/24"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
rs: Routes{
|
||||
Route{
|
||||
Prefix: ipp("10.0.0.0/24"),
|
||||
Enabled: true,
|
||||
},
|
||||
Route{
|
||||
Prefix: ipp("10.0.0.0/24"),
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
want: map[netip.Prefix][]Route{
|
||||
ipp("10.0.0.0/24"): Routes{
|
||||
Route{
|
||||
Prefix: ipp("10.0.0.0/24"),
|
||||
Enabled: true,
|
||||
},
|
||||
Route{
|
||||
Prefix: ipp("10.0.0.0/24"),
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for idx, tt := range tests {
|
||||
t.Run(fmt.Sprintf("test-%d", idx), func(t *testing.T) {
|
||||
got := tt.rs.PrefixMap()
|
||||
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
|
||||
t.Errorf("PrefixMap() unexpected result (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -55,6 +55,13 @@ type User struct {
|
|||
ProfilePicURL string
|
||||
}
|
||||
|
||||
func (u *User) StringID() string {
|
||||
if u == nil {
|
||||
return ""
|
||||
}
|
||||
return strconv.FormatUint(uint64(u.ID), 10)
|
||||
}
|
||||
|
||||
// Username is the main way to get the username of a user,
|
||||
// it will return the email if it exists, the name if it exists,
|
||||
// the OIDCIdentifier if it exists, and the ID if nothing else exists.
|
||||
|
@ -63,7 +70,11 @@ type User struct {
|
|||
// should be used throughout headscale, in information returned to the
|
||||
// user and the Policy engine.
|
||||
func (u *User) Username() string {
|
||||
return cmp.Or(u.Email, u.Name, u.ProviderIdentifier.String, strconv.FormatUint(uint64(u.ID), 10))
|
||||
return cmp.Or(
|
||||
u.Email,
|
||||
u.Name,
|
||||
u.ProviderIdentifier.String,
|
||||
u.StringID())
|
||||
}
|
||||
|
||||
// DisplayNameOrUsername returns the DisplayName if it exists, otherwise
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) {
|
||||
|
@ -10,3 +12,40 @@ func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) {
|
|||
|
||||
return d.DialContext(ctx, "unix", addr)
|
||||
}
|
||||
|
||||
// TODO(kradalby): Remove when in stdlib;
|
||||
// https://github.com/golang/go/issues/61642
|
||||
// Compare returns an integer comparing two prefixes.
|
||||
// The result will be 0 if p == p2, -1 if p < p2, and +1 if p > p2.
|
||||
// Prefixes sort first by validity (invalid before valid), then
|
||||
// address family (IPv4 before IPv6), then prefix length, then
|
||||
// address.
|
||||
func ComparePrefix(p, p2 netip.Prefix) int {
|
||||
if c := cmp.Compare(p.Addr().BitLen(), p2.Addr().BitLen()); c != 0 {
|
||||
return c
|
||||
}
|
||||
if c := cmp.Compare(p.Bits(), p2.Bits()); c != 0 {
|
||||
return c
|
||||
}
|
||||
|
||||
return p.Addr().Compare(p2.Addr())
|
||||
}
|
||||
|
||||
func PrefixesToString(prefixes []netip.Prefix) []string {
|
||||
ret := make([]string, 0, len(prefixes))
|
||||
for _, prefix := range prefixes {
|
||||
ret = append(ret, prefix.String())
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func MustStringsToPrefixes(strings []string) []netip.Prefix {
|
||||
ret := make([]netip.Prefix, 0, len(strings))
|
||||
for _, str := range strings {
|
||||
prefix := netip.MustParsePrefix(str)
|
||||
ret = append(ret, prefix)
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
|
|
@ -74,3 +74,21 @@ func TailMapResponseToString(resp tailcfg.MapResponse) string {
|
|||
TailNodesToString(resp.Peers),
|
||||
)
|
||||
}
|
||||
|
||||
func TailcfgFilterRulesToString(rules []tailcfg.FilterRule) string {
|
||||
var sb strings.Builder
|
||||
|
||||
for index, rule := range rules {
|
||||
sb.WriteString(fmt.Sprintf(`
|
||||
{
|
||||
SrcIPs: %v
|
||||
DstIPs: %v
|
||||
}
|
||||
`, rule.SrcIPs, rule.DstPorts))
|
||||
if index < len(rules)-1 {
|
||||
sb.WriteString(", ")
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("[ %s ](%d)", sb.String(), len(rules))
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue