wrap policy in policy manager interface (#2255)
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
2c1ad6d11a
commit
f7b0cbbbea
16 changed files with 742 additions and 371 deletions
181
hscontrol/policy/pm.go
Normal file
181
hscontrol/policy/pm.go
Normal file
|
@ -0,0 +1,181 @@
|
|||
package policy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"go4.org/netipx"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/util/deephash"
|
||||
)
|
||||
|
||||
type PolicyManager interface {
|
||||
Filter() []tailcfg.FilterRule
|
||||
SSHPolicy(*types.Node) (*tailcfg.SSHPolicy, error)
|
||||
Tags(*types.Node) []string
|
||||
ApproversForRoute(netip.Prefix) []string
|
||||
ExpandAlias(string) (*netipx.IPSet, error)
|
||||
SetPolicy([]byte) (bool, error)
|
||||
SetUsers(users []types.User) (bool, error)
|
||||
SetNodes(nodes types.Nodes) (bool, error)
|
||||
}
|
||||
|
||||
func NewPolicyManagerFromPath(path string, users []types.User, nodes types.Nodes) (PolicyManager, error) {
|
||||
policyFile, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer policyFile.Close()
|
||||
|
||||
policyBytes, err := io.ReadAll(policyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewPolicyManager(policyBytes, users, nodes)
|
||||
}
|
||||
|
||||
func NewPolicyManager(polB []byte, users []types.User, nodes types.Nodes) (PolicyManager, error) {
|
||||
var pol *ACLPolicy
|
||||
var err error
|
||||
if polB != nil && len(polB) > 0 {
|
||||
pol, err = LoadACLPolicyFromBytes(polB)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing policy: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
pm := PolicyManagerV1{
|
||||
pol: pol,
|
||||
users: users,
|
||||
nodes: nodes,
|
||||
}
|
||||
|
||||
_, err = pm.updateLocked()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &pm, nil
|
||||
}
|
||||
|
||||
func NewPolicyManagerForTest(pol *ACLPolicy, users []types.User, nodes types.Nodes) (PolicyManager, error) {
|
||||
pm := PolicyManagerV1{
|
||||
pol: pol,
|
||||
users: users,
|
||||
nodes: nodes,
|
||||
}
|
||||
|
||||
_, err := pm.updateLocked()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &pm, nil
|
||||
}
|
||||
|
||||
type PolicyManagerV1 struct {
|
||||
mu sync.Mutex
|
||||
pol *ACLPolicy
|
||||
|
||||
users []types.User
|
||||
nodes types.Nodes
|
||||
|
||||
filterHash deephash.Sum
|
||||
filter []tailcfg.FilterRule
|
||||
}
|
||||
|
||||
// updateLocked updates the filter rules based on the current policy and nodes.
|
||||
// It must be called with the lock held.
|
||||
func (pm *PolicyManagerV1) updateLocked() (bool, error) {
|
||||
filter, err := pm.pol.CompileFilterRules(pm.users, pm.nodes)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("compiling filter rules: %w", err)
|
||||
}
|
||||
|
||||
filterHash := deephash.Hash(&filter)
|
||||
if filterHash == pm.filterHash {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
pm.filter = filter
|
||||
pm.filterHash = filterHash
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (pm *PolicyManagerV1) Filter() []tailcfg.FilterRule {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
return pm.filter
|
||||
}
|
||||
|
||||
func (pm *PolicyManagerV1) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
return pm.pol.CompileSSHPolicy(node, pm.users, pm.nodes)
|
||||
}
|
||||
|
||||
func (pm *PolicyManagerV1) SetPolicy(polB []byte) (bool, error) {
|
||||
pol, err := LoadACLPolicyFromBytes(polB)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("parsing policy: %w", err)
|
||||
}
|
||||
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
pm.pol = pol
|
||||
|
||||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
// SetUsers updates the users in the policy manager and updates the filter rules.
|
||||
func (pm *PolicyManagerV1) SetUsers(users []types.User) (bool, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
pm.users = users
|
||||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
// SetNodes updates the nodes in the policy manager and updates the filter rules.
|
||||
func (pm *PolicyManagerV1) SetNodes(nodes types.Nodes) (bool, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
pm.nodes = nodes
|
||||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
func (pm *PolicyManagerV1) Tags(node *types.Node) []string {
|
||||
if pm == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
tags, _ := pm.pol.TagsOfNode(node)
|
||||
return tags
|
||||
}
|
||||
|
||||
func (pm *PolicyManagerV1) ApproversForRoute(route netip.Prefix) []string {
|
||||
// TODO(kradalby): This can be a parse error of the address in the policy,
|
||||
// in the new policy this will be typed and not a problem, in this policy
|
||||
// we will just return empty list
|
||||
if pm.pol == nil {
|
||||
return nil
|
||||
}
|
||||
approvers, _ := pm.pol.AutoApprovers.GetRouteApprovers(route)
|
||||
return approvers
|
||||
}
|
||||
|
||||
func (pm *PolicyManagerV1) ExpandAlias(alias string) (*netipx.IPSet, error) {
|
||||
ips, err := pm.pol.ExpandAlias(pm.nodes, pm.users, alias)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ips, nil
|
||||
}
|
158
hscontrol/policy/pm_test.go
Normal file
158
hscontrol/policy/pm_test.go
Normal file
|
@ -0,0 +1,158 @@
|
|||
package policy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
func TestPolicySetChange(t *testing.T) {
|
||||
users := []types.User{
|
||||
{
|
||||
Model: gorm.Model{ID: 1},
|
||||
Name: "testuser",
|
||||
},
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
users []types.User
|
||||
nodes types.Nodes
|
||||
policy []byte
|
||||
wantUsersChange bool
|
||||
wantNodesChange bool
|
||||
wantPolicyChange bool
|
||||
wantFilter []tailcfg.FilterRule
|
||||
}{
|
||||
{
|
||||
name: "set-nodes",
|
||||
nodes: types.Nodes{
|
||||
{
|
||||
IPv4: iap("100.64.0.2"),
|
||||
User: users[0],
|
||||
},
|
||||
},
|
||||
wantNodesChange: false,
|
||||
wantFilter: []tailcfg.FilterRule{
|
||||
{
|
||||
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set-users",
|
||||
users: users,
|
||||
wantUsersChange: false,
|
||||
wantFilter: []tailcfg.FilterRule{
|
||||
{
|
||||
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set-users-and-node",
|
||||
users: users,
|
||||
nodes: types.Nodes{
|
||||
{
|
||||
IPv4: iap("100.64.0.2"),
|
||||
User: users[0],
|
||||
},
|
||||
},
|
||||
wantUsersChange: false,
|
||||
wantNodesChange: true,
|
||||
wantFilter: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.2/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set-policy",
|
||||
policy: []byte(`
|
||||
{
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": [
|
||||
"100.64.0.61",
|
||||
],
|
||||
"dst": [
|
||||
"100.64.0.62:*",
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
`),
|
||||
wantPolicyChange: true,
|
||||
wantFilter: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.61/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.62/32", Ports: tailcfg.PortRangeAny}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pol := `
|
||||
{
|
||||
"groups": {
|
||||
"group:example": [
|
||||
"testuser",
|
||||
],
|
||||
},
|
||||
|
||||
"hosts": {
|
||||
"host-1": "100.64.0.1",
|
||||
"subnet-1": "100.100.101.100/24",
|
||||
},
|
||||
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": [
|
||||
"group:example",
|
||||
],
|
||||
"dst": [
|
||||
"host-1:*",
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
`
|
||||
pm, err := NewPolicyManager([]byte(pol), []types.User{}, types.Nodes{})
|
||||
require.NoError(t, err)
|
||||
|
||||
if tt.policy != nil {
|
||||
change, err := pm.SetPolicy(tt.policy)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tt.wantPolicyChange, change)
|
||||
}
|
||||
|
||||
if tt.users != nil {
|
||||
change, err := pm.SetUsers(tt.users)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tt.wantUsersChange, change)
|
||||
}
|
||||
|
||||
if tt.nodes != nil {
|
||||
change, err := pm.SetNodes(tt.nodes)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tt.wantNodesChange, change)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.wantFilter, pm.Filter()); diff != "" {
|
||||
t.Errorf("TestPolicySetChange() unexpected result (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue