policy: reduce routes sent to peers based on packetfilter (#2561)

* notifier: use convenience funcs

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* policy: reduce routes based on policy

Fixes #2365

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* hsic: more helper methods

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* policy: more test cases

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* integration: add route with filter acl integration test

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* integration: correct route reduce test, now failing

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* mapper: compare peer routes against node

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* hs: more output to debug strings

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* types/node: slice.ContainsFunc

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* policy: more reduce route test

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* changelog: add entry for route filter

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

---------

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2025-05-04 22:52:47 +03:00 committed by GitHub
parent b9868f6516
commit 45e38cb080
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 903 additions and 47 deletions

View file

@ -21,6 +21,8 @@ type ControlServer interface {
CreateUser(user string) (*v1.User, error)
CreateAuthKey(user uint64, reusable bool, ephemeral bool) (*v1.PreAuthKey, error)
ListNodes(users ...string) ([]*v1.Node, error)
NodesByUser() (map[string][]*v1.Node, error)
NodesByName() (map[string]*v1.Node, error)
ListUsers() ([]*v1.User, error)
MapUsers() (map[string]*v1.User, error)
ApproveRoutes(uint64, []netip.Prefix) (*v1.Node, error)

View file

@ -819,6 +819,38 @@ func (t *HeadscaleInContainer) ListNodes(
return ret, nil
}
func (t *HeadscaleInContainer) NodesByUser() (map[string][]*v1.Node, error) {
nodes, err := t.ListNodes()
if err != nil {
return nil, err
}
var userMap map[string][]*v1.Node
for _, node := range nodes {
if _, ok := userMap[node.User.Name]; !ok {
mak.Set(&userMap, node.User.Name, []*v1.Node{node})
} else {
userMap[node.User.Name] = append(userMap[node.User.Name], node)
}
}
return userMap, nil
}
func (t *HeadscaleInContainer) NodesByName() (map[string]*v1.Node, error) {
nodes, err := t.ListNodes()
if err != nil {
return nil, err
}
var nameMap map[string]*v1.Node
for _, node := range nodes {
mak.Set(&nameMap, node.GetName(), node)
}
return nameMap, nil
}
// ListUsers returns a list of users from Headscale.
func (t *HeadscaleInContainer) ListUsers() ([]*v1.User, error) {
command := []string{"headscale", "users", "list", "--output", "json"}
@ -973,7 +1005,7 @@ func (t *HeadscaleInContainer) ApproveRoutes(id uint64, routes []netip.Prefix) (
"headscale", "nodes", "approve-routes",
"--output", "json",
"--identifier", strconv.FormatUint(id, 10),
fmt.Sprintf("--routes=%q", strings.Join(util.PrefixesToString(routes), ",")),
fmt.Sprintf("--routes=%s", strings.Join(util.PrefixesToString(routes), ",")),
}
result, _, err := dockertestutil.ExecuteCommand(

View file

@ -1,6 +1,7 @@
package integration
import (
"encoding/json"
"fmt"
"net/netip"
"sort"
@ -9,7 +10,7 @@ import (
"slices"
"github.com/google/go-cmp/cmp"
cmpdiff "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
policyv1 "github.com/juanfont/headscale/hscontrol/policy/v1"
@ -23,6 +24,7 @@ import (
"tailscale.com/net/tsaddr"
"tailscale.com/types/ipproto"
"tailscale.com/types/views"
"tailscale.com/util/must"
"tailscale.com/util/slicesx"
"tailscale.com/wgengine/filter"
)
@ -940,7 +942,7 @@ func TestSubnetRouteACL(t *testing.T) {
},
}
if diff := cmp.Diff(wantClientFilter, clientNm.PacketFilter, util.ViewSliceIPProtoComparer, util.PrefixComparer); diff != "" {
if diff := cmpdiff.Diff(wantClientFilter, clientNm.PacketFilter, util.ViewSliceIPProtoComparer, util.PrefixComparer); diff != "" {
t.Errorf("Client (%s) filter, unexpected result (-want +got):\n%s", client.Hostname(), diff)
}
@ -990,7 +992,7 @@ func TestSubnetRouteACL(t *testing.T) {
},
}
if diff := cmp.Diff(wantSubnetFilter, subnetNm.PacketFilter, util.ViewSliceIPProtoComparer, util.PrefixComparer); diff != "" {
if diff := cmpdiff.Diff(wantSubnetFilter, subnetNm.PacketFilter, util.ViewSliceIPProtoComparer, util.PrefixComparer); diff != "" {
t.Errorf("Subnet (%s) filter, unexpected result (-want +got):\n%s", subRouter1.Hostname(), diff)
}
}
@ -1603,9 +1605,9 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
}
for _, tt := range tests {
for _, dbMode := range []types.PolicyMode{types.PolicyModeDB, types.PolicyModeFile} {
for _, polMode := range []types.PolicyMode{types.PolicyModeDB, types.PolicyModeFile} {
for _, advertiseDuringUp := range []bool{false, true} {
name := fmt.Sprintf("%s-advertiseduringup-%t-pol-%s", tt.name, advertiseDuringUp, dbMode)
name := fmt.Sprintf("%s-advertiseduringup-%t-pol-%s", tt.name, advertiseDuringUp, polMode)
t.Run(name, func(t *testing.T) {
scenario, err := NewScenario(tt.spec)
require.NoErrorf(t, err, "failed to create scenario: %s", err)
@ -1616,7 +1618,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
hsic.WithEmbeddedDERPServerOnly(),
hsic.WithTLS(),
hsic.WithACLPolicy(tt.pol),
hsic.WithPolicyMode(dbMode),
hsic.WithPolicyMode(polMode),
}
tsOpts := []tsic.Option{
@ -2007,7 +2009,7 @@ func requirePeerSubnetRoutes(t *testing.T, status *ipnstate.PeerStatus, expected
return !slices.ContainsFunc(status.TailscaleIPs, p.Contains)
})
if diff := cmp.Diff(expected, got, util.PrefixComparer, cmpopts.EquateEmpty()); diff != "" {
if diff := cmpdiff.Diff(expected, got, util.PrefixComparer, cmpopts.EquateEmpty()); diff != "" {
t.Fatalf("peer %s (%s) subnet routes, unexpected result (-want +got):\n%s", status.HostName, status.ID, diff)
}
}
@ -2018,3 +2020,193 @@ func requireNodeRouteCount(t *testing.T, node *v1.Node, announced, approved, sub
require.Lenf(t, node.GetApprovedRoutes(), approved, "expected %q approved routes(%v) to have %d route, had %d", node.GetName(), node.GetApprovedRoutes(), approved, len(node.GetApprovedRoutes()))
require.Lenf(t, node.GetSubnetRoutes(), subnet, "expected %q subnet routes(%v) to have %d route, had %d", node.GetName(), node.GetSubnetRoutes(), subnet, len(node.GetSubnetRoutes()))
}
// TestSubnetRouteACLFiltering tests that a node can only access subnet routes
// that are explicitly allowed in the ACL.
func TestSubnetRouteACLFiltering(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
// Use router and node users for better clarity
routerUser := "router"
nodeUser := "node"
spec := ScenarioSpec{
NodesPerUser: 1,
Users: []string{routerUser, nodeUser},
Networks: map[string][]string{
"usernet1": {routerUser, nodeUser},
},
ExtraService: map[string][]extraServiceFunc{
"usernet1": {Webservice},
},
// We build the head image with curl and traceroute, so only use
// that for this test.
Versions: []string{"head"},
}
scenario, err := NewScenario(spec)
require.NoErrorf(t, err, "failed to create scenario: %s", err)
defer scenario.ShutdownAssertNoPanics(t)
// Set up the ACL policy that allows the node to access only one of the subnet routes (10.10.10.0/24)
aclPolicyStr := fmt.Sprintf(`{
"hosts": {
"router": "100.64.0.1/32",
"node": "100.64.0.2/32"
},
"acls": [
{
"action": "accept",
"src": [
"*"
],
"dst": [
"router:8000"
]
},
{
"action": "accept",
"src": [
"node"
],
"dst": []
}
]
}`)
route, err := scenario.SubnetOfNetwork("usernet1")
require.NoError(t, err)
services, err := scenario.Services("usernet1")
require.NoError(t, err)
require.Len(t, services, 1)
usernet1, err := scenario.Network("usernet1")
require.NoError(t, err)
web := services[0]
webip := netip.MustParseAddr(web.GetIPInNetwork(usernet1))
weburl := fmt.Sprintf("http://%s/etc/hostname", webip)
t.Logf("webservice: %s, %s", webip.String(), weburl)
// Create ACL policy
aclPolicy := &policyv1.ACLPolicy{}
err = json.Unmarshal([]byte(aclPolicyStr), aclPolicy)
require.NoError(t, err)
err = scenario.CreateHeadscaleEnv([]tsic.Option{
tsic.WithAcceptRoutes(),
}, hsic.WithTestName("routeaclfilter"),
hsic.WithACLPolicy(aclPolicy),
hsic.WithPolicyMode(types.PolicyModeDB),
)
assertNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
headscale, err := scenario.Headscale()
assertNoErrGetHeadscale(t, err)
// Sort clients by ID for consistent order
slices.SortFunc(allClients, func(a, b TailscaleClient) int {
return b.MustIPv4().Compare(a.MustIPv4())
})
// Get the router and node clients
routerClient := allClients[0]
nodeClient := allClients[1]
aclPolicy.Hosts = policyv1.Hosts{
routerUser: must.Get(routerClient.MustIPv4().Prefix(32)),
nodeUser: must.Get(nodeClient.MustIPv4().Prefix(32)),
}
aclPolicy.ACLs[1].Destinations = []string{
route.String() + ":*",
}
require.NoError(t, headscale.SetPolicy(aclPolicy))
// Set up the subnet routes for the router
routes := []string{
route.String(), // This should be accessible by the client
"10.10.11.0/24", // These should NOT be accessible
"10.10.12.0/24",
}
routeArg := "--advertise-routes=" + routes[0] + "," + routes[1] + "," + routes[2]
command := []string{
"tailscale",
"set",
routeArg,
}
_, _, err = routerClient.Execute(command)
require.NoErrorf(t, err, "failed to advertise routes: %s", err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
// List nodes and verify the router has 3 available routes
nodes, err := headscale.NodesByUser()
require.NoError(t, err)
require.Len(t, nodes, 2)
// Find the router node
routerNode := nodes[routerUser][0]
nodeNode := nodes[nodeUser][0]
require.NotNil(t, routerNode, "Router node not found")
require.NotNil(t, nodeNode, "Client node not found")
// Check that the router has 3 routes available but not approved yet
requireNodeRouteCount(t, routerNode, 3, 0, 0)
requireNodeRouteCount(t, nodeNode, 0, 0, 0)
// Approve all routes for the router
_, err = headscale.ApproveRoutes(
routerNode.GetId(),
util.MustStringsToPrefixes(routerNode.GetAvailableRoutes()),
)
require.NoError(t, err)
// Give some time for the routes to propagate
time.Sleep(5 * time.Second)
// List nodes and verify the router has 3 available routes
nodes, err = headscale.NodesByUser()
require.NoError(t, err)
require.Len(t, nodes, 2)
// Find the router node
routerNode = nodes[routerUser][0]
// Check that the router has 3 routes now approved and available
requireNodeRouteCount(t, routerNode, 3, 3, 3)
// Now check the client node status
nodeStatus, err := nodeClient.Status()
require.NoError(t, err)
routerStatus, err := routerClient.Status()
require.NoError(t, err)
// Check that the node can see the subnet routes from the router
routerPeerStatus := nodeStatus.Peer[routerStatus.Self.PublicKey]
// The node should only have 1 subnet route
requirePeerSubnetRoutes(t, routerPeerStatus, []netip.Prefix{*route})
result, err := nodeClient.Curl(weburl)
require.NoError(t, err)
assert.Len(t, result, 13)
tr, err := nodeClient.Traceroute(webip)
require.NoError(t, err)
assertTracerouteViaIP(t, tr, routerClient.MustIPv4())
}