Redo route code (#2422)

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2025-02-26 07:22:55 -08:00 committed by GitHub
parent 16868190c8
commit 7891378f57
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
53 changed files with 2977 additions and 6251 deletions

View file

@ -32,6 +32,7 @@ import (
"github.com/juanfont/headscale/hscontrol/mapper"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/routes"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
zerolog "github.com/philip-bui/grpc-zerolog"
@ -92,6 +93,7 @@ type Headscale struct {
polManOnce sync.Once
polMan policy.PolicyManager
extraRecordMan *dns.ExtraRecordsMan
primaryRoutes *routes.PrimaryRoutes
mapper *mapper.Mapper
nodeNotifier *notifier.Notifier
@ -134,6 +136,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
registrationCache: registrationCache,
pollNetMapStreamWG: sync.WaitGroup{},
nodeNotifier: notifier.NewNotifier(cfg),
primaryRoutes: routes.New(),
}
app.db, err = db.NewHeadscaleDatabase(
@ -495,6 +498,8 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
// TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
// Maybe we should attempt a new in memory state and not go via the DB?
// Maybe this should be implemented as an event bus?
// A bool is returned indicating if a full update was sent to all nodes
func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error {
users, err := db.ListUsers()
if err != nil {
@ -516,6 +521,7 @@ func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *not
// TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
// Maybe we should attempt a new in memory state and not go via the DB?
// Maybe this should be implemented as an event bus?
// A bool is returned indicating if a full update was sent to all nodes
func nodesChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) (bool, error) {
nodes, err := db.ListNodes()
@ -566,7 +572,7 @@ func (h *Headscale) Serve() error {
// Fetch an initial DERP Map before we start serving
h.DERPMap = derp.GetDERPMap(h.cfg.DERP)
h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier, h.polMan)
h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier, h.polMan, h.primaryRoutes)
if h.cfg.DERP.ServerEnabled {
// When embedded DERP is enabled we always need a STUN server

View file

@ -84,16 +84,13 @@ func (h *Headscale) handleExistingNode(
// If the request expiry is in the past, we consider it a logout.
if requestExpiry.Before(time.Now()) {
if node.IsEphemeral() {
changedNodes, err := h.db.DeleteNode(node, h.nodeNotifier.LikelyConnectedMap())
err := h.db.DeleteNode(node)
if err != nil {
return nil, fmt.Errorf("deleting ephemeral node: %w", err)
}
ctx := types.NotifyCtx(context.Background(), "logout-ephemeral", "na")
h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerRemoved(node.ID))
if changedNodes != nil {
h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(changedNodes...))
}
}
expired = true

View file

@ -8,6 +8,7 @@ import (
"fmt"
"net/netip"
"path/filepath"
"slices"
"strconv"
"strings"
"time"
@ -622,6 +623,62 @@ AND auth_key_id NOT IN (
},
Rollback: func(db *gorm.DB) error { return nil },
},
// Migrate all routes from the Route table to the new field ApprovedRoutes
// in the Node table. Then drop the Route table.
{
ID: "202502131714",
Migrate: func(tx *gorm.DB) error {
if !tx.Migrator().HasColumn(&types.Node{}, "approved_routes") {
err := tx.Migrator().AddColumn(&types.Node{}, "approved_routes")
if err != nil {
return fmt.Errorf("adding column types.Node: %w", err)
}
}
// Ensure the ApprovedRoutes exist.
// err := tx.AutoMigrate(&types.Node{})
// if err != nil {
// return fmt.Errorf("automigrating types.Node: %w", err)
// }
nodeRoutes := map[uint64][]netip.Prefix{}
var routes []types.Route
err = tx.Find(&routes).Error
if err != nil {
return fmt.Errorf("fetching routes: %w", err)
}
for _, route := range routes {
if route.Enabled {
nodeRoutes[route.NodeID] = append(nodeRoutes[route.NodeID], route.Prefix)
}
}
for nodeID, routes := range nodeRoutes {
slices.SortFunc(routes, util.ComparePrefix)
slices.Compact(routes)
data, err := json.Marshal(routes)
err = tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", data).Error
if err != nil {
return fmt.Errorf("saving approved routes to new column: %w", err)
}
}
return nil
},
Rollback: func(db *gorm.DB) error { return nil },
},
{
ID: "202502171819",
Migrate: func(tx *gorm.DB) error {
_ = tx.Migrator().DropColumn(&types.Node{}, "last_seen")
return nil
},
Rollback: func(db *gorm.DB) error { return nil },
},
},
)

View file

@ -48,25 +48,43 @@ func TestMigrationsSQLite(t *testing.T) {
{
dbPath: "testdata/0-22-3-to-0-23-0-routes-are-dropped-2063.sqlite",
wantFunc: func(t *testing.T, h *HSDatabase) {
routes, err := Read(h.DB, func(rx *gorm.DB) (types.Routes, error) {
return GetRoutes(rx)
nodes, err := Read(h.DB, func(rx *gorm.DB) (types.Nodes, error) {
n1, err := GetNodeByID(rx, 1)
n26, err := GetNodeByID(rx, 26)
n31, err := GetNodeByID(rx, 31)
n32, err := GetNodeByID(rx, 32)
if err != nil {
return nil, err
}
return types.Nodes{n1, n26, n31, n32}, nil
})
require.NoError(t, err)
assert.Len(t, routes, 10)
want := types.Routes{
r(1, "0.0.0.0/0", true, true, false),
r(1, "::/0", true, true, false),
r(1, "10.9.110.0/24", true, true, true),
r(26, "172.100.100.0/24", true, true, true),
r(26, "172.100.100.0/24", true, false, false),
r(31, "0.0.0.0/0", true, true, false),
r(31, "0.0.0.0/0", true, false, false),
r(31, "::/0", true, true, false),
r(31, "::/0", true, false, false),
r(32, "192.168.0.24/32", true, true, true),
// want := types.Routes{
// r(1, "0.0.0.0/0", true, false),
// r(1, "::/0", true, false),
// r(1, "10.9.110.0/24", true, true),
// r(26, "172.100.100.0/24", true, true),
// r(26, "172.100.100.0/24", true, false, false),
// r(31, "0.0.0.0/0", true, false),
// r(31, "0.0.0.0/0", true, false, false),
// r(31, "::/0", true, false),
// r(31, "::/0", true, false, false),
// r(32, "192.168.0.24/32", true, true),
// }
want := [][]netip.Prefix{
{ipp("0.0.0.0/0"), ipp("10.9.110.0/24"), ipp("::/0")},
{ipp("172.100.100.0/24")},
{ipp("0.0.0.0/0"), ipp("::/0")},
{ipp("192.168.0.24/32")},
}
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), util.PrefixComparer); diff != "" {
var got [][]netip.Prefix
for _, node := range nodes {
got = append(got, node.ApprovedRoutes)
}
if diff := cmp.Diff(want, got, util.PrefixComparer); diff != "" {
t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff)
}
},
@ -74,13 +92,13 @@ func TestMigrationsSQLite(t *testing.T) {
{
dbPath: "testdata/0-22-3-to-0-23-0-routes-fail-foreign-key-2076.sqlite",
wantFunc: func(t *testing.T, h *HSDatabase) {
routes, err := Read(h.DB, func(rx *gorm.DB) (types.Routes, error) {
return GetRoutes(rx)
node, err := Read(h.DB, func(rx *gorm.DB) (*types.Node, error) {
return GetNodeByID(rx, 13)
})
require.NoError(t, err)
assert.Len(t, routes, 4)
want := types.Routes{
assert.Len(t, node.ApprovedRoutes, 3)
_ = types.Routes{
// These routes exists, but have no nodes associated with them
// when the migration starts.
// r(1, "0.0.0.0/0", true, true, false),
@ -111,7 +129,8 @@ func TestMigrationsSQLite(t *testing.T) {
r(13, "::/0", true, true, false),
r(13, "10.18.80.2/32", true, true, true),
}
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), util.PrefixComparer); diff != "" {
want := []netip.Prefix{ipp("0.0.0.0/0"), ipp("10.18.80.2/32"), ipp("::/0")}
if diff := cmp.Diff(want, node.ApprovedRoutes, util.PrefixComparer); diff != "" {
t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff)
}
},
@ -225,7 +244,7 @@ func TestMigrationsSQLite(t *testing.T) {
for _, tt := range tests {
t.Run(tt.dbPath, func(t *testing.T) {
dbPath, err := testCopyOfDatabase(tt.dbPath)
dbPath, err := testCopyOfDatabase(t, tt.dbPath)
if err != nil {
t.Fatalf("copying db for test: %s", err)
}
@ -247,7 +266,7 @@ func TestMigrationsSQLite(t *testing.T) {
}
}
func testCopyOfDatabase(src string) (string, error) {
func testCopyOfDatabase(t *testing.T, src string) (string, error) {
sourceFileStat, err := os.Stat(src)
if err != nil {
return "", err
@ -263,11 +282,7 @@ func testCopyOfDatabase(src string) (string, error) {
}
defer source.Close()
tmpDir, err := os.MkdirTemp("", "hsdb-test-*")
if err != nil {
return "", err
}
tmpDir := t.TempDir()
fn := filepath.Base(src)
dst := filepath.Join(tmpDir, fn)
@ -454,3 +469,27 @@ func TestMigrationsPostgres(t *testing.T) {
})
}
}
func dbForTest(t *testing.T) *HSDatabase {
t.Helper()
dbPath := t.TempDir() + "/headscale_test.db"
db, err := NewHeadscaleDatabase(
types.DatabaseConfig{
Type: "sqlite3",
Sqlite: types.SqliteConfig{
Path: dbPath,
},
},
"",
emptyCache(),
)
if err != nil {
t.Fatalf("setting up database: %s", err)
}
t.Logf("database set up at: %s", dbPath)
return db
}

View file

@ -91,7 +91,7 @@ func TestIPAllocatorSequential(t *testing.T) {
{
name: "simple-with-db",
dbFunc: func() *HSDatabase {
db := dbForTest(t, "simple-with-db")
db := dbForTest(t)
user := types.User{Name: ""}
db.DB.Save(&user)
@ -119,7 +119,7 @@ func TestIPAllocatorSequential(t *testing.T) {
{
name: "before-after-free-middle-in-db",
dbFunc: func() *HSDatabase {
db := dbForTest(t, "before-after-free-middle-in-db")
db := dbForTest(t)
user := types.User{Name: ""}
db.DB.Save(&user)
@ -309,7 +309,7 @@ func TestBackfillIPAddresses(t *testing.T) {
{
name: "simple-backfill-ipv6",
dbFunc: func() *HSDatabase {
db := dbForTest(t, "simple-backfill-ipv6")
db := dbForTest(t)
user := types.User{Name: ""}
db.DB.Save(&user)
@ -334,7 +334,7 @@ func TestBackfillIPAddresses(t *testing.T) {
{
name: "simple-backfill-ipv4",
dbFunc: func() *HSDatabase {
db := dbForTest(t, "simple-backfill-ipv4")
db := dbForTest(t)
user := types.User{Name: ""}
db.DB.Save(&user)
@ -359,7 +359,7 @@ func TestBackfillIPAddresses(t *testing.T) {
{
name: "simple-backfill-remove-ipv6",
dbFunc: func() *HSDatabase {
db := dbForTest(t, "simple-backfill-remove-ipv6")
db := dbForTest(t)
user := types.User{Name: ""}
db.DB.Save(&user)
@ -383,7 +383,7 @@ func TestBackfillIPAddresses(t *testing.T) {
{
name: "simple-backfill-remove-ipv4",
dbFunc: func() *HSDatabase {
db := dbForTest(t, "simple-backfill-remove-ipv4")
db := dbForTest(t)
user := types.User{Name: ""}
db.DB.Save(&user)
@ -407,7 +407,7 @@ func TestBackfillIPAddresses(t *testing.T) {
{
name: "multi-backfill-ipv6",
dbFunc: func() *HSDatabase {
db := dbForTest(t, "simple-backfill-ipv6")
db := dbForTest(t)
user := types.User{Name: ""}
db.DB.Save(&user)
@ -449,7 +449,6 @@ func TestBackfillIPAddresses(t *testing.T) {
"UserID",
"Endpoints",
"Hostinfo",
"Routes",
"CreatedAt",
"UpdatedAt",
))
@ -488,6 +487,10 @@ func TestBackfillIPAddresses(t *testing.T) {
}
func TestIPAllocatorNextNoReservedIPs(t *testing.T) {
db, err := newSQLiteTestDB()
require.NoError(t, err)
defer db.Close()
alloc, err := NewIPAllocator(
db,
ptr.To(tsaddr.CGNATRange()),

View file

@ -12,12 +12,10 @@ import (
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/puzpuzpuz/xsync/v3"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/ptr"
)
const (
@ -50,7 +48,6 @@ func ListPeers(tx *gorm.DB, nodeID types.NodeID) (types.Nodes, error) {
Preload("AuthKey").
Preload("AuthKey.User").
Preload("User").
Preload("Routes").
Where("id <> ?",
nodeID).Find(&nodes).Error; err != nil {
return types.Nodes{}, err
@ -73,7 +70,6 @@ func ListNodes(tx *gorm.DB) (types.Nodes, error) {
Preload("AuthKey").
Preload("AuthKey.User").
Preload("User").
Preload("Routes").
Find(&nodes).Error; err != nil {
return nil, err
}
@ -127,7 +123,6 @@ func GetNodeByID(tx *gorm.DB, id types.NodeID) (*types.Node, error) {
Preload("AuthKey").
Preload("AuthKey.User").
Preload("User").
Preload("Routes").
Find(&types.Node{ID: id}).First(&mach); result.Error != nil {
return nil, result.Error
}
@ -151,7 +146,6 @@ func GetNodeByMachineKey(
Preload("AuthKey").
Preload("AuthKey.User").
Preload("User").
Preload("Routes").
First(&mach, "machine_key = ?", machineKey.String()); result.Error != nil {
return nil, result.Error
}
@ -175,7 +169,6 @@ func GetNodeByNodeKey(
Preload("AuthKey").
Preload("AuthKey.User").
Preload("User").
Preload("Routes").
First(&mach, "node_key = ?", nodeKey.String()); result.Error != nil {
return nil, result.Error
}
@ -201,7 +194,7 @@ func SetTags(
if len(tags) == 0 {
// if no tags are provided, we remove all forced tags
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", "[]").Error; err != nil {
return fmt.Errorf("failed to remove tags for node in the database: %w", err)
return fmt.Errorf("removing tags: %w", err)
}
return nil
@ -220,7 +213,34 @@ func SetTags(
}
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", string(b)).Error; err != nil {
return fmt.Errorf("failed to update tags for node in the database: %w", err)
return fmt.Errorf("updating tags: %w", err)
}
return nil
}
// SetTags takes a Node struct pointer and update the forced tags.
func SetApprovedRoutes(
tx *gorm.DB,
nodeID types.NodeID,
routes []netip.Prefix,
) error {
if len(routes) == 0 {
// if no routes are provided, we remove all
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", "[]").Error; err != nil {
return fmt.Errorf("removing approved routes: %w", err)
}
return nil
}
b, err := json.Marshal(routes)
if err != nil {
return err
}
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", string(b)).Error; err != nil {
return fmt.Errorf("updating approved routes: %w", err)
}
return nil
@ -267,9 +287,9 @@ func NodeSetExpiry(tx *gorm.DB,
return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("expiry", expiry).Error
}
func (hsdb *HSDatabase) DeleteNode(node *types.Node, isLikelyConnected *xsync.MapOf[types.NodeID, bool]) ([]types.NodeID, error) {
return Write(hsdb.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
return DeleteNode(tx, node, isLikelyConnected)
func (hsdb *HSDatabase) DeleteNode(node *types.Node) error {
return hsdb.Write(func(tx *gorm.DB) error {
return DeleteNode(tx, node)
})
}
@ -277,19 +297,13 @@ func (hsdb *HSDatabase) DeleteNode(node *types.Node, isLikelyConnected *xsync.Ma
// Caller is responsible for notifying all of change.
func DeleteNode(tx *gorm.DB,
node *types.Node,
isLikelyConnected *xsync.MapOf[types.NodeID, bool],
) ([]types.NodeID, error) {
changed, err := deleteNodeRoutes(tx, node, isLikelyConnected)
if err != nil {
return changed, err
}
) error {
// Unscoped causes the node to be fully removed from the database.
if err := tx.Unscoped().Delete(&types.Node{}, node.ID).Error; err != nil {
return changed, err
return err
}
return changed, nil
return nil
}
// DeleteEphemeralNode deletes a Node from the database, note that this method
@ -306,12 +320,6 @@ func (hsdb *HSDatabase) DeleteEphemeralNode(
})
}
// SetLastSeen sets a node's last seen field indicating that we
// have recently communicating with this node.
func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error {
return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("last_seen", lastSeen).Error
}
// HandleNodeFromAuthPath is called from the OIDC or CLI auth path
// with a registrationID to register or reauthenticate a node.
// If the node found in the registration cache is not already registered,
@ -458,10 +466,6 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
return nil, fmt.Errorf("failed register(save) node in the database: %w", err)
}
if _, err := SaveNodeRoutes(tx, &node); err != nil {
return nil, fmt.Errorf("failed to save node routes: %w", err)
}
log.Trace().
Caller().
Str("node", node.Hostname).
@ -504,141 +508,6 @@ func NodeSave(tx *gorm.DB, node *types.Node) error {
return tx.Save(node).Error
}
func (hsdb *HSDatabase) GetAdvertisedRoutes(node *types.Node) ([]netip.Prefix, error) {
return Read(hsdb.DB, func(rx *gorm.DB) ([]netip.Prefix, error) {
return GetAdvertisedRoutes(rx, node)
})
}
// GetAdvertisedRoutes returns the routes that are be advertised by the given node.
func GetAdvertisedRoutes(tx *gorm.DB, node *types.Node) ([]netip.Prefix, error) {
routes := types.Routes{}
err := tx.
Preload("Node").
Where("node_id = ? AND advertised = ?", node.ID, true).Find(&routes).Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("getting advertised routes for node(%d): %w", node.ID, err)
}
var prefixes []netip.Prefix
for _, route := range routes {
prefixes = append(prefixes, netip.Prefix(route.Prefix))
}
return prefixes, nil
}
func (hsdb *HSDatabase) GetEnabledRoutes(node *types.Node) ([]netip.Prefix, error) {
return Read(hsdb.DB, func(rx *gorm.DB) ([]netip.Prefix, error) {
return GetEnabledRoutes(rx, node)
})
}
// GetEnabledRoutes returns the routes that are enabled for the node.
func GetEnabledRoutes(tx *gorm.DB, node *types.Node) ([]netip.Prefix, error) {
routes := types.Routes{}
err := tx.
Preload("Node").
Where("node_id = ? AND advertised = ? AND enabled = ?", node.ID, true, true).
Find(&routes).Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("getting enabled routes for node(%d): %w", node.ID, err)
}
var prefixes []netip.Prefix
for _, route := range routes {
prefixes = append(prefixes, netip.Prefix(route.Prefix))
}
return prefixes, nil
}
func IsRoutesEnabled(tx *gorm.DB, node *types.Node, routeStr string) bool {
route, err := netip.ParsePrefix(routeStr)
if err != nil {
return false
}
enabledRoutes, err := GetEnabledRoutes(tx, node)
if err != nil {
return false
}
for _, enabledRoute := range enabledRoutes {
if route == enabledRoute {
return true
}
}
return false
}
func (hsdb *HSDatabase) enableRoutes(
node *types.Node,
newRoutes ...netip.Prefix,
) (*types.StateUpdate, error) {
return Write(hsdb.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
return enableRoutes(tx, node, newRoutes...)
})
}
// enableRoutes enables new routes based on a list of new routes.
func enableRoutes(tx *gorm.DB,
node *types.Node, newRoutes ...netip.Prefix,
) (*types.StateUpdate, error) {
advertisedRoutes, err := GetAdvertisedRoutes(tx, node)
if err != nil {
return nil, err
}
for _, newRoute := range newRoutes {
if !slices.Contains(advertisedRoutes, newRoute) {
return nil, fmt.Errorf(
"route (%s) is not available on node %s: %w",
node.Hostname,
newRoute, ErrNodeRouteIsNotAvailable,
)
}
}
// Separate loop so we don't leave things in a half-updated state
for _, prefix := range newRoutes {
route := types.Route{}
err := tx.Preload("Node").
Where("node_id = ? AND prefix = ?", node.ID, prefix.String()).
First(&route).Error
if err == nil {
route.Enabled = true
// Mark already as primary if there is only this node offering this subnet
// (and is not an exit route)
if !route.IsExitRoute() {
route.IsPrimary = isUniquePrefix(tx, route)
}
err = tx.Save(&route).Error
if err != nil {
return nil, fmt.Errorf("failed to enable route: %w", err)
}
} else {
return nil, fmt.Errorf("failed to find route: %w", err)
}
}
// Ensure the node has the latest routes when notifying the other
// nodes
nRoutes, err := GetNodeRoutes(tx, node)
if err != nil {
return nil, fmt.Errorf("failed to read back routes: %w", err)
}
node.Routes = nRoutes
return ptr.To(types.UpdatePeerChanged(node.ID)), nil
}
func generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
suppliedName = util.ConvertWithFQDNRules(suppliedName)
if len(suppliedName) > util.LabelHostnameLength {

View file

@ -15,12 +15,10 @@ import (
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/puzpuzpuz/xsync/v3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/check.v1"
"gorm.io/gorm"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/ptr"
@ -102,7 +100,7 @@ func (s *Suite) TestHardDeleteNode(c *check.C) {
trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil)
_, err = db.DeleteNode(&node, xsync.NewMapOf[types.NodeID, bool]())
err = db.DeleteNode(&node)
c.Assert(err, check.IsNil)
_, err = db.getNode(types.UserID(user.ID), "testnode3")
@ -458,142 +456,143 @@ func TestHeadscale_generateGivenName(t *testing.T) {
}
}
func TestAutoApproveRoutes(t *testing.T) {
tests := []struct {
name string
acl string
routes []netip.Prefix
want []netip.Prefix
}{
{
name: "2068-approve-issue-sub",
acl: `
{
"groups": {
"group:k8s": ["test"]
},
// TODO(kradalby): replace this test
// func TestAutoApproveRoutes(t *testing.T) {
// tests := []struct {
// name string
// acl string
// routes []netip.Prefix
// want []netip.Prefix
// }{
// {
// name: "2068-approve-issue-sub",
// acl: `
// {
// "groups": {
// "group:k8s": ["test"]
// },
"acls": [
{"action": "accept", "users": ["*"], "ports": ["*:*"]},
],
// "acls": [
// {"action": "accept", "users": ["*"], "ports": ["*:*"]},
// ],
"autoApprovers": {
"routes": {
"10.42.0.0/16": ["test"],
}
}
}`,
routes: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")},
want: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")},
},
{
name: "2068-approve-issue-sub",
acl: `
{
"tagOwners": {
"tag:exit": ["test"],
},
// "autoApprovers": {
// "routes": {
// "10.42.0.0/16": ["test"],
// }
// }
// }`,
// routes: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")},
// want: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")},
// },
// {
// name: "2068-approve-issue-sub",
// acl: `
// {
// "tagOwners": {
// "tag:exit": ["test"],
// },
"groups": {
"group:test": ["test"]
},
// "groups": {
// "group:test": ["test"]
// },
"acls": [
{"action": "accept", "users": ["*"], "ports": ["*:*"]},
],
// "acls": [
// {"action": "accept", "users": ["*"], "ports": ["*:*"]},
// ],
"autoApprovers": {
"exitNode": ["tag:exit"],
"routes": {
"10.10.0.0/16": ["group:test"],
"10.11.0.0/16": ["test"],
}
}
}`,
routes: []netip.Prefix{
tsaddr.AllIPv4(),
tsaddr.AllIPv6(),
netip.MustParsePrefix("10.10.0.0/16"),
netip.MustParsePrefix("10.11.0.0/24"),
},
want: []netip.Prefix{
tsaddr.AllIPv4(),
netip.MustParsePrefix("10.10.0.0/16"),
netip.MustParsePrefix("10.11.0.0/24"),
tsaddr.AllIPv6(),
},
},
}
// "autoApprovers": {
// "exitNode": ["tag:exit"],
// "routes": {
// "10.10.0.0/16": ["group:test"],
// "10.11.0.0/16": ["test"],
// }
// }
// }`,
// routes: []netip.Prefix{
// tsaddr.AllIPv4(),
// tsaddr.AllIPv6(),
// netip.MustParsePrefix("10.10.0.0/16"),
// netip.MustParsePrefix("10.11.0.0/24"),
// },
// want: []netip.Prefix{
// tsaddr.AllIPv4(),
// netip.MustParsePrefix("10.10.0.0/16"),
// netip.MustParsePrefix("10.11.0.0/24"),
// tsaddr.AllIPv6(),
// },
// },
// }
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
adb, err := newSQLiteTestDB()
require.NoError(t, err)
pol, err := policy.LoadACLPolicyFromBytes([]byte(tt.acl))
// for _, tt := range tests {
// t.Run(tt.name, func(t *testing.T) {
// adb, err := newSQLiteTestDB()
// require.NoError(t, err)
// pol, err := policy.LoadACLPolicyFromBytes([]byte(tt.acl))
require.NoError(t, err)
require.NotNil(t, pol)
// require.NoError(t, err)
// require.NotNil(t, pol)
user, err := adb.CreateUser(types.User{Name: "test"})
require.NoError(t, err)
// user, err := adb.CreateUser(types.User{Name: "test"})
// require.NoError(t, err)
pak, err := adb.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
require.NoError(t, err)
// pak, err := adb.CreatePreAuthKey(types.UserID(user.ID), false, nil, nil)
// require.NoError(t, err)
nodeKey := key.NewNode()
machineKey := key.NewMachine()
// nodeKey := key.NewNode()
// machineKey := key.NewMachine()
v4 := netip.MustParseAddr("100.64.0.1")
node := types.Node{
ID: 0,
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
Hostname: "test",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID),
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:exit"},
RoutableIPs: tt.routes,
},
IPv4: &v4,
}
// v4 := netip.MustParseAddr("100.64.0.1")
// node := types.Node{
// ID: 0,
// MachineKey: machineKey.Public(),
// NodeKey: nodeKey.Public(),
// Hostname: "test",
// UserID: user.ID,
// RegisterMethod: util.RegisterMethodAuthKey,
// AuthKeyID: ptr.To(pak.ID),
// Hostinfo: &tailcfg.Hostinfo{
// RequestTags: []string{"tag:exit"},
// RoutableIPs: tt.routes,
// },
// IPv4: &v4,
// }
trx := adb.DB.Save(&node)
require.NoError(t, trx.Error)
// trx := adb.DB.Save(&node)
// require.NoError(t, trx.Error)
sendUpdate, err := adb.SaveNodeRoutes(&node)
require.NoError(t, err)
assert.False(t, sendUpdate)
// sendUpdate, err := adb.SaveNodeRoutes(&node)
// require.NoError(t, err)
// assert.False(t, sendUpdate)
node0ByID, err := adb.GetNodeByID(0)
require.NoError(t, err)
// node0ByID, err := adb.GetNodeByID(0)
// require.NoError(t, err)
users, err := adb.ListUsers()
assert.NoError(t, err)
// users, err := adb.ListUsers()
// assert.NoError(t, err)
nodes, err := adb.ListNodes()
assert.NoError(t, err)
// nodes, err := adb.ListNodes()
// assert.NoError(t, err)
pm, err := policy.NewPolicyManager([]byte(tt.acl), users, nodes)
assert.NoError(t, err)
// pm, err := policy.NewPolicyManager([]byte(tt.acl), users, nodes)
// assert.NoError(t, err)
// TODO(kradalby): Check state update
err = adb.EnableAutoApprovedRoutes(pm, node0ByID)
require.NoError(t, err)
// // TODO(kradalby): Check state update
// err = adb.EnableAutoApprovedRoutes(pm, node0ByID)
// require.NoError(t, err)
enabledRoutes, err := adb.GetEnabledRoutes(node0ByID)
require.NoError(t, err)
assert.Len(t, enabledRoutes, len(tt.want))
// enabledRoutes, err := adb.GetEnabledRoutes(node0ByID)
// require.NoError(t, err)
// assert.Len(t, enabledRoutes, len(tt.want))
tsaddr.SortPrefixes(enabledRoutes)
// tsaddr.SortPrefixes(enabledRoutes)
if diff := cmp.Diff(tt.want, enabledRoutes, util.Comparers...); diff != "" {
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
}
})
}
}
// if diff := cmp.Diff(tt.want, enabledRoutes, util.Comparers...); diff != "" {
// t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
// }
// })
// }
// }
func TestEphemeralGarbageCollectorOrder(t *testing.T) {
want := []types.NodeID{1, 3}

View file

@ -1,676 +0,0 @@
package db
import (
"errors"
"fmt"
"net/netip"
"sort"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/puzpuzpuz/xsync/v3"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"tailscale.com/net/tsaddr"
"tailscale.com/types/ptr"
"tailscale.com/util/set"
)
var ErrRouteIsNotAvailable = errors.New("route is not available")
func GetRoutes(tx *gorm.DB) (types.Routes, error) {
var routes types.Routes
err := tx.
Preload("Node").
Preload("Node.User").
Find(&routes).Error
if err != nil {
return nil, err
}
return routes, nil
}
func getAdvertisedAndEnabledRoutes(tx *gorm.DB) (types.Routes, error) {
var routes types.Routes
err := tx.
Preload("Node").
Preload("Node.User").
Where("advertised = ? AND enabled = ?", true, true).
Find(&routes).Error
if err != nil {
return nil, err
}
return routes, nil
}
func getRoutesByPrefix(tx *gorm.DB, pref netip.Prefix) (types.Routes, error) {
var routes types.Routes
err := tx.
Preload("Node").
Preload("Node.User").
Where("prefix = ?", pref.String()).
Find(&routes).Error
if err != nil {
return nil, err
}
return routes, nil
}
func GetNodeAdvertisedRoutes(tx *gorm.DB, node *types.Node) (types.Routes, error) {
var routes types.Routes
err := tx.
Preload("Node").
Preload("Node.User").
Where("node_id = ? AND advertised = true", node.ID).
Find(&routes).Error
if err != nil {
return nil, err
}
return routes, nil
}
func (hsdb *HSDatabase) GetNodeRoutes(node *types.Node) (types.Routes, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (types.Routes, error) {
return GetNodeRoutes(rx, node)
})
}
func GetNodeRoutes(tx *gorm.DB, node *types.Node) (types.Routes, error) {
var routes types.Routes
err := tx.
Preload("Node").
Preload("Node.User").
Where("node_id = ?", node.ID).
Find(&routes).Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
return routes, nil
}
func GetRoute(tx *gorm.DB, id uint64) (*types.Route, error) {
var route types.Route
err := tx.
Preload("Node").
Preload("Node.User").
First(&route, id).Error
if err != nil {
return nil, err
}
return &route, nil
}
func EnableRoute(tx *gorm.DB, id uint64) (*types.StateUpdate, error) {
route, err := GetRoute(tx, id)
if err != nil {
return nil, err
}
// Tailscale requires both IPv4 and IPv6 exit routes to
// be enabled at the same time, as per
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
if route.IsExitRoute() {
return enableRoutes(
tx,
route.Node,
tsaddr.AllIPv4(),
tsaddr.AllIPv6(),
)
}
return enableRoutes(tx, route.Node, netip.Prefix(route.Prefix))
}
func DisableRoute(tx *gorm.DB,
id uint64,
isLikelyConnected *xsync.MapOf[types.NodeID, bool],
) ([]types.NodeID, error) {
route, err := GetRoute(tx, id)
if err != nil {
return nil, err
}
var routes types.Routes
node := route.Node
// Tailscale requires both IPv4 and IPv6 exit routes to
// be enabled at the same time, as per
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
var update []types.NodeID
if !route.IsExitRoute() {
route.Enabled = false
err = tx.Save(route).Error
if err != nil {
return nil, err
}
update, err = failoverRouteTx(tx, isLikelyConnected, route)
if err != nil {
return nil, err
}
} else {
routes, err = GetNodeRoutes(tx, node)
if err != nil {
return nil, err
}
for i := range routes {
if routes[i].IsExitRoute() {
routes[i].Enabled = false
routes[i].IsPrimary = false
err = tx.Save(&routes[i]).Error
if err != nil {
return nil, err
}
}
}
}
// If update is empty, it means that one was not created
// by failover (as a failover was not necessary), create
// one and return to the caller.
if update == nil {
update = []types.NodeID{node.ID}
}
return update, nil
}
func (hsdb *HSDatabase) DeleteRoute(
id uint64,
isLikelyConnected *xsync.MapOf[types.NodeID, bool],
) ([]types.NodeID, error) {
return Write(hsdb.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
return DeleteRoute(tx, id, isLikelyConnected)
})
}
func DeleteRoute(
tx *gorm.DB,
id uint64,
isLikelyConnected *xsync.MapOf[types.NodeID, bool],
) ([]types.NodeID, error) {
route, err := GetRoute(tx, id)
if err != nil {
return nil, err
}
if route.Node == nil {
// If the route is not assigned to a node, just delete it,
// there are no updates to be sent as no nodes are
// dependent on it
if err := tx.Unscoped().Delete(&route).Error; err != nil {
return nil, err
}
return nil, nil
}
var routes types.Routes
node := route.Node
// Tailscale requires both IPv4 and IPv6 exit routes to
// be enabled at the same time, as per
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
// This means that if we delete a route which is an exit route, delete both.
var update []types.NodeID
if route.IsExitRoute() {
routes, err = GetNodeRoutes(tx, node)
if err != nil {
return nil, err
}
var routesToDelete types.Routes
for _, r := range routes {
if r.IsExitRoute() {
routesToDelete = append(routesToDelete, r)
}
}
if err := tx.Unscoped().Delete(&routesToDelete).Error; err != nil {
return nil, err
}
} else {
update, err = failoverRouteTx(tx, isLikelyConnected, route)
if err != nil {
return nil, nil
}
if err := tx.Unscoped().Delete(&route).Error; err != nil {
return nil, err
}
}
// If update is empty, it means that one was not created
// by failover (as a failover was not necessary), create
// one and return to the caller.
if routes == nil {
routes, err = GetNodeRoutes(tx, node)
if err != nil {
return nil, err
}
}
node.Routes = routes
if update == nil {
update = []types.NodeID{node.ID}
}
return update, nil
}
func deleteNodeRoutes(tx *gorm.DB, node *types.Node, isLikelyConnected *xsync.MapOf[types.NodeID, bool]) ([]types.NodeID, error) {
routes, err := GetNodeRoutes(tx, node)
if err != nil {
return nil, fmt.Errorf("getting node routes: %w", err)
}
var changed []types.NodeID
for i := range routes {
if err := tx.Unscoped().Delete(&routes[i]).Error; err != nil {
return nil, fmt.Errorf("deleting route(%d): %w", &routes[i].ID, err)
}
// TODO(kradalby): This is a bit too aggressive, we could probably
// figure out which routes needs to be failed over rather than all.
chn, err := failoverRouteTx(tx, isLikelyConnected, &routes[i])
if err != nil {
return changed, fmt.Errorf("failing over route after delete: %w", err)
}
if chn != nil {
changed = append(changed, chn...)
}
}
return changed, nil
}
// isUniquePrefix returns if there is another node providing the same route already.
func isUniquePrefix(tx *gorm.DB, route types.Route) bool {
var count int64
tx.Model(&types.Route{}).
Where("prefix = ? AND node_id != ? AND advertised = ? AND enabled = ?",
route.Prefix.String(),
route.NodeID,
true, true).Count(&count)
return count == 0
}
func getPrimaryRoute(tx *gorm.DB, prefix netip.Prefix) (*types.Route, error) {
var route types.Route
err := tx.
Preload("Node").
Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", prefix.String(), true, true, true).
First(&route).Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, gorm.ErrRecordNotFound
}
return &route, nil
}
func (hsdb *HSDatabase) GetNodePrimaryRoutes(node *types.Node) (types.Routes, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (types.Routes, error) {
return GetNodePrimaryRoutes(rx, node)
})
}
// getNodePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover)
// Exit nodes are not considered for this, as they are never marked as Primary.
func GetNodePrimaryRoutes(tx *gorm.DB, node *types.Node) (types.Routes, error) {
var routes types.Routes
err := tx.
Preload("Node").
Where("node_id = ? AND advertised = ? AND enabled = ? AND is_primary = ?", node.ID, true, true, true).
Find(&routes).Error
if err != nil {
return nil, err
}
return routes, nil
}
func (hsdb *HSDatabase) SaveNodeRoutes(node *types.Node) (bool, error) {
return Write(hsdb.DB, func(tx *gorm.DB) (bool, error) {
return SaveNodeRoutes(tx, node)
})
}
// SaveNodeRoutes takes a node and updates the database with
// the new routes.
// It returns a bool whether an update should be sent as the
// saved route impacts nodes.
func SaveNodeRoutes(tx *gorm.DB, node *types.Node) (bool, error) {
sendUpdate := false
currentRoutes := types.Routes{}
err := tx.Where("node_id = ?", node.ID).Find(&currentRoutes).Error
if err != nil {
return sendUpdate, err
}
advertisedRoutes := map[netip.Prefix]bool{}
for _, prefix := range node.Hostinfo.RoutableIPs {
advertisedRoutes[prefix] = false
}
log.Trace().
Str("node", node.Hostname).
Interface("advertisedRoutes", advertisedRoutes).
Interface("currentRoutes", currentRoutes).
Msg("updating routes")
for pos, route := range currentRoutes {
if _, ok := advertisedRoutes[netip.Prefix(route.Prefix)]; ok {
if !route.Advertised {
currentRoutes[pos].Advertised = true
err := tx.Save(&currentRoutes[pos]).Error
if err != nil {
return sendUpdate, err
}
// If a route that is newly "saved" is already
// enabled, set sendUpdate to true as it is now
// available.
if route.Enabled {
sendUpdate = true
}
}
advertisedRoutes[netip.Prefix(route.Prefix)] = true
} else if route.Advertised {
currentRoutes[pos].Advertised = false
currentRoutes[pos].Enabled = false
err := tx.Save(&currentRoutes[pos]).Error
if err != nil {
return sendUpdate, err
}
}
}
for prefix, exists := range advertisedRoutes {
if !exists {
route := types.Route{
NodeID: node.ID.Uint64(),
Prefix: prefix,
Advertised: true,
Enabled: false,
}
err := tx.Create(&route).Error
if err != nil {
return sendUpdate, err
}
}
}
return sendUpdate, nil
}
// FailoverNodeRoutesIfNecessary takes a node and checks if the node's route
// need to be failed over to another host.
// If needed, the failover will be attempted.
func FailoverNodeRoutesIfNecessary(
tx *gorm.DB,
isLikelyConnected *xsync.MapOf[types.NodeID, bool],
node *types.Node,
) (*types.StateUpdate, error) {
nodeRoutes, err := GetNodeRoutes(tx, node)
if err != nil {
return nil, nil
}
changedNodes := make(set.Set[types.NodeID])
nodeRouteLoop:
for _, nodeRoute := range nodeRoutes {
routes, err := getRoutesByPrefix(tx, netip.Prefix(nodeRoute.Prefix))
if err != nil {
return nil, fmt.Errorf("getting routes by prefix: %w", err)
}
for _, route := range routes {
if route.IsPrimary {
// if we have a primary route, and the node is connected
// nothing needs to be done.
if val, ok := isLikelyConnected.Load(route.Node.ID); ok && val {
continue nodeRouteLoop
}
// if not, we need to failover the route
failover := failoverRoute(isLikelyConnected, &route, routes)
if failover != nil {
err := failover.save(tx)
if err != nil {
return nil, fmt.Errorf("saving failover routes: %w", err)
}
changedNodes.Add(failover.old.Node.ID)
changedNodes.Add(failover.new.Node.ID)
continue nodeRouteLoop
}
}
}
}
chng := changedNodes.Slice()
sort.SliceStable(chng, func(i, j int) bool {
return chng[i] < chng[j]
})
if len(changedNodes) != 0 {
return ptr.To(types.UpdatePeerChanged(chng...)), nil
}
return nil, nil
}
// failoverRouteTx takes a route that is no longer available,
// this can be either from:
// - being disabled
// - being deleted
// - host going offline
//
// and tries to find a new route to take over its place.
// If the given route was not primary, it returns early.
func failoverRouteTx(
tx *gorm.DB,
isLikelyConnected *xsync.MapOf[types.NodeID, bool],
r *types.Route,
) ([]types.NodeID, error) {
if r == nil {
return nil, nil
}
// This route is not a primary route, and it is not
// being served to nodes.
if !r.IsPrimary {
return nil, nil
}
// We do not have to failover exit nodes
if r.IsExitRoute() {
return nil, nil
}
routes, err := getRoutesByPrefix(tx, netip.Prefix(r.Prefix))
if err != nil {
return nil, fmt.Errorf("getting routes by prefix: %w", err)
}
fo := failoverRoute(isLikelyConnected, r, routes)
if fo == nil {
return nil, nil
}
err = fo.save(tx)
if err != nil {
return nil, fmt.Errorf("saving failover route: %w", err)
}
log.Trace().
Str("hostname", fo.new.Node.Hostname).
Msgf("set primary to new route, was: id(%d), host(%s), now: id(%d), host(%s)", fo.old.ID, fo.old.Node.Hostname, fo.new.ID, fo.new.Node.Hostname)
// Return a list of the machinekeys of the changed nodes.
return []types.NodeID{fo.old.Node.ID, fo.new.Node.ID}, nil
}
type failover struct {
old *types.Route
new *types.Route
}
func (f *failover) save(tx *gorm.DB) error {
err := tx.Save(f.old).Error
if err != nil {
return fmt.Errorf("saving old primary: %w", err)
}
err = tx.Save(f.new).Error
if err != nil {
return fmt.Errorf("saving new primary: %w", err)
}
return nil
}
func failoverRoute(
isLikelyConnected *xsync.MapOf[types.NodeID, bool],
routeToReplace *types.Route,
altRoutes types.Routes,
) *failover {
if routeToReplace == nil {
return nil
}
// This route is not a primary route, and it is not
// being served to nodes.
if !routeToReplace.IsPrimary {
return nil
}
// We do not have to failover exit nodes
if routeToReplace.IsExitRoute() {
return nil
}
var newPrimary *types.Route
// Find a new suitable route
for idx, route := range altRoutes {
if routeToReplace.ID == route.ID {
continue
}
if !route.Enabled {
continue
}
if isLikelyConnected != nil {
if val, ok := isLikelyConnected.Load(route.Node.ID); ok && val {
newPrimary = &altRoutes[idx]
break
}
}
}
// If a new route was not found/available,
// return without an error.
// We do not want to update the database as
// the one currently marked as primary is the
// best we got.
if newPrimary == nil {
return nil
}
routeToReplace.IsPrimary = false
newPrimary.IsPrimary = true
return &failover{
old: routeToReplace,
new: newPrimary,
}
}
func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
polMan policy.PolicyManager,
node *types.Node,
) error {
return hsdb.Write(func(tx *gorm.DB) error {
return EnableAutoApprovedRoutes(tx, polMan, node)
})
}
// EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy.
func EnableAutoApprovedRoutes(
tx *gorm.DB,
polMan policy.PolicyManager,
node *types.Node,
) error {
if node.IPv4 == nil && node.IPv6 == nil {
return nil // This node has no IPAddresses, so can't possibly match any autoApprovers ACLs
}
routes, err := GetNodeAdvertisedRoutes(tx, node)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("getting advertised routes for node(%s %d): %w", node.Hostname, node.ID, err)
}
log.Trace().Interface("routes", routes).Msg("routes for autoapproving")
var approvedRoutes types.Routes
for _, advertisedRoute := range routes {
if advertisedRoute.Enabled {
continue
}
routeApprovers := polMan.ApproversForRoute(netip.Prefix(advertisedRoute.Prefix))
log.Trace().
Str("node", node.Hostname).
Uint("user.id", node.User.ID).
Strs("routeApprovers", routeApprovers).
Str("prefix", netip.Prefix(advertisedRoute.Prefix).String()).
Msg("looking up route for autoapproving")
for _, approvedAlias := range routeApprovers {
if approvedAlias == node.User.Username() {
approvedRoutes = append(approvedRoutes, advertisedRoute)
} else {
// TODO(kradalby): figure out how to get this to depend on less stuff
approvedIps, err := polMan.ExpandAlias(approvedAlias)
if err != nil {
return fmt.Errorf("expanding alias %q for autoApprovers: %w", approvedAlias, err)
}
// approvedIPs should contain all of node's IPs if it matches the rule, so check for first
if approvedIps.Contains(*node.IPv4) {
approvedRoutes = append(approvedRoutes, advertisedRoute)
}
}
}
}
for _, approvedRoute := range approvedRoutes {
_, err := EnableRoute(tx, uint64(approvedRoute.ID))
if err != nil {
return fmt.Errorf("enabling approved route(%d): %w", approvedRoute.ID, err)
}
}
return nil
}

File diff suppressed because it is too large Load diff

View file

@ -100,6 +100,11 @@ func (h *Headscale) debugHTTPServer() *http.Server {
w.WriteHeader(http.StatusOK)
w.Write(registrationsJSON)
}))
debug.Handle("routes", "Routes", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte(h.primaryRoutes.String()))
}))
err := statsviz.Register(debugMux)
if err == nil {

View file

@ -6,7 +6,9 @@ import (
"errors"
"fmt"
"io"
"net/netip"
"os"
"slices"
"sort"
"strings"
"time"
@ -18,6 +20,7 @@ import (
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
"gorm.io/gorm"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
@ -326,6 +329,51 @@ func (api headscaleV1APIServer) SetTags(
return &v1.SetTagsResponse{Node: node.Proto()}, nil
}
func (api headscaleV1APIServer) SetApprovedRoutes(
ctx context.Context,
request *v1.SetApprovedRoutesRequest,
) (*v1.SetApprovedRoutesResponse, error) {
var routes []netip.Prefix
for _, route := range request.GetRoutes() {
prefix, err := netip.ParsePrefix(route)
if err != nil {
return nil, fmt.Errorf("parsing route: %w", err)
}
// If the prefix is an exit route, add both. The client expect both
// to annotate the node as an exit node.
if prefix == tsaddr.AllIPv4() || prefix == tsaddr.AllIPv6() {
routes = append(routes, tsaddr.AllIPv4(), tsaddr.AllIPv6())
} else {
routes = append(routes, prefix)
}
}
slices.SortFunc(routes, util.ComparePrefix)
slices.Compact(routes)
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
err := db.SetApprovedRoutes(tx, types.NodeID(request.GetNodeId()), routes)
if err != nil {
return nil, err
}
return db.GetNodeByID(tx, types.NodeID(request.GetNodeId()))
})
if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
if api.h.primaryRoutes.SetRoutes(node.ID, node.SubnetRoutes()...) {
ctx := types.NotifyCtx(ctx, "poll-primary-change", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
} else {
ctx = types.NotifyCtx(ctx, "cli-approveroutes", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
}
return &v1.SetApprovedRoutesResponse{Node: node.Proto()}, nil
}
func validateTag(tag string) error {
if strings.Index(tag, "tag:") != 0 {
return errors.New("tag must start with the string 'tag:'")
@ -348,10 +396,7 @@ func (api headscaleV1APIServer) DeleteNode(
return nil, err
}
changedNodes, err := api.h.db.DeleteNode(
node,
api.h.nodeNotifier.LikelyConnectedMap(),
)
err = api.h.db.DeleteNode(node)
if err != nil {
return nil, err
}
@ -359,10 +404,6 @@ func (api headscaleV1APIServer) DeleteNode(
ctx = types.NotifyCtx(ctx, "cli-deletenode", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerRemoved(node.ID))
if changedNodes != nil {
api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(changedNodes...))
}
return &v1.DeleteNodeResponse{}, nil
}
@ -533,100 +574,6 @@ func (api headscaleV1APIServer) BackfillNodeIPs(
return &v1.BackfillNodeIPsResponse{Changes: changes}, nil
}
func (api headscaleV1APIServer) GetRoutes(
ctx context.Context,
request *v1.GetRoutesRequest,
) (*v1.GetRoutesResponse, error) {
routes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Routes, error) {
return db.GetRoutes(rx)
})
if err != nil {
return nil, err
}
return &v1.GetRoutesResponse{
Routes: types.Routes(routes).Proto(),
}, nil
}
func (api headscaleV1APIServer) EnableRoute(
ctx context.Context,
request *v1.EnableRouteRequest,
) (*v1.EnableRouteResponse, error) {
update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
return db.EnableRoute(tx, request.GetRouteId())
})
if err != nil {
return nil, err
}
if update != nil {
ctx := types.NotifyCtx(ctx, "cli-enableroute", "unknown")
api.h.nodeNotifier.NotifyAll(
ctx, *update)
}
return &v1.EnableRouteResponse{}, nil
}
func (api headscaleV1APIServer) DisableRoute(
ctx context.Context,
request *v1.DisableRouteRequest,
) (*v1.DisableRouteResponse, error) {
update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
return db.DisableRoute(tx, request.GetRouteId(), api.h.nodeNotifier.LikelyConnectedMap())
})
if err != nil {
return nil, err
}
if update != nil {
ctx := types.NotifyCtx(ctx, "cli-disableroute", "unknown")
api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(update...))
}
return &v1.DisableRouteResponse{}, nil
}
func (api headscaleV1APIServer) GetNodeRoutes(
ctx context.Context,
request *v1.GetNodeRoutesRequest,
) (*v1.GetNodeRoutesResponse, error) {
node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId()))
if err != nil {
return nil, err
}
routes, err := api.h.db.GetNodeRoutes(node)
if err != nil {
return nil, err
}
return &v1.GetNodeRoutesResponse{
Routes: types.Routes(routes).Proto(),
}, nil
}
func (api headscaleV1APIServer) DeleteRoute(
ctx context.Context,
request *v1.DeleteRouteRequest,
) (*v1.DeleteRouteResponse, error) {
isConnected := api.h.nodeNotifier.LikelyConnectedMap()
update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
return db.DeleteRoute(tx, request.GetRouteId(), isConnected)
})
if err != nil {
return nil, err
}
if update != nil {
ctx := types.NotifyCtx(ctx, "cli-deleteroute", "unknown")
api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(update...))
}
return &v1.DeleteRouteResponse{}, nil
}
func (api headscaleV1APIServer) CreateApiKey(
ctx context.Context,
request *v1.CreateApiKeyRequest,

View file

@ -18,6 +18,7 @@ import (
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/routes"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/klauspost/compress/zstd"
@ -56,6 +57,7 @@ type Mapper struct {
derpMap *tailcfg.DERPMap
notif *notifier.Notifier
polMan policy.PolicyManager
primary *routes.PrimaryRoutes
uid string
created time.Time
@ -73,6 +75,7 @@ func NewMapper(
derpMap *tailcfg.DERPMap,
notif *notifier.Notifier,
polMan policy.PolicyManager,
primary *routes.PrimaryRoutes,
) *Mapper {
uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
@ -82,6 +85,7 @@ func NewMapper(
derpMap: derpMap,
notif: notif,
polMan: polMan,
primary: primary,
uid: uid,
created: time.Now(),
@ -97,15 +101,22 @@ func generateUserProfiles(
node *types.Node,
peers types.Nodes,
) []tailcfg.UserProfile {
userMap := make(map[uint]types.User)
userMap[node.User.ID] = node.User
userMap := make(map[uint]*types.User)
ids := make([]uint, 0, len(userMap))
userMap[node.User.ID] = &node.User
ids = append(ids, node.User.ID)
for _, peer := range peers {
userMap[peer.User.ID] = peer.User // not worth checking if already is there
userMap[peer.User.ID] = &peer.User
ids = append(ids, peer.User.ID)
}
slices.Sort(ids)
slices.Compact(ids)
var profiles []tailcfg.UserProfile
for _, user := range userMap {
profiles = append(profiles, user.TailscaleUserProfile())
for _, id := range ids {
if userMap[id] != nil {
profiles = append(profiles, userMap[id].TailscaleUserProfile())
}
}
return profiles
@ -166,6 +177,7 @@ func (m *Mapper) fullMapResponse(
resp,
true, // full change
m.polMan,
m.primary,
node,
capVer,
peers,
@ -271,6 +283,7 @@ func (m *Mapper) PeerChangedResponse(
&resp,
false, // partial change
m.polMan,
m.primary,
node,
mapRequest.Version,
changedNodes,
@ -299,7 +312,7 @@ func (m *Mapper) PeerChangedResponse(
// Add the node itself, it might have changed, and particularly
// if there are no patches or changes, this is a self update.
tailnode, err := tailNode(node, mapRequest.Version, m.polMan, m.cfg)
tailnode, err := tailNode(node, mapRequest.Version, m.polMan, m.primary, m.cfg)
if err != nil {
return nil, err
}
@ -446,7 +459,7 @@ func (m *Mapper) baseWithConfigMapResponse(
) (*tailcfg.MapResponse, error) {
resp := m.baseMapResponse()
tailnode, err := tailNode(node, capVer, m.polMan, m.cfg)
tailnode, err := tailNode(node, capVer, m.polMan, m.primary, m.cfg)
if err != nil {
return nil, err
}
@ -500,6 +513,7 @@ func appendPeerChanges(
fullChange bool,
polMan policy.PolicyManager,
primary *routes.PrimaryRoutes,
node *types.Node,
capVer tailcfg.CapabilityVersion,
changed types.Nodes,
@ -522,7 +536,7 @@ func appendPeerChanges(
dnsConfig := generateDNSConfig(cfg, node)
tailPeers, err := tailNodes(changed, capVer, polMan, cfg)
tailPeers, err := tailNodes(changed, capVer, polMan, primary, cfg)
if err != nil {
return err
}

View file

@ -6,12 +6,11 @@ import (
"testing"
"time"
"github.com/davecgh/go-spew/spew"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/routes"
"github.com/juanfont/headscale/hscontrol/types"
"gopkg.in/check.v1"
"gorm.io/gorm"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
@ -24,51 +23,6 @@ var iap = func(ipStr string) *netip.Addr {
return &ip
}
func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
mach := func(hostname, username string, userid uint) *types.Node {
return &types.Node{
Hostname: hostname,
UserID: userid,
User: types.User{
Model: gorm.Model{
ID: userid,
},
Name: username,
},
}
}
nodeInShared1 := mach("test_get_shared_nodes_1", "user1", 1)
nodeInShared2 := mach("test_get_shared_nodes_2", "user2", 2)
nodeInShared3 := mach("test_get_shared_nodes_3", "user3", 3)
node2InShared1 := mach("test_get_shared_nodes_4", "user1", 1)
userProfiles := generateUserProfiles(
nodeInShared1,
types.Nodes{
nodeInShared2, nodeInShared3, node2InShared1,
},
)
c.Assert(len(userProfiles), check.Equals, 3)
users := []string{
"user1", "user2", "user3",
}
for _, user := range users {
found := false
for _, userProfile := range userProfiles {
if userProfile.DisplayName == user {
found = true
break
}
}
c.Assert(found, check.Equals, true)
}
}
func TestDNSConfigMapResponse(t *testing.T) {
tests := []struct {
magicDNS bool
@ -159,11 +113,11 @@ func Test_fullMapResponse(t *testing.T) {
lastSeen := time.Date(2009, time.November, 10, 23, 9, 0, 0, time.UTC)
expire := time.Date(2500, time.November, 11, 23, 0, 0, 0, time.UTC)
user1 := types.User{Model: gorm.Model{ID: 0}, Name: "mini"}
user2 := types.User{Model: gorm.Model{ID: 1}, Name: "peer2"}
user1 := types.User{Model: gorm.Model{ID: 1}, Name: "user1"}
user2 := types.User{Model: gorm.Model{ID: 2}, Name: "user2"}
mini := &types.Node{
ID: 0,
ID: 1,
MachineKey: mustMK(
"mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507",
),
@ -182,35 +136,22 @@ func Test_fullMapResponse(t *testing.T) {
AuthKey: &types.PreAuthKey{},
LastSeen: &lastSeen,
Expiry: &expire,
Hostinfo: &tailcfg.Hostinfo{},
Routes: []types.Route{
{
Prefix: tsaddr.AllIPv4(),
Advertised: true,
Enabled: true,
IsPrimary: false,
},
{
Prefix: netip.MustParsePrefix("192.168.0.0/24"),
Advertised: true,
Enabled: true,
IsPrimary: true,
},
{
Prefix: netip.MustParsePrefix("172.0.0.0/10"),
Advertised: true,
Enabled: false,
IsPrimary: true,
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{
tsaddr.AllIPv4(),
netip.MustParsePrefix("192.168.0.0/24"),
netip.MustParsePrefix("172.0.0.0/10"),
},
},
CreatedAt: created,
ApprovedRoutes: []netip.Prefix{tsaddr.AllIPv4(), netip.MustParsePrefix("192.168.0.0/24")},
CreatedAt: created,
}
tailMini := &tailcfg.Node{
ID: 0,
StableID: "0",
ID: 1,
StableID: "1",
Name: "mini",
User: 0,
User: tailcfg.UserID(user1.ID),
Key: mustNK(
"nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
),
@ -227,12 +168,17 @@ func Test_fullMapResponse(t *testing.T) {
tsaddr.AllIPv4(),
netip.MustParsePrefix("192.168.0.0/24"),
},
HomeDERP: 0,
LegacyDERPString: "127.3.3.40:0",
Hostinfo: hiview(tailcfg.Hostinfo{}),
HomeDERP: 0,
LegacyDERPString: "127.3.3.40:0",
Hostinfo: hiview(tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{
tsaddr.AllIPv4(),
netip.MustParsePrefix("192.168.0.0/24"),
netip.MustParsePrefix("172.0.0.0/10"),
},
}),
Created: created,
Tags: []string{},
PrimaryRoutes: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")},
LastSeen: &lastSeen,
MachineAuthorized: true,
@ -244,7 +190,7 @@ func Test_fullMapResponse(t *testing.T) {
}
peer1 := &types.Node{
ID: 1,
ID: 2,
MachineKey: mustMK(
"mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507",
),
@ -257,20 +203,20 @@ func Test_fullMapResponse(t *testing.T) {
IPv4: iap("100.64.0.2"),
Hostname: "peer1",
GivenName: "peer1",
UserID: user1.ID,
User: user1,
UserID: user2.ID,
User: user2,
ForcedTags: []string{},
LastSeen: &lastSeen,
Expiry: &expire,
Hostinfo: &tailcfg.Hostinfo{},
Routes: []types.Route{},
CreatedAt: created,
}
tailPeer1 := &tailcfg.Node{
ID: 1,
StableID: "1",
ID: 2,
StableID: "2",
Name: "peer1",
User: tailcfg.UserID(user2.ID),
Key: mustNK(
"nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
),
@ -288,7 +234,6 @@ func Test_fullMapResponse(t *testing.T) {
Hostinfo: hiview(tailcfg.Hostinfo{}),
Created: created,
Tags: []string{},
PrimaryRoutes: []netip.Prefix{},
LastSeen: &lastSeen,
MachineAuthorized: true,
@ -299,30 +244,6 @@ func Test_fullMapResponse(t *testing.T) {
},
}
peer2 := &types.Node{
ID: 2,
MachineKey: mustMK(
"mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507",
),
NodeKey: mustNK(
"nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
),
DiscoKey: mustDK(
"discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084",
),
IPv4: iap("100.64.0.3"),
Hostname: "peer2",
GivenName: "peer2",
UserID: user2.ID,
User: user2,
ForcedTags: []string{},
LastSeen: &lastSeen,
Expiry: &expire,
Hostinfo: &tailcfg.Hostinfo{},
Routes: []types.Route{},
CreatedAt: created,
}
tests := []struct {
name string
pol *policy.ACLPolicy
@ -364,7 +285,7 @@ func Test_fullMapResponse(t *testing.T) {
Domain: "",
CollectServices: "false",
PacketFilter: []tailcfg.FilterRule{},
UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}},
UserProfiles: []tailcfg.UserProfile{{ID: tailcfg.UserID(user1.ID), LoginName: "user1", DisplayName: "user1"}},
SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
ControlTime: &time.Time{},
Debug: &tailcfg.Debug{
@ -398,9 +319,12 @@ func Test_fullMapResponse(t *testing.T) {
Domain: "",
CollectServices: "false",
PacketFilter: []tailcfg.FilterRule{},
UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}},
SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
ControlTime: &time.Time{},
UserProfiles: []tailcfg.UserProfile{
{ID: tailcfg.UserID(user1.ID), LoginName: "user1", DisplayName: "user1"},
{ID: tailcfg.UserID(user2.ID), LoginName: "user2", DisplayName: "user2"},
},
SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
ControlTime: &time.Time{},
Debug: &tailcfg.Debug{
DisableLogTail: true,
},
@ -410,6 +334,9 @@ func Test_fullMapResponse(t *testing.T) {
{
name: "with-pol-map-response",
pol: &policy.ACLPolicy{
Hosts: policy.Hosts{
"mini": netip.MustParsePrefix("100.64.0.1/32"),
},
ACLs: []policy.ACL{
{
Action: "accept",
@ -421,7 +348,6 @@ func Test_fullMapResponse(t *testing.T) {
node: mini,
peers: types.Nodes{
peer1,
peer2,
},
derpMap: &tailcfg.DERPMap{},
cfg: &types.Config{
@ -449,7 +375,8 @@ func Test_fullMapResponse(t *testing.T) {
},
},
UserProfiles: []tailcfg.UserProfile{
{LoginName: "mini", DisplayName: "mini"},
{ID: tailcfg.UserID(user1.ID), LoginName: "user1", DisplayName: "user1"},
{ID: tailcfg.UserID(user2.ID), LoginName: "user2", DisplayName: "user2"},
},
SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
ControlTime: &time.Time{},
@ -464,6 +391,12 @@ func Test_fullMapResponse(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
polMan, _ := policy.NewPolicyManagerForTest(tt.pol, []types.User{user1, user2}, append(tt.peers, tt.node))
primary := routes.New()
primary.SetRoutes(tt.node.ID, tt.node.SubnetRoutes()...)
for _, peer := range tt.peers {
primary.SetRoutes(peer.ID, peer.SubnetRoutes()...)
}
mappy := NewMapper(
nil,
@ -471,6 +404,7 @@ func Test_fullMapResponse(t *testing.T) {
tt.derpMap,
nil,
polMan,
primary,
)
got, err := mappy.fullMapResponse(
@ -485,8 +419,6 @@ func Test_fullMapResponse(t *testing.T) {
return
}
spew.Dump(got)
if diff := cmp.Diff(
tt.want,
got,

View file

@ -6,6 +6,7 @@ import (
"time"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/routes"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/samber/lo"
"tailscale.com/tailcfg"
@ -15,6 +16,7 @@ func tailNodes(
nodes types.Nodes,
capVer tailcfg.CapabilityVersion,
polMan policy.PolicyManager,
primary *routes.PrimaryRoutes,
cfg *types.Config,
) ([]*tailcfg.Node, error) {
tNodes := make([]*tailcfg.Node, len(nodes))
@ -24,6 +26,7 @@ func tailNodes(
node,
capVer,
polMan,
primary,
cfg,
)
if err != nil {
@ -41,6 +44,7 @@ func tailNode(
node *types.Node,
capVer tailcfg.CapabilityVersion,
polMan policy.PolicyManager,
primary *routes.PrimaryRoutes,
cfg *types.Config,
) (*tailcfg.Node, error) {
addrs := node.Prefixes()
@ -49,17 +53,8 @@ func tailNode(
[]netip.Prefix{},
addrs...) // we append the node own IP, as it is required by the clients
primaryPrefixes := []netip.Prefix{}
for _, route := range node.Routes {
if route.Enabled {
if route.IsPrimary {
allowedIPs = append(allowedIPs, netip.Prefix(route.Prefix))
primaryPrefixes = append(primaryPrefixes, netip.Prefix(route.Prefix))
} else if route.IsExitRoute() {
allowedIPs = append(allowedIPs, netip.Prefix(route.Prefix))
}
}
for _, route := range node.SubnetRoutes() {
allowedIPs = append(allowedIPs, netip.Prefix(route))
}
var derp int
@ -103,6 +98,7 @@ func tailNode(
Machine: node.MachineKey,
DiscoKey: node.DiscoKey,
Addresses: addrs,
PrimaryRoutes: primary.PrimaryRoutes(node.ID),
AllowedIPs: allowedIPs,
Endpoints: node.Endpoints,
HomeDERP: derp,
@ -114,8 +110,6 @@ func tailNode(
Tags: tags,
PrimaryRoutes: primaryPrefixes,
MachineAuthorized: !node.IsExpired(),
Expired: node.IsExpired(),
}

View file

@ -9,6 +9,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/routes"
"github.com/juanfont/headscale/hscontrol/types"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
@ -72,7 +73,6 @@ func TestTailNode(t *testing.T) {
LegacyDERPString: "127.3.3.40:0",
Hostinfo: hiview(tailcfg.Hostinfo{}),
Tags: []string{},
PrimaryRoutes: []netip.Prefix{},
MachineAuthorized: true,
CapMap: tailcfg.NodeCapMap{
@ -107,28 +107,15 @@ func TestTailNode(t *testing.T) {
AuthKey: &types.PreAuthKey{},
LastSeen: &lastSeen,
Expiry: &expire,
Hostinfo: &tailcfg.Hostinfo{},
Routes: []types.Route{
{
Prefix: tsaddr.AllIPv4(),
Advertised: true,
Enabled: true,
IsPrimary: false,
},
{
Prefix: netip.MustParsePrefix("192.168.0.0/24"),
Advertised: true,
Enabled: true,
IsPrimary: true,
},
{
Prefix: netip.MustParsePrefix("172.0.0.0/10"),
Advertised: true,
Enabled: false,
IsPrimary: true,
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{
tsaddr.AllIPv4(),
netip.MustParsePrefix("192.168.0.0/24"),
netip.MustParsePrefix("172.0.0.0/10"),
},
},
CreatedAt: created,
ApprovedRoutes: []netip.Prefix{tsaddr.AllIPv4(), netip.MustParsePrefix("192.168.0.0/24")},
CreatedAt: created,
},
pol: &policy.ACLPolicy{},
dnsConfig: &tailcfg.DNSConfig{},
@ -159,8 +146,14 @@ func TestTailNode(t *testing.T) {
},
HomeDERP: 0,
LegacyDERPString: "127.3.3.40:0",
Hostinfo: hiview(tailcfg.Hostinfo{}),
Created: created,
Hostinfo: hiview(tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{
tsaddr.AllIPv4(),
netip.MustParsePrefix("192.168.0.0/24"),
netip.MustParsePrefix("172.0.0.0/10"),
},
}),
Created: created,
Tags: []string{},
@ -187,15 +180,22 @@ func TestTailNode(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
polMan, _ := policy.NewPolicyManagerForTest(tt.pol, []types.User{}, types.Nodes{tt.node})
primary := routes.New()
cfg := &types.Config{
BaseDomain: tt.baseDomain,
TailcfgDNSConfig: tt.dnsConfig,
RandomizeClientPort: false,
}
_ = primary.SetRoutes(tt.node.ID, tt.node.SubnetRoutes()...)
// This is a hack to avoid having a second node to test the primary route.
// This should be baked into the test case proper if it is extended in the future.
_ = primary.SetRoutes(2, netip.MustParsePrefix("192.168.0.0/24"))
got, err := tailNode(
tt.node,
0,
polMan,
primary,
cfg,
)
@ -249,6 +249,7 @@ func TestNodeExpiry(t *testing.T) {
node,
0,
&policy.PolicyManagerV1{},
nil,
&types.Config{},
)
if err != nil {

View file

@ -243,6 +243,7 @@ func (pol *ACLPolicy) CompileFilterRules(
// ReduceFilterRules takes a node and a set of rules and removes all rules and destinations
// that are not relevant to that particular node.
func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.FilterRule {
// TODO(kradalby): Make this nil and not alloc unless needed
ret := []tailcfg.FilterRule{}
for _, rule := range rules {
@ -264,13 +265,11 @@ func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.F
// If the node exposes routes, ensure they are note removed
// when the filters are reduced.
if node.Hostinfo != nil {
if len(node.Hostinfo.RoutableIPs) > 0 {
for _, routableIP := range node.Hostinfo.RoutableIPs {
if expanded.OverlapsPrefix(routableIP) {
dests = append(dests, dest)
continue DEST_LOOP
}
if len(node.SubnetRoutes()) > 0 {
for _, routableIP := range node.SubnetRoutes() {
if expanded.OverlapsPrefix(routableIP) {
dests = append(dests, dest)
continue DEST_LOOP
}
}
}

View file

@ -2165,6 +2165,9 @@ func TestReduceFilterRules(t *testing.T) {
netip.MustParsePrefix("10.33.0.0/16"),
},
},
ApprovedRoutes: []netip.Prefix{
netip.MustParsePrefix("10.33.0.0/16"),
},
},
peers: types.Nodes{
&types.Node{
@ -2292,6 +2295,7 @@ func TestReduceFilterRules(t *testing.T) {
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tsaddr.ExitRoutes(),
},
ApprovedRoutes: tsaddr.ExitRoutes(),
},
peers: types.Nodes{
&types.Node{
@ -2398,6 +2402,7 @@ func TestReduceFilterRules(t *testing.T) {
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tsaddr.ExitRoutes(),
},
ApprovedRoutes: tsaddr.ExitRoutes(),
},
peers: types.Nodes{
&types.Node{
@ -2513,6 +2518,10 @@ func TestReduceFilterRules(t *testing.T) {
netip.MustParsePrefix("16.0.0.0/16"),
},
},
ApprovedRoutes: []netip.Prefix{
netip.MustParsePrefix("8.0.0.0/16"),
netip.MustParsePrefix("16.0.0.0/16"),
},
},
peers: types.Nodes{
&types.Node{
@ -2603,6 +2612,10 @@ func TestReduceFilterRules(t *testing.T) {
netip.MustParsePrefix("16.0.0.0/8"),
},
},
ApprovedRoutes: []netip.Prefix{
netip.MustParsePrefix("8.0.0.0/8"),
netip.MustParsePrefix("16.0.0.0/8"),
},
},
peers: types.Nodes{
&types.Node{
@ -2683,7 +2696,8 @@ func TestReduceFilterRules(t *testing.T) {
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")},
},
ForcedTags: []string{"tag:access-servers"},
ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")},
ForcedTags: []string{"tag:access-servers"},
},
peers: types.Nodes{
&types.Node{
@ -3475,14 +3489,10 @@ func Test_getFilteredByACLPeers(t *testing.T) {
IPv4: iap("100.64.0.2"),
Hostname: "router",
User: types.User{Name: "router"},
Routes: types.Routes{
types.Route{
NodeID: 2,
Prefix: netip.MustParsePrefix("10.33.0.0/16"),
IsPrimary: true,
Enabled: true,
},
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.33.0.0/16")},
},
ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("10.33.0.0/16")},
},
},
rules: []tailcfg.FilterRule{
@ -3508,14 +3518,10 @@ func Test_getFilteredByACLPeers(t *testing.T) {
IPv4: iap("100.64.0.2"),
Hostname: "router",
User: types.User{Name: "router"},
Routes: types.Routes{
types.Route{
NodeID: 2,
Prefix: netip.MustParsePrefix("10.33.0.0/16"),
IsPrimary: true,
Enabled: true,
},
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.33.0.0/16")},
},
ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("10.33.0.0/16")},
},
},
},

View file

@ -9,8 +9,8 @@ import (
)
type Match struct {
Srcs *netipx.IPSet
Dests *netipx.IPSet
srcs *netipx.IPSet
dests *netipx.IPSet
}
func MatchFromFilterRule(rule tailcfg.FilterRule) Match {
@ -42,16 +42,16 @@ func MatchFromStrings(sources, destinations []string) Match {
destsSet, _ := dests.IPSet()
match := Match{
Srcs: srcsSet,
Dests: destsSet,
srcs: srcsSet,
dests: destsSet,
}
return match
}
func (m *Match) SrcsContainsIPs(ips []netip.Addr) bool {
func (m *Match) SrcsContainsIPs(ips ...netip.Addr) bool {
for _, ip := range ips {
if m.Srcs.Contains(ip) {
if m.srcs.Contains(ip) {
return true
}
}
@ -59,9 +59,29 @@ func (m *Match) SrcsContainsIPs(ips []netip.Addr) bool {
return false
}
func (m *Match) DestsContainsIP(ips []netip.Addr) bool {
func (m *Match) DestsContainsIP(ips ...netip.Addr) bool {
for _, ip := range ips {
if m.Dests.Contains(ip) {
if m.dests.Contains(ip) {
return true
}
}
return false
}
func (m *Match) SrcsOverlapsPrefixes(prefixes ...netip.Prefix) bool {
for _, prefix := range prefixes {
if m.srcs.ContainsPrefix(prefix) {
return true
}
}
return false
}
func (m *Match) DestsOverlapsPrefixes(prefixes ...netip.Prefix) bool {
for _, prefix := range prefixes {
if m.dests.ContainsPrefix(prefix) {
return true
}
}

View file

@ -23,6 +23,9 @@ type PolicyManager interface {
SetPolicy([]byte) (bool, error)
SetUsers(users []types.User) (bool, error)
SetNodes(nodes types.Nodes) (bool, error)
// NodeCanApproveRoute reports whether the given node can approve the given route.
NodeCanApproveRoute(*types.Node, netip.Prefix) bool
}
func NewPolicyManagerFromPath(path string, users []types.User, nodes types.Nodes) (PolicyManager, error) {
@ -185,3 +188,32 @@ func (pm *PolicyManagerV1) ExpandAlias(alias string) (*netipx.IPSet, error) {
}
return ips, nil
}
func (pm *PolicyManagerV1) NodeCanApproveRoute(node *types.Node, route netip.Prefix) bool {
if pm.pol == nil {
return false
}
pm.mu.Lock()
defer pm.mu.Unlock()
approvers, _ := pm.pol.AutoApprovers.GetRouteApprovers(route)
for _, approvedAlias := range approvers {
if approvedAlias == node.User.Username() {
return true
} else {
ips, err := pm.pol.ExpandAlias(pm.nodes, pm.users, approvedAlias)
if err != nil {
return false
}
// approvedIPs should contain all of node's IPs if it matches the rule, so check for first
if ips.Contains(*node.IPv4) {
return true
}
}
}
return false
}

View file

@ -7,16 +7,15 @@ import (
"net/http"
"net/netip"
"slices"
"strings"
"time"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/mapper"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"github.com/samber/lo"
"github.com/sasha-s/go-deadlock"
xslices "golang.org/x/exp/slices"
"gorm.io/gorm"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
)
@ -205,7 +204,15 @@ func (m *mapSession) serveLongPoll() {
if m.h.nodeNotifier.RemoveNode(m.node.ID, m.ch) {
// Failover the node's routes if any.
m.h.updateNodeOnlineStatus(false, m.node)
m.pollFailoverRoutes("node closing connection", m.node)
// When a node disconnects, and it causes the primary route map to change,
// send a full update to all nodes.
// TODO(kradalby): This can likely be made more effective, but likely most
// nodes has access to the same routes, so it might not be a big deal.
if m.h.primaryRoutes.SetRoutes(m.node.ID) {
ctx := types.NotifyCtx(context.Background(), "poll-primary-change", m.node.Hostname)
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
}
m.afterServeLongPoll()
@ -216,7 +223,10 @@ func (m *mapSession) serveLongPoll() {
m.h.pollNetMapStreamWG.Add(1)
defer m.h.pollNetMapStreamWG.Done()
m.pollFailoverRoutes("node connected", m.node)
if m.h.primaryRoutes.SetRoutes(m.node.ID, m.node.SubnetRoutes()...) {
ctx := types.NotifyCtx(context.Background(), "poll-primary-change", m.node.Hostname)
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
// Upgrade the writer to a ResponseController
rc := http.NewResponseController(m.w)
@ -383,22 +393,6 @@ func (m *mapSession) serveLongPoll() {
}
}
func (m *mapSession) pollFailoverRoutes(where string, node *types.Node) {
update, err := db.Write(m.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
return db.FailoverNodeRoutesIfNecessary(tx, m.h.nodeNotifier.LikelyConnectedMap(), node)
})
if err != nil {
m.errf(err, fmt.Sprintf("failed to ensure failover routes, %s", where))
return
}
if update != nil && !update.Empty() {
ctx := types.NotifyCtx(context.Background(), fmt.Sprintf("poll-%s-routes-ensurefailover", strings.ReplaceAll(where, " ", "-")), node.Hostname)
m.h.nodeNotifier.NotifyWithIgnore(ctx, *update, node.ID)
}
}
// updateNodeOnlineStatus records the last seen status of a node and notifies peers
// about change in their online/offline status.
// It takes a StateUpdateType of either StatePeerOnlineChanged or StatePeerOfflineChanged.
@ -414,15 +408,6 @@ func (h *Headscale) updateNodeOnlineStatus(online bool, node *types.Node) {
// lastSeen is only relevant if the node is disconnected.
node.LastSeen = &now
change.LastSeen = &now
err := h.db.Write(func(tx *gorm.DB) error {
return db.SetLastSeen(tx, node.ID, *node.LastSeen)
})
if err != nil {
log.Error().Err(err).Msg("Cannot update node LastSeen")
return
}
}
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-onlinestatus", node.Hostname)
@ -471,36 +456,47 @@ func (m *mapSession) handleEndpointUpdate() {
// If the hostinfo has changed, but not the routes, just update
// hostinfo and let the function continue.
if routesChanged {
var err error
_, err = m.h.db.SaveNodeRoutes(m.node)
if err != nil {
m.errf(err, "Error processing node routes")
http.Error(m.w, "", http.StatusInternalServerError)
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
return
}
// TODO(kradalby): Only update the node that has actually changed
// TODO(kradalby): I am not sure if we need this?
nodesChangedHook(m.h.db, m.h.polMan, m.h.nodeNotifier)
if m.h.polMan != nil {
// update routes with peer information
err := m.h.db.EnableAutoApprovedRoutes(m.h.polMan, m.node)
if err != nil {
m.errf(err, "Error running auto approved routes")
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
// Take all the routes presented to us by the node and check
// if any of them should be auto approved by the policy.
// If any of them are, add them to the approved routes of the node.
// Keep all the old entries and compact the list to remove duplicates.
var newApproved []netip.Prefix
for _, route := range m.node.Hostinfo.RoutableIPs {
if m.h.polMan.NodeCanApproveRoute(m.node, route) {
newApproved = append(newApproved, route)
}
}
if newApproved != nil {
newApproved = append(newApproved, m.node.ApprovedRoutes...)
slices.SortFunc(newApproved, util.ComparePrefix)
slices.Compact(newApproved)
newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool {
return route.IsValid()
})
m.node.ApprovedRoutes = newApproved
if m.h.primaryRoutes.SetRoutes(m.node.ID, m.node.SubnetRoutes()...) {
ctx := types.NotifyCtx(m.ctx, "poll-primary-change", m.node.Hostname)
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
} else {
ctx := types.NotifyCtx(m.ctx, "cli-approveroutes", m.node.Hostname)
m.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(m.node.ID), m.node.ID)
// TODO(kradalby): I am not sure if we need this?
// Send an update to the node itself with to ensure it
// has an updated packetfilter allowing the new route
// if it is defined in the ACL.
ctx = types.NotifyCtx(m.ctx, "poll-nodeupdate-self-hostinfochange", m.node.Hostname)
m.h.nodeNotifier.NotifyByNodeID(
ctx,
types.UpdateSelf(m.node.ID),
m.node.ID)
}
}
// Send an update to the node itself with to ensure it
// has an updated packetfilter allowing the new route
// if it is defined in the ACL.
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-self-hostinfochange", m.node.Hostname)
m.h.nodeNotifier.NotifyByNodeID(
ctx,
types.UpdateSelf(m.node.ID),
m.node.ID)
}
// Check if there has been a change to Hostname and update them

186
hscontrol/routes/primary.go Normal file
View file

@ -0,0 +1,186 @@
package routes
import (
"fmt"
"log"
"net/netip"
"slices"
"sort"
"strings"
"sync"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
xmaps "golang.org/x/exp/maps"
"tailscale.com/util/set"
)
type PrimaryRoutes struct {
mu sync.Mutex
// routes is a map of prefixes that are adverties and approved and available
// in the global headscale state.
routes map[types.NodeID]set.Set[netip.Prefix]
// primaries is a map of prefixes to the node that is the primary for that prefix.
primaries map[netip.Prefix]types.NodeID
isPrimary map[types.NodeID]bool
}
func New() *PrimaryRoutes {
return &PrimaryRoutes{
routes: make(map[types.NodeID]set.Set[netip.Prefix]),
primaries: make(map[netip.Prefix]types.NodeID),
isPrimary: make(map[types.NodeID]bool),
}
}
// updatePrimaryLocked recalculates the primary routes and updates the internal state.
// It returns true if the primary routes have changed.
// It is assumed that the caller holds the lock.
// The algorthm is as follows:
// 1. Reset the primaries map.
// 2. Iterate over the routes and count the number of times a prefix is advertised.
// 3. If a prefix is advertised by at least two nodes, it is a primary route.
// 4. If the primary routes have changed, update the internal state and return true.
// 5. Otherwise, return false.
func (pr *PrimaryRoutes) updatePrimaryLocked() bool {
// reset the primaries map, as we are going to recalculate it.
allPrimaries := make(map[netip.Prefix][]types.NodeID)
pr.isPrimary = make(map[types.NodeID]bool)
changed := false
// sort the node ids so we can iterate over them in a deterministic order.
// this is important so the same node is chosen two times in a row
// as the primary route.
ids := types.NodeIDs(xmaps.Keys(pr.routes))
sort.Sort(ids)
// Create a map of prefixes to nodes that serve them so we
// can determine the primary route for each prefix.
for _, id := range ids {
routes := pr.routes[id]
for route := range routes {
if _, ok := allPrimaries[route]; !ok {
allPrimaries[route] = []types.NodeID{id}
} else {
allPrimaries[route] = append(allPrimaries[route], id)
}
}
}
// Go through all prefixes and determine the primary route for each.
// If the number of routes is below the minimum, remove the primary.
// If the current primary is still available, continue.
// If the current primary is not available, select a new one.
for prefix, nodes := range allPrimaries {
if node, ok := pr.primaries[prefix]; ok {
if len(nodes) < 2 {
delete(pr.primaries, prefix)
changed = true
continue
}
// If the current primary is still available, continue.
if slices.Contains(nodes, node) {
continue
}
}
if len(nodes) >= 2 {
pr.primaries[prefix] = nodes[0]
changed = true
}
}
// Clean up any remaining primaries that are no longer valid.
for prefix := range pr.primaries {
if _, ok := allPrimaries[prefix]; !ok {
delete(pr.primaries, prefix)
changed = true
}
}
// Populate the quick lookup index for primary routes
for _, nodeID := range pr.primaries {
pr.isPrimary[nodeID] = true
}
return changed
}
func (pr *PrimaryRoutes) SetRoutes(node types.NodeID, prefix ...netip.Prefix) bool {
pr.mu.Lock()
defer pr.mu.Unlock()
// If no routes are being set, remove the node from the routes map.
if len(prefix) == 0 {
log.Printf("Removing node %d from routes", node)
if _, ok := pr.routes[node]; ok {
delete(pr.routes, node)
return pr.updatePrimaryLocked()
}
return false
}
if _, ok := pr.routes[node]; !ok {
pr.routes[node] = make(set.Set[netip.Prefix], len(prefix))
}
for _, p := range prefix {
pr.routes[node].Add(p)
}
return pr.updatePrimaryLocked()
}
func (pr *PrimaryRoutes) PrimaryRoutes(id types.NodeID) []netip.Prefix {
if pr == nil {
return nil
}
pr.mu.Lock()
defer pr.mu.Unlock()
// Short circuit if the node is not a primary for any route.
if _, ok := pr.isPrimary[id]; !ok {
return nil
}
var routes []netip.Prefix
for prefix, node := range pr.primaries {
if node == id {
routes = append(routes, prefix)
}
}
return routes
}
func (pr *PrimaryRoutes) String() string {
pr.mu.Lock()
defer pr.mu.Unlock()
return pr.stringLocked()
}
func (pr *PrimaryRoutes) stringLocked() string {
var sb strings.Builder
fmt.Fprintln(&sb, "Available routes:")
ids := types.NodeIDs(xmaps.Keys(pr.routes))
sort.Sort(ids)
for _, id := range ids {
prefixes := pr.routes[id]
fmt.Fprintf(&sb, "\nNode %d: %s", id, strings.Join(util.PrefixesToString(prefixes.Slice()), ", "))
}
fmt.Fprintln(&sb, "\n\nCurrent primary routes:")
for route, nodeID := range pr.primaries {
fmt.Fprintf(&sb, "\nRoute %s: %d", route, nodeID)
}
return sb.String()
}

View file

@ -0,0 +1,316 @@
package routes
import (
"net/netip"
"sync"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
)
// mp is a helper function that wraps netip.MustParsePrefix.
func mp(prefix string) netip.Prefix {
return netip.MustParsePrefix(prefix)
}
func TestPrimaryRoutes(t *testing.T) {
tests := []struct {
name string
operations func(pr *PrimaryRoutes) bool
nodeID types.NodeID
expectedRoutes []netip.Prefix
expectedChange bool
}{
{
name: "single-node-registers-single-route",
operations: func(pr *PrimaryRoutes) bool {
return pr.SetRoutes(1, mp("192.168.1.0/24"))
},
nodeID: 1,
expectedRoutes: nil,
expectedChange: false,
},
{
name: "multiple-nodes-register-different-routes",
operations: func(pr *PrimaryRoutes) bool {
pr.SetRoutes(1, mp("192.168.1.0/24"))
return pr.SetRoutes(2, mp("192.168.2.0/24"))
},
nodeID: 1,
expectedRoutes: nil,
expectedChange: false,
},
{
name: "multiple-nodes-register-overlapping-routes",
operations: func(pr *PrimaryRoutes) bool {
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
return pr.SetRoutes(2, mp("192.168.1.0/24")) // true
},
nodeID: 1,
expectedRoutes: []netip.Prefix{mp("192.168.1.0/24")},
expectedChange: true,
},
{
name: "node-deregisters-a-route",
operations: func(pr *PrimaryRoutes) bool {
pr.SetRoutes(1, mp("192.168.1.0/24"))
return pr.SetRoutes(1) // Deregister by setting no routes
},
nodeID: 1,
expectedRoutes: nil,
expectedChange: false,
},
{
name: "node-deregisters-one-of-multiple-routes",
operations: func(pr *PrimaryRoutes) bool {
pr.SetRoutes(1, mp("192.168.1.0/24"), mp("192.168.2.0/24"))
return pr.SetRoutes(1, mp("192.168.2.0/24")) // Deregister one route by setting the remaining route
},
nodeID: 1,
expectedRoutes: nil,
expectedChange: false,
},
{
name: "node-registers-and-deregisters-routes-in-sequence",
operations: func(pr *PrimaryRoutes) bool {
pr.SetRoutes(1, mp("192.168.1.0/24"))
pr.SetRoutes(2, mp("192.168.2.0/24"))
pr.SetRoutes(1) // Deregister by setting no routes
return pr.SetRoutes(1, mp("192.168.3.0/24"))
},
nodeID: 1,
expectedRoutes: nil,
expectedChange: false,
},
{
name: "no-change-in-primary-routes",
operations: func(pr *PrimaryRoutes) bool {
return pr.SetRoutes(1, mp("192.168.1.0/24"))
},
nodeID: 1,
expectedRoutes: nil,
expectedChange: false,
},
{
name: "multiple-nodes-register-same-route",
operations: func(pr *PrimaryRoutes) bool {
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
pr.SetRoutes(2, mp("192.168.1.0/24")) // true
return pr.SetRoutes(3, mp("192.168.1.0/24")) // false
},
nodeID: 1,
expectedRoutes: []netip.Prefix{mp("192.168.1.0/24")},
expectedChange: false,
},
{
name: "register-multiple-routes-shift-primary-check-old-primary",
operations: func(pr *PrimaryRoutes) bool {
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary
pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary
return pr.SetRoutes(1) // true, 2 primary
},
nodeID: 1,
expectedRoutes: nil,
expectedChange: true,
},
{
name: "register-multiple-routes-shift-primary-check-primary",
operations: func(pr *PrimaryRoutes) bool {
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary
pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary
return pr.SetRoutes(1) // true, 2 primary
},
nodeID: 2,
expectedRoutes: []netip.Prefix{mp("192.168.1.0/24")},
expectedChange: true,
},
{
name: "register-multiple-routes-shift-primary-check-non-primary",
operations: func(pr *PrimaryRoutes) bool {
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary
pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary
return pr.SetRoutes(1) // true, 2 primary
},
nodeID: 3,
expectedRoutes: nil,
expectedChange: true,
},
{
name: "primary-route-map-is-cleared-up-no-primary",
operations: func(pr *PrimaryRoutes) bool {
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary
pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary
pr.SetRoutes(1) // true, 2 primary
return pr.SetRoutes(2) // true, no primary
},
nodeID: 2,
expectedRoutes: nil,
expectedChange: true,
},
{
name: "primary-route-map-is-cleared-up-all-no-primary",
operations: func(pr *PrimaryRoutes) bool {
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary
pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary
pr.SetRoutes(1) // true, 2 primary
pr.SetRoutes(2) // true, no primary
return pr.SetRoutes(3) // false, no primary
},
nodeID: 2,
expectedRoutes: nil,
expectedChange: false,
},
{
name: "primary-route-map-is-cleared-up",
operations: func(pr *PrimaryRoutes) bool {
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary
pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary
pr.SetRoutes(1) // true, 2 primary
return pr.SetRoutes(2) // true, no primary
},
nodeID: 2,
expectedRoutes: nil,
expectedChange: true,
},
{
name: "primary-route-no-flake",
operations: func(pr *PrimaryRoutes) bool {
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary
pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary
pr.SetRoutes(1) // true, 2 primary
return pr.SetRoutes(1, mp("192.168.1.0/24")) // false, 2 primary
},
nodeID: 2,
expectedRoutes: []netip.Prefix{mp("192.168.1.0/24")},
expectedChange: false,
},
{
name: "primary-route-no-flake-check-old-primary",
operations: func(pr *PrimaryRoutes) bool {
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary
pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary
pr.SetRoutes(1) // true, 2 primary
return pr.SetRoutes(1, mp("192.168.1.0/24")) // false, 2 primary
},
nodeID: 1,
expectedRoutes: nil,
expectedChange: false,
},
{
name: "primary-route-no-flake-full-integration",
operations: func(pr *PrimaryRoutes) bool {
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary
pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary
pr.SetRoutes(1) // true, 2 primary
pr.SetRoutes(2) // true, no primary
pr.SetRoutes(1, mp("192.168.1.0/24")) // true, 1 primary
pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary
pr.SetRoutes(1) // true, 2 primary
return pr.SetRoutes(1, mp("192.168.1.0/24")) // false, 2 primary
},
nodeID: 2,
expectedRoutes: []netip.Prefix{mp("192.168.1.0/24")},
expectedChange: false,
},
{
name: "multiple-nodes-register-same-route-and-exit",
operations: func(pr *PrimaryRoutes) bool {
pr.SetRoutes(1, mp("0.0.0.0/0"), mp("192.168.1.0/24"))
return pr.SetRoutes(2, mp("192.168.1.0/24"))
},
nodeID: 1,
expectedRoutes: []netip.Prefix{mp("192.168.1.0/24")},
expectedChange: true,
},
{
name: "deregister-non-existent-route",
operations: func(pr *PrimaryRoutes) bool {
return pr.SetRoutes(1) // Deregister by setting no routes
},
nodeID: 1,
expectedRoutes: nil,
expectedChange: false,
},
{
name: "register-empty-prefix-list",
operations: func(pr *PrimaryRoutes) bool {
return pr.SetRoutes(1)
},
nodeID: 1,
expectedRoutes: nil,
expectedChange: false,
},
{
name: "deregister-empty-prefix-list",
operations: func(pr *PrimaryRoutes) bool {
return pr.SetRoutes(1)
},
nodeID: 1,
expectedRoutes: nil,
expectedChange: false,
},
{
name: "concurrent-access",
operations: func(pr *PrimaryRoutes) bool {
var wg sync.WaitGroup
wg.Add(2)
var change1, change2 bool
go func() {
defer wg.Done()
change1 = pr.SetRoutes(1, mp("192.168.1.0/24"))
}()
go func() {
defer wg.Done()
change2 = pr.SetRoutes(2, mp("192.168.2.0/24"))
}()
wg.Wait()
return change1 || change2
},
nodeID: 1,
expectedRoutes: nil,
expectedChange: false,
},
{
name: "no-routes-registered",
operations: func(pr *PrimaryRoutes) bool {
// No operations
return false
},
nodeID: 1,
expectedRoutes: nil,
expectedChange: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pr := New()
change := tt.operations(pr)
if change != tt.expectedChange {
t.Errorf("change = %v, want %v", change, tt.expectedChange)
}
routes := pr.PrimaryRoutes(tt.nodeID)
if diff := cmp.Diff(tt.expectedRoutes, routes, util.Comparers...); diff != "" {
t.Errorf("PrimaryRoutes() mismatch (-want +got):\n%s", diff)
}
})
}
}

View file

@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"net/netip"
"slices"
"strconv"
"strings"
"time"
@ -25,8 +26,11 @@ var (
)
type NodeID uint64
type NodeIDs []NodeID
// type NodeConnectedMap *xsync.MapOf[NodeID, bool]
func (n NodeIDs) Len() int { return len(n) }
func (n NodeIDs) Less(i, j int) bool { return n[i] < n[j] }
func (n NodeIDs) Swap(i, j int) { n[i], n[j] = n[j], n[i] }
func (id NodeID) StableID() tailcfg.StableNodeID {
return tailcfg.StableNodeID(strconv.FormatUint(uint64(id), util.Base10))
@ -84,10 +88,21 @@ type Node struct {
AuthKeyID *uint64 `sql:"DEFAULT:NULL"`
AuthKey *PreAuthKey
LastSeen *time.Time
Expiry *time.Time
Expiry *time.Time
Routes []Route `gorm:"constraint:OnDelete:CASCADE;"`
// LastSeen is when the node was last in contact with
// headscale. It is best effort and not persisted.
LastSeen *time.Time `gorm:"-"`
// DEPRECATED: Use the ApprovedRoutes field instead.
// TODO(kradalby): remove when ApprovedRoutes is used all over the code.
// Routes []Route `gorm:"constraint:OnDelete:CASCADE;"`
// ApprovedRoutes is a list of routes that the node is allowed to announce
// as a subnet router. They are not necessarily the routes that the node
// announces at the moment.
// See [Node.Hostinfo]
ApprovedRoutes []netip.Prefix `gorm:"column:approved_routes;serializer:json"`
CreatedAt time.Time
UpdatedAt time.Time
@ -96,9 +111,7 @@ type Node struct {
IsOnline *bool `gorm:"-"`
}
type (
Nodes []*Node
)
type Nodes []*Node
// GivenNameHasBeenChanged returns whether the `givenName` can be automatically changed based on the `Hostname` of the node.
func (node *Node) GivenNameHasBeenChanged() bool {
@ -185,23 +198,22 @@ func (node *Node) CanAccess(filter []tailcfg.FilterRule, node2 *Node) bool {
// TODO(kradalby): Regenerate this every time the filter change, instead of
// every time we use it.
// Part of #2416
matchers := make([]matcher.Match, len(filter))
for i, rule := range filter {
matchers[i] = matcher.MatchFromFilterRule(rule)
}
for _, route := range node2.Routes {
if route.Enabled {
allowedIPs = append(allowedIPs, netip.Prefix(route.Prefix).Addr())
}
}
for _, matcher := range matchers {
if !matcher.SrcsContainsIPs(src) {
if !matcher.SrcsContainsIPs(src...) {
continue
}
if matcher.DestsContainsIP(allowedIPs) {
if matcher.DestsContainsIP(allowedIPs...) {
return true
}
if matcher.DestsOverlapsPrefixes(node2.SubnetRoutes()...) {
return true
}
}
@ -245,11 +257,14 @@ func (node *Node) Proto() *v1.Node {
DiscoKey: node.DiscoKey.String(),
// TODO(kradalby): replace list with v4, v6 field?
IpAddresses: node.IPsAsString(),
Name: node.Hostname,
GivenName: node.GivenName,
User: node.User.Proto(),
ForcedTags: node.ForcedTags,
IpAddresses: node.IPsAsString(),
Name: node.Hostname,
GivenName: node.GivenName,
User: node.User.Proto(),
ForcedTags: node.ForcedTags,
ApprovedRoutes: util.PrefixesToString(node.ApprovedRoutes),
AvailableRoutes: util.PrefixesToString(node.AnnouncedRoutes()),
SubnetRoutes: util.PrefixesToString(node.SubnetRoutes()),
RegisterMethod: node.RegisterMethodToV1Enum(),
@ -297,6 +312,29 @@ func (node *Node) GetFQDN(baseDomain string) (string, error) {
return hostname, nil
}
// AnnouncedRoutes returns the list of routes that the node announces.
// It should be used instead of checking Hostinfo.RoutableIPs directly.
func (node *Node) AnnouncedRoutes() []netip.Prefix {
if node.Hostinfo == nil {
return nil
}
return node.Hostinfo.RoutableIPs
}
// SubnetRoutes returns the list of routes that the node announces and are approved.
func (node *Node) SubnetRoutes() []netip.Prefix {
var routes []netip.Prefix
for _, route := range node.AnnouncedRoutes() {
if slices.Contains(node.ApprovedRoutes, route) {
routes = append(routes, route)
}
}
return routes
}
// func (node *Node) String() string {
// return node.Hostname
// }

View file

@ -1,102 +1,31 @@
package types
import (
"fmt"
"net/netip"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"google.golang.org/protobuf/types/known/timestamppb"
"gorm.io/gorm"
"tailscale.com/net/tsaddr"
)
// Deprecated: Approval of routes is denormalised onto the relevant node.
// Struct is kept for GORM migrations only.
type Route struct {
gorm.Model
NodeID uint64 `gorm:"not null"`
Node *Node
// TODO(kradalby): change this custom type to netip.Prefix
Prefix netip.Prefix `gorm:"serializer:text"`
// Advertised is now only stored as part of [Node.Hostinfo].
Advertised bool
Enabled bool
IsPrimary bool
// Enabled is stored directly on the node as ApprovedRoutes.
Enabled bool
// IsPrimary is only determined in memory as it is only relevant
// when the server is up.
IsPrimary bool
}
// Deprecated: Approval of routes is denormalised onto the relevant node.
type Routes []Route
func (r *Route) String() string {
return fmt.Sprintf("%s:%s", r.Node.Hostname, netip.Prefix(r.Prefix).String())
}
func (r *Route) IsExitRoute() bool {
return tsaddr.IsExitRoute(r.Prefix)
}
func (r *Route) IsAnnouncable() bool {
return r.Advertised && r.Enabled
}
func (rs Routes) Prefixes() []netip.Prefix {
prefixes := make([]netip.Prefix, len(rs))
for i, r := range rs {
prefixes[i] = netip.Prefix(r.Prefix)
}
return prefixes
}
// Primaries returns Primary routes from a list of routes.
func (rs Routes) Primaries() Routes {
res := make(Routes, 0)
for _, route := range rs {
if route.IsPrimary {
res = append(res, route)
}
}
return res
}
func (rs Routes) PrefixMap() map[netip.Prefix][]Route {
res := map[netip.Prefix][]Route{}
for _, route := range rs {
if _, ok := res[route.Prefix]; ok {
res[route.Prefix] = append(res[route.Prefix], route)
} else {
res[route.Prefix] = []Route{route}
}
}
return res
}
func (rs Routes) Proto() []*v1.Route {
protoRoutes := []*v1.Route{}
for _, route := range rs {
protoRoute := v1.Route{
Id: uint64(route.ID),
Prefix: route.Prefix.String(),
Advertised: route.Advertised,
Enabled: route.Enabled,
IsPrimary: route.IsPrimary,
CreatedAt: timestamppb.New(route.CreatedAt),
UpdatedAt: timestamppb.New(route.UpdatedAt),
}
if route.Node != nil {
protoRoute.Node = route.Node.Proto()
}
if route.DeletedAt.Valid {
protoRoute.DeletedAt = timestamppb.New(route.DeletedAt.Time)
}
protoRoutes = append(protoRoutes, &protoRoute)
}
return protoRoutes
}

View file

@ -1,89 +0,0 @@
package types
import (
"fmt"
"net/netip"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/util"
)
func TestPrefixMap(t *testing.T) {
ipp := func(s string) netip.Prefix { return netip.MustParsePrefix(s) }
tests := []struct {
rs Routes
want map[netip.Prefix][]Route
}{
{
rs: Routes{
Route{
Prefix: ipp("10.0.0.0/24"),
},
},
want: map[netip.Prefix][]Route{
ipp("10.0.0.0/24"): Routes{
Route{
Prefix: ipp("10.0.0.0/24"),
},
},
},
},
{
rs: Routes{
Route{
Prefix: ipp("10.0.0.0/24"),
},
Route{
Prefix: ipp("10.0.1.0/24"),
},
},
want: map[netip.Prefix][]Route{
ipp("10.0.0.0/24"): Routes{
Route{
Prefix: ipp("10.0.0.0/24"),
},
},
ipp("10.0.1.0/24"): Routes{
Route{
Prefix: ipp("10.0.1.0/24"),
},
},
},
},
{
rs: Routes{
Route{
Prefix: ipp("10.0.0.0/24"),
Enabled: true,
},
Route{
Prefix: ipp("10.0.0.0/24"),
Enabled: false,
},
},
want: map[netip.Prefix][]Route{
ipp("10.0.0.0/24"): Routes{
Route{
Prefix: ipp("10.0.0.0/24"),
Enabled: true,
},
Route{
Prefix: ipp("10.0.0.0/24"),
Enabled: false,
},
},
},
},
}
for idx, tt := range tests {
t.Run(fmt.Sprintf("test-%d", idx), func(t *testing.T) {
got := tt.rs.PrefixMap()
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
t.Errorf("PrefixMap() unexpected result (-want +got):\n%s", diff)
}
})
}
}

View file

@ -55,6 +55,13 @@ type User struct {
ProfilePicURL string
}
func (u *User) StringID() string {
if u == nil {
return ""
}
return strconv.FormatUint(uint64(u.ID), 10)
}
// Username is the main way to get the username of a user,
// it will return the email if it exists, the name if it exists,
// the OIDCIdentifier if it exists, and the ID if nothing else exists.
@ -63,7 +70,11 @@ type User struct {
// should be used throughout headscale, in information returned to the
// user and the Policy engine.
func (u *User) Username() string {
return cmp.Or(u.Email, u.Name, u.ProviderIdentifier.String, strconv.FormatUint(uint64(u.ID), 10))
return cmp.Or(
u.Email,
u.Name,
u.ProviderIdentifier.String,
u.StringID())
}
// DisplayNameOrUsername returns the DisplayName if it exists, otherwise

View file

@ -1,8 +1,10 @@
package util
import (
"cmp"
"context"
"net"
"net/netip"
)
func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) {
@ -10,3 +12,40 @@ func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) {
return d.DialContext(ctx, "unix", addr)
}
// TODO(kradalby): Remove when in stdlib;
// https://github.com/golang/go/issues/61642
// Compare returns an integer comparing two prefixes.
// The result will be 0 if p == p2, -1 if p < p2, and +1 if p > p2.
// Prefixes sort first by validity (invalid before valid), then
// address family (IPv4 before IPv6), then prefix length, then
// address.
func ComparePrefix(p, p2 netip.Prefix) int {
if c := cmp.Compare(p.Addr().BitLen(), p2.Addr().BitLen()); c != 0 {
return c
}
if c := cmp.Compare(p.Bits(), p2.Bits()); c != 0 {
return c
}
return p.Addr().Compare(p2.Addr())
}
func PrefixesToString(prefixes []netip.Prefix) []string {
ret := make([]string, 0, len(prefixes))
for _, prefix := range prefixes {
ret = append(ret, prefix.String())
}
return ret
}
func MustStringsToPrefixes(strings []string) []netip.Prefix {
ret := make([]netip.Prefix, 0, len(strings))
for _, str := range strings {
prefix := netip.MustParsePrefix(str)
ret = append(ret, prefix)
}
return ret
}

View file

@ -74,3 +74,21 @@ func TailMapResponseToString(resp tailcfg.MapResponse) string {
TailNodesToString(resp.Peers),
)
}
func TailcfgFilterRulesToString(rules []tailcfg.FilterRule) string {
var sb strings.Builder
for index, rule := range rules {
sb.WriteString(fmt.Sprintf(`
{
SrcIPs: %v
DstIPs: %v
}
`, rule.SrcIPs, rule.DstPorts))
if index < len(rules)-1 {
sb.WriteString(", ")
}
}
return fmt.Sprintf("[ %s ](%d)", sb.String(), len(rules))
}