policy: reduce routes sent to peers based on packetfilter (#2561)
* notifier: use convenience funcs Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * policy: reduce routes based on policy Fixes #2365 Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * hsic: more helper methods Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * policy: more test cases Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * integration: add route with filter acl integration test Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * integration: correct route reduce test, now failing Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * mapper: compare peer routes against node Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * hs: more output to debug strings Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * types/node: slice.ContainsFunc Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * policy: more reduce route test Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * changelog: add entry for route filter Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> --------- Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
b9868f6516
commit
45e38cb080
16 changed files with 903 additions and 47 deletions
|
@ -21,6 +21,8 @@ type ControlServer interface {
|
|||
CreateUser(user string) (*v1.User, error)
|
||||
CreateAuthKey(user uint64, reusable bool, ephemeral bool) (*v1.PreAuthKey, error)
|
||||
ListNodes(users ...string) ([]*v1.Node, error)
|
||||
NodesByUser() (map[string][]*v1.Node, error)
|
||||
NodesByName() (map[string]*v1.Node, error)
|
||||
ListUsers() ([]*v1.User, error)
|
||||
MapUsers() (map[string]*v1.User, error)
|
||||
ApproveRoutes(uint64, []netip.Prefix) (*v1.Node, error)
|
||||
|
|
|
@ -819,6 +819,38 @@ func (t *HeadscaleInContainer) ListNodes(
|
|||
return ret, nil
|
||||
}
|
||||
|
||||
func (t *HeadscaleInContainer) NodesByUser() (map[string][]*v1.Node, error) {
|
||||
nodes, err := t.ListNodes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var userMap map[string][]*v1.Node
|
||||
for _, node := range nodes {
|
||||
if _, ok := userMap[node.User.Name]; !ok {
|
||||
mak.Set(&userMap, node.User.Name, []*v1.Node{node})
|
||||
} else {
|
||||
userMap[node.User.Name] = append(userMap[node.User.Name], node)
|
||||
}
|
||||
}
|
||||
|
||||
return userMap, nil
|
||||
}
|
||||
|
||||
func (t *HeadscaleInContainer) NodesByName() (map[string]*v1.Node, error) {
|
||||
nodes, err := t.ListNodes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var nameMap map[string]*v1.Node
|
||||
for _, node := range nodes {
|
||||
mak.Set(&nameMap, node.GetName(), node)
|
||||
}
|
||||
|
||||
return nameMap, nil
|
||||
}
|
||||
|
||||
// ListUsers returns a list of users from Headscale.
|
||||
func (t *HeadscaleInContainer) ListUsers() ([]*v1.User, error) {
|
||||
command := []string{"headscale", "users", "list", "--output", "json"}
|
||||
|
@ -973,7 +1005,7 @@ func (t *HeadscaleInContainer) ApproveRoutes(id uint64, routes []netip.Prefix) (
|
|||
"headscale", "nodes", "approve-routes",
|
||||
"--output", "json",
|
||||
"--identifier", strconv.FormatUint(id, 10),
|
||||
fmt.Sprintf("--routes=%q", strings.Join(util.PrefixesToString(routes), ",")),
|
||||
fmt.Sprintf("--routes=%s", strings.Join(util.PrefixesToString(routes), ",")),
|
||||
}
|
||||
|
||||
result, _, err := dockertestutil.ExecuteCommand(
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package integration
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sort"
|
||||
|
@ -9,7 +10,7 @@ import (
|
|||
|
||||
"slices"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
cmpdiff "github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
policyv1 "github.com/juanfont/headscale/hscontrol/policy/v1"
|
||||
|
@ -23,6 +24,7 @@ import (
|
|||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/types/ipproto"
|
||||
"tailscale.com/types/views"
|
||||
"tailscale.com/util/must"
|
||||
"tailscale.com/util/slicesx"
|
||||
"tailscale.com/wgengine/filter"
|
||||
)
|
||||
|
@ -940,7 +942,7 @@ func TestSubnetRouteACL(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(wantClientFilter, clientNm.PacketFilter, util.ViewSliceIPProtoComparer, util.PrefixComparer); diff != "" {
|
||||
if diff := cmpdiff.Diff(wantClientFilter, clientNm.PacketFilter, util.ViewSliceIPProtoComparer, util.PrefixComparer); diff != "" {
|
||||
t.Errorf("Client (%s) filter, unexpected result (-want +got):\n%s", client.Hostname(), diff)
|
||||
}
|
||||
|
||||
|
@ -990,7 +992,7 @@ func TestSubnetRouteACL(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(wantSubnetFilter, subnetNm.PacketFilter, util.ViewSliceIPProtoComparer, util.PrefixComparer); diff != "" {
|
||||
if diff := cmpdiff.Diff(wantSubnetFilter, subnetNm.PacketFilter, util.ViewSliceIPProtoComparer, util.PrefixComparer); diff != "" {
|
||||
t.Errorf("Subnet (%s) filter, unexpected result (-want +got):\n%s", subRouter1.Hostname(), diff)
|
||||
}
|
||||
}
|
||||
|
@ -1603,9 +1605,9 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
|||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
for _, dbMode := range []types.PolicyMode{types.PolicyModeDB, types.PolicyModeFile} {
|
||||
for _, polMode := range []types.PolicyMode{types.PolicyModeDB, types.PolicyModeFile} {
|
||||
for _, advertiseDuringUp := range []bool{false, true} {
|
||||
name := fmt.Sprintf("%s-advertiseduringup-%t-pol-%s", tt.name, advertiseDuringUp, dbMode)
|
||||
name := fmt.Sprintf("%s-advertiseduringup-%t-pol-%s", tt.name, advertiseDuringUp, polMode)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
scenario, err := NewScenario(tt.spec)
|
||||
require.NoErrorf(t, err, "failed to create scenario: %s", err)
|
||||
|
@ -1616,7 +1618,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
|||
hsic.WithEmbeddedDERPServerOnly(),
|
||||
hsic.WithTLS(),
|
||||
hsic.WithACLPolicy(tt.pol),
|
||||
hsic.WithPolicyMode(dbMode),
|
||||
hsic.WithPolicyMode(polMode),
|
||||
}
|
||||
|
||||
tsOpts := []tsic.Option{
|
||||
|
@ -2007,7 +2009,7 @@ func requirePeerSubnetRoutes(t *testing.T, status *ipnstate.PeerStatus, expected
|
|||
return !slices.ContainsFunc(status.TailscaleIPs, p.Contains)
|
||||
})
|
||||
|
||||
if diff := cmp.Diff(expected, got, util.PrefixComparer, cmpopts.EquateEmpty()); diff != "" {
|
||||
if diff := cmpdiff.Diff(expected, got, util.PrefixComparer, cmpopts.EquateEmpty()); diff != "" {
|
||||
t.Fatalf("peer %s (%s) subnet routes, unexpected result (-want +got):\n%s", status.HostName, status.ID, diff)
|
||||
}
|
||||
}
|
||||
|
@ -2018,3 +2020,193 @@ func requireNodeRouteCount(t *testing.T, node *v1.Node, announced, approved, sub
|
|||
require.Lenf(t, node.GetApprovedRoutes(), approved, "expected %q approved routes(%v) to have %d route, had %d", node.GetName(), node.GetApprovedRoutes(), approved, len(node.GetApprovedRoutes()))
|
||||
require.Lenf(t, node.GetSubnetRoutes(), subnet, "expected %q subnet routes(%v) to have %d route, had %d", node.GetName(), node.GetSubnetRoutes(), subnet, len(node.GetSubnetRoutes()))
|
||||
}
|
||||
|
||||
// TestSubnetRouteACLFiltering tests that a node can only access subnet routes
|
||||
// that are explicitly allowed in the ACL.
|
||||
func TestSubnetRouteACLFiltering(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
// Use router and node users for better clarity
|
||||
routerUser := "router"
|
||||
nodeUser := "node"
|
||||
|
||||
spec := ScenarioSpec{
|
||||
NodesPerUser: 1,
|
||||
Users: []string{routerUser, nodeUser},
|
||||
Networks: map[string][]string{
|
||||
"usernet1": {routerUser, nodeUser},
|
||||
},
|
||||
ExtraService: map[string][]extraServiceFunc{
|
||||
"usernet1": {Webservice},
|
||||
},
|
||||
// We build the head image with curl and traceroute, so only use
|
||||
// that for this test.
|
||||
Versions: []string{"head"},
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
require.NoErrorf(t, err, "failed to create scenario: %s", err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
// Set up the ACL policy that allows the node to access only one of the subnet routes (10.10.10.0/24)
|
||||
aclPolicyStr := fmt.Sprintf(`{
|
||||
"hosts": {
|
||||
"router": "100.64.0.1/32",
|
||||
"node": "100.64.0.2/32"
|
||||
},
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": [
|
||||
"*"
|
||||
],
|
||||
"dst": [
|
||||
"router:8000"
|
||||
]
|
||||
},
|
||||
{
|
||||
"action": "accept",
|
||||
"src": [
|
||||
"node"
|
||||
],
|
||||
"dst": []
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
route, err := scenario.SubnetOfNetwork("usernet1")
|
||||
require.NoError(t, err)
|
||||
|
||||
services, err := scenario.Services("usernet1")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, services, 1)
|
||||
|
||||
usernet1, err := scenario.Network("usernet1")
|
||||
require.NoError(t, err)
|
||||
|
||||
web := services[0]
|
||||
webip := netip.MustParseAddr(web.GetIPInNetwork(usernet1))
|
||||
weburl := fmt.Sprintf("http://%s/etc/hostname", webip)
|
||||
t.Logf("webservice: %s, %s", webip.String(), weburl)
|
||||
|
||||
// Create ACL policy
|
||||
aclPolicy := &policyv1.ACLPolicy{}
|
||||
err = json.Unmarshal([]byte(aclPolicyStr), aclPolicy)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = scenario.CreateHeadscaleEnv([]tsic.Option{
|
||||
tsic.WithAcceptRoutes(),
|
||||
}, hsic.WithTestName("routeaclfilter"),
|
||||
hsic.WithACLPolicy(aclPolicy),
|
||||
hsic.WithPolicyMode(types.PolicyModeDB),
|
||||
)
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErrGetHeadscale(t, err)
|
||||
|
||||
// Sort clients by ID for consistent order
|
||||
slices.SortFunc(allClients, func(a, b TailscaleClient) int {
|
||||
return b.MustIPv4().Compare(a.MustIPv4())
|
||||
})
|
||||
|
||||
// Get the router and node clients
|
||||
routerClient := allClients[0]
|
||||
nodeClient := allClients[1]
|
||||
|
||||
aclPolicy.Hosts = policyv1.Hosts{
|
||||
routerUser: must.Get(routerClient.MustIPv4().Prefix(32)),
|
||||
nodeUser: must.Get(nodeClient.MustIPv4().Prefix(32)),
|
||||
}
|
||||
aclPolicy.ACLs[1].Destinations = []string{
|
||||
route.String() + ":*",
|
||||
}
|
||||
|
||||
require.NoError(t, headscale.SetPolicy(aclPolicy))
|
||||
|
||||
// Set up the subnet routes for the router
|
||||
routes := []string{
|
||||
route.String(), // This should be accessible by the client
|
||||
"10.10.11.0/24", // These should NOT be accessible
|
||||
"10.10.12.0/24",
|
||||
}
|
||||
|
||||
routeArg := "--advertise-routes=" + routes[0] + "," + routes[1] + "," + routes[2]
|
||||
command := []string{
|
||||
"tailscale",
|
||||
"set",
|
||||
routeArg,
|
||||
}
|
||||
|
||||
_, _, err = routerClient.Execute(command)
|
||||
require.NoErrorf(t, err, "failed to advertise routes: %s", err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
|
||||
// List nodes and verify the router has 3 available routes
|
||||
nodes, err := headscale.NodesByUser()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, nodes, 2)
|
||||
|
||||
// Find the router node
|
||||
routerNode := nodes[routerUser][0]
|
||||
nodeNode := nodes[nodeUser][0]
|
||||
|
||||
require.NotNil(t, routerNode, "Router node not found")
|
||||
require.NotNil(t, nodeNode, "Client node not found")
|
||||
|
||||
// Check that the router has 3 routes available but not approved yet
|
||||
requireNodeRouteCount(t, routerNode, 3, 0, 0)
|
||||
requireNodeRouteCount(t, nodeNode, 0, 0, 0)
|
||||
|
||||
// Approve all routes for the router
|
||||
_, err = headscale.ApproveRoutes(
|
||||
routerNode.GetId(),
|
||||
util.MustStringsToPrefixes(routerNode.GetAvailableRoutes()),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Give some time for the routes to propagate
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
// List nodes and verify the router has 3 available routes
|
||||
nodes, err = headscale.NodesByUser()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, nodes, 2)
|
||||
|
||||
// Find the router node
|
||||
routerNode = nodes[routerUser][0]
|
||||
|
||||
// Check that the router has 3 routes now approved and available
|
||||
requireNodeRouteCount(t, routerNode, 3, 3, 3)
|
||||
|
||||
// Now check the client node status
|
||||
nodeStatus, err := nodeClient.Status()
|
||||
require.NoError(t, err)
|
||||
|
||||
routerStatus, err := routerClient.Status()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that the node can see the subnet routes from the router
|
||||
routerPeerStatus := nodeStatus.Peer[routerStatus.Self.PublicKey]
|
||||
|
||||
// The node should only have 1 subnet route
|
||||
requirePeerSubnetRoutes(t, routerPeerStatus, []netip.Prefix{*route})
|
||||
|
||||
result, err := nodeClient.Curl(weburl)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, result, 13)
|
||||
|
||||
tr, err := nodeClient.Traceroute(webip)
|
||||
require.NoError(t, err)
|
||||
assertTracerouteViaIP(t, tr, routerClient.MustIPv4())
|
||||
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue