Make matchers part of the Policy interface (#2514)
* Make matchers part of the Policy interface * Prevent race condition between rules and matchers * Test also matchers in tests for Policy.Filter * Compute `filterChanged` in v2 policy correctly * Fix nil vs. empty list issue in v2 policy test * policy/v2: always clear ssh map Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> --------- Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> Co-authored-by: Aras Ergus <aras.ergus@tngtech.com> Co-authored-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
eb1ecefd9e
commit
4651d06fa8
12 changed files with 89 additions and 43 deletions
|
@ -13,6 +13,14 @@ type Match struct {
|
|||
dests *netipx.IPSet
|
||||
}
|
||||
|
||||
func MatchesFromFilterRules(rules []tailcfg.FilterRule) []Match {
|
||||
matches := make([]Match, 0, len(rules))
|
||||
for _, rule := range rules {
|
||||
matches = append(matches, MatchFromFilterRule(rule))
|
||||
}
|
||||
return matches
|
||||
}
|
||||
|
||||
func MatchFromFilterRule(rule tailcfg.FilterRule) Match {
|
||||
dests := []string{}
|
||||
for _, dest := range rule.DstPorts {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package policy
|
||||
|
||||
import (
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
"net/netip"
|
||||
|
||||
policyv1 "github.com/juanfont/headscale/hscontrol/policy/v1"
|
||||
|
@ -15,7 +16,8 @@ var (
|
|||
)
|
||||
|
||||
type PolicyManager interface {
|
||||
Filter() []tailcfg.FilterRule
|
||||
// Filter returns the current filter rules for the entire tailnet and the associated matchers.
|
||||
Filter() ([]tailcfg.FilterRule, []matcher.Match)
|
||||
SSHPolicy(*types.Node) (*tailcfg.SSHPolicy, error)
|
||||
SetPolicy([]byte) (bool, error)
|
||||
SetUsers(users []types.User) (bool, error)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package policy
|
||||
|
||||
import (
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
"net/netip"
|
||||
"slices"
|
||||
|
||||
|
@ -15,7 +16,7 @@ import (
|
|||
func FilterNodesByACL(
|
||||
node *types.Node,
|
||||
nodes types.Nodes,
|
||||
filter []tailcfg.FilterRule,
|
||||
matchers []matcher.Match,
|
||||
) types.Nodes {
|
||||
var result types.Nodes
|
||||
|
||||
|
@ -24,7 +25,7 @@ func FilterNodesByACL(
|
|||
continue
|
||||
}
|
||||
|
||||
if node.CanAccess(filter, nodes[index]) || peer.CanAccess(filter, node) {
|
||||
if node.CanAccess(matchers, nodes[index]) || peer.CanAccess(matchers, node) {
|
||||
result = append(result, peer)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package policy
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
|
@ -769,7 +770,7 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
var err error
|
||||
pm, err = pmf(users, append(tt.peers, tt.node))
|
||||
require.NoError(t, err)
|
||||
got := pm.Filter()
|
||||
got, _ := pm.Filter()
|
||||
got = ReduceFilterRules(tt.node, got)
|
||||
|
||||
if diff := cmp.Diff(tt.want, got); diff != "" {
|
||||
|
@ -1425,10 +1426,11 @@ func TestFilterNodesByACL(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
matchers := matcher.MatchesFromFilterRules(tt.args.rules)
|
||||
got := FilterNodesByACL(
|
||||
tt.args.node,
|
||||
tt.args.nodes,
|
||||
tt.args.rules,
|
||||
matchers,
|
||||
)
|
||||
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
|
||||
t.Errorf("FilterNodesByACL() unexpected result (-want +got):\n%s", diff)
|
||||
|
|
|
@ -2,6 +2,7 @@ package v1
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
|
@ -88,10 +89,10 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
|||
return true, nil
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) Filter() []tailcfg.FilterRule {
|
||||
func (pm *PolicyManager) Filter() ([]tailcfg.FilterRule, []matcher.Match) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
return pm.filter
|
||||
return pm.filter, matcher.MatchesFromFilterRules(pm.filter)
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package v1
|
||||
|
||||
import (
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
@ -27,6 +28,7 @@ func TestPolicySetChange(t *testing.T) {
|
|||
wantNodesChange bool
|
||||
wantPolicyChange bool
|
||||
wantFilter []tailcfg.FilterRule
|
||||
wantMatchers []matcher.Match
|
||||
}{
|
||||
{
|
||||
name: "set-nodes",
|
||||
|
@ -42,6 +44,9 @@ func TestPolicySetChange(t *testing.T) {
|
|||
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
|
||||
},
|
||||
},
|
||||
wantMatchers: []matcher.Match{
|
||||
matcher.MatchFromStrings([]string{}, []string{"100.64.0.1/32"}),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set-users",
|
||||
|
@ -52,6 +57,9 @@ func TestPolicySetChange(t *testing.T) {
|
|||
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
|
||||
},
|
||||
},
|
||||
wantMatchers: []matcher.Match{
|
||||
matcher.MatchFromStrings([]string{}, []string{"100.64.0.1/32"}),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set-users-and-node",
|
||||
|
@ -70,6 +78,9 @@ func TestPolicySetChange(t *testing.T) {
|
|||
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
|
||||
},
|
||||
},
|
||||
wantMatchers: []matcher.Match{
|
||||
matcher.MatchFromStrings([]string{"100.64.0.2/32"}, []string{"100.64.0.1/32"}),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set-policy",
|
||||
|
@ -95,6 +106,9 @@ func TestPolicySetChange(t *testing.T) {
|
|||
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.62/32", Ports: tailcfg.PortRangeAny}},
|
||||
},
|
||||
},
|
||||
wantMatchers: []matcher.Match{
|
||||
matcher.MatchFromStrings([]string{"100.64.0.61/32"}, []string{"100.64.0.62/32"}),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -150,8 +164,16 @@ func TestPolicySetChange(t *testing.T) {
|
|||
assert.Equal(t, tt.wantNodesChange, change)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.wantFilter, pm.Filter()); diff != "" {
|
||||
t.Errorf("TestPolicySetChange() unexpected result (-want +got):\n%s", diff)
|
||||
filter, matchers := pm.Filter()
|
||||
if diff := cmp.Diff(tt.wantFilter, filter); diff != "" {
|
||||
t.Errorf("TestPolicySetChange() unexpected filter (-want +got):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(
|
||||
tt.wantMatchers,
|
||||
matchers,
|
||||
cmp.AllowUnexported(matcher.Match{}),
|
||||
); diff != "" {
|
||||
t.Errorf("TestPolicySetChange() unexpected matchers (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -7,6 +7,8 @@ import (
|
|||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
|
||||
"slices"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
|
@ -24,6 +26,7 @@ type PolicyManager struct {
|
|||
|
||||
filterHash deephash.Sum
|
||||
filter []tailcfg.FilterRule
|
||||
matchers []matcher.Match
|
||||
|
||||
tagOwnerMapHash deephash.Sum
|
||||
tagOwnerMap map[Tag]*netipx.IPSet
|
||||
|
@ -62,15 +65,24 @@ func NewPolicyManager(b []byte, users []types.User, nodes types.Nodes) (*PolicyM
|
|||
// updateLocked updates the filter rules based on the current policy and nodes.
|
||||
// It must be called with the lock held.
|
||||
func (pm *PolicyManager) updateLocked() (bool, error) {
|
||||
// Clear the SSH policy map to ensure it's recalculated with the new policy.
|
||||
// TODO(kradalby): This could potentially be optimized by only clearing the
|
||||
// policies for nodes that have changed. Particularly if the only difference is
|
||||
// that nodes has been added or removed.
|
||||
defer clear(pm.sshPolicyMap)
|
||||
|
||||
filter, err := pm.pol.compileFilterRules(pm.users, pm.nodes)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("compiling filter rules: %w", err)
|
||||
}
|
||||
|
||||
filterHash := deephash.Hash(&filter)
|
||||
filterChanged := filterHash == pm.filterHash
|
||||
filterChanged := filterHash != pm.filterHash
|
||||
pm.filter = filter
|
||||
pm.filterHash = filterHash
|
||||
if filterChanged {
|
||||
pm.matchers = matcher.MatchesFromFilterRules(pm.filter)
|
||||
}
|
||||
|
||||
// Order matters, tags might be used in autoapprovers, so we need to ensure
|
||||
// that the map for tag owners is resolved before resolving autoapprovers.
|
||||
|
@ -100,12 +112,6 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
|||
return false, nil
|
||||
}
|
||||
|
||||
// Clear the SSH policy map to ensure it's recalculated with the new policy.
|
||||
// TODO(kradalby): This could potentially be optimized by only clearing the
|
||||
// policies for nodes that have changed. Particularly if the only difference is
|
||||
// that nodes has been added or removed.
|
||||
clear(pm.sshPolicyMap)
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
|
@ -144,11 +150,11 @@ func (pm *PolicyManager) SetPolicy(polB []byte) (bool, error) {
|
|||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
// Filter returns the current filter rules for the entire tailnet.
|
||||
func (pm *PolicyManager) Filter() []tailcfg.FilterRule {
|
||||
// Filter returns the current filter rules for the entire tailnet and the associated matchers.
|
||||
func (pm *PolicyManager) Filter() ([]tailcfg.FilterRule, []matcher.Match) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
return pm.filter
|
||||
return pm.filter, pm.matchers
|
||||
}
|
||||
|
||||
// SetUsers updates the users in the policy manager and updates the filter rules.
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package v2
|
||||
|
||||
import (
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
@ -29,16 +30,18 @@ func TestPolicyManager(t *testing.T) {
|
|||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
pol string
|
||||
nodes types.Nodes
|
||||
wantFilter []tailcfg.FilterRule
|
||||
name string
|
||||
pol string
|
||||
nodes types.Nodes
|
||||
wantFilter []tailcfg.FilterRule
|
||||
wantMatchers []matcher.Match
|
||||
}{
|
||||
{
|
||||
name: "empty-policy",
|
||||
pol: "{}",
|
||||
nodes: types.Nodes{},
|
||||
wantFilter: nil,
|
||||
name: "empty-policy",
|
||||
pol: "{}",
|
||||
nodes: types.Nodes{},
|
||||
wantFilter: nil,
|
||||
wantMatchers: []matcher.Match{},
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -47,9 +50,16 @@ func TestPolicyManager(t *testing.T) {
|
|||
pm, err := NewPolicyManager([]byte(tt.pol), users, tt.nodes)
|
||||
require.NoError(t, err)
|
||||
|
||||
filter := pm.Filter()
|
||||
if diff := cmp.Diff(filter, tt.wantFilter); diff != "" {
|
||||
t.Errorf("Filter() mismatch (-want +got):\n%s", diff)
|
||||
filter, matchers := pm.Filter()
|
||||
if diff := cmp.Diff(tt.wantFilter, filter); diff != "" {
|
||||
t.Errorf("Filter() filter mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(
|
||||
tt.wantMatchers,
|
||||
matchers,
|
||||
cmp.AllowUnexported(matcher.Match{}),
|
||||
); diff != "" {
|
||||
t.Errorf("Filter() matchers mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
// TODO(kradalby): Test SSH Policy
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue