Experimental implementation of Policy v2 (#2214)
* utility iterator for ipset Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * split policy -> policy and v1 This commit split out the common policy logic and policy implementation into separate packages. policy contains functions that are independent of the policy implementation, this typically means logic that works on tailcfg types and generic formats. In addition, it defines the PolicyManager interface which the v1 implements. v1 is a subpackage which implements the PolicyManager using the "original" policy implementation. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * use polivyv1 definitions in integration tests These can be marshalled back into JSON, which the new format might not be able to. Also, just dont change it all to JSON strings for now. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * formatter: breaks lines Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * remove compareprefix, use tsaddr version Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * remove getacl test, add back autoapprover Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * use policy manager tag handling Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * rename display helper for user Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * introduce policy v2 package policy v2 is built from the ground up to be stricter and follow the same pattern for all types of resolvers. TODO introduce aliass resolver Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * wire up policyv2 in integration testing Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * split policy v2 tests into seperate workflow to work around github limit Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * add policy manager output to /debug Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * update changelog Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> --------- Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
b6fbd37539
commit
87326f5c4f
41 changed files with 5883 additions and 2118 deletions
|
@ -194,10 +194,14 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
|||
|
||||
var magicDNSDomains []dnsname.FQDN
|
||||
if cfg.PrefixV4 != nil {
|
||||
magicDNSDomains = append(magicDNSDomains, util.GenerateIPv4DNSRootDomain(*cfg.PrefixV4)...)
|
||||
magicDNSDomains = append(
|
||||
magicDNSDomains,
|
||||
util.GenerateIPv4DNSRootDomain(*cfg.PrefixV4)...)
|
||||
}
|
||||
if cfg.PrefixV6 != nil {
|
||||
magicDNSDomains = append(magicDNSDomains, util.GenerateIPv6DNSRootDomain(*cfg.PrefixV6)...)
|
||||
magicDNSDomains = append(
|
||||
magicDNSDomains,
|
||||
util.GenerateIPv6DNSRootDomain(*cfg.PrefixV6)...)
|
||||
}
|
||||
|
||||
// we might have routes already from Split DNS
|
||||
|
@ -459,11 +463,13 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
|
|||
router := mux.NewRouter()
|
||||
router.Use(prometheusMiddleware)
|
||||
|
||||
router.HandleFunc(ts2021UpgradePath, h.NoiseUpgradeHandler).Methods(http.MethodPost, http.MethodGet)
|
||||
router.HandleFunc(ts2021UpgradePath, h.NoiseUpgradeHandler).
|
||||
Methods(http.MethodPost, http.MethodGet)
|
||||
|
||||
router.HandleFunc("/health", h.HealthHandler).Methods(http.MethodGet)
|
||||
router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet)
|
||||
router.HandleFunc("/register/{registration_id}", h.authProvider.RegisterHandler).Methods(http.MethodGet)
|
||||
router.HandleFunc("/register/{registration_id}", h.authProvider.RegisterHandler).
|
||||
Methods(http.MethodGet)
|
||||
|
||||
if provider, ok := h.authProvider.(*AuthProviderOIDC); ok {
|
||||
router.HandleFunc("/oidc/callback", provider.OIDCCallbackHandler).Methods(http.MethodGet)
|
||||
|
@ -523,7 +529,11 @@ func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *not
|
|||
// Maybe we should attempt a new in memory state and not go via the DB?
|
||||
// Maybe this should be implemented as an event bus?
|
||||
// A bool is returned indicating if a full update was sent to all nodes
|
||||
func nodesChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) (bool, error) {
|
||||
func nodesChangedHook(
|
||||
db *db.HSDatabase,
|
||||
polMan policy.PolicyManager,
|
||||
notif *notifier.Notifier,
|
||||
) (bool, error) {
|
||||
nodes, err := db.ListNodes()
|
||||
if err != nil {
|
||||
return false, err
|
||||
|
@ -1143,6 +1153,7 @@ func (h *Headscale) loadPolicyManager() error {
|
|||
errOut = fmt.Errorf("creating policy manager: %w", err)
|
||||
return
|
||||
}
|
||||
log.Info().Msgf("Using policy manager version: %d", h.polMan.Version())
|
||||
|
||||
if len(nodes) > 0 {
|
||||
_, err = h.polMan.SSHPolicy(nodes[0])
|
||||
|
|
|
@ -22,6 +22,7 @@ import (
|
|||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/schema"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/util/set"
|
||||
"zgo.at/zcache/v2"
|
||||
)
|
||||
|
@ -655,7 +656,7 @@ AND auth_key_id NOT IN (
|
|||
}
|
||||
|
||||
for nodeID, routes := range nodeRoutes {
|
||||
slices.SortFunc(routes, util.ComparePrefix)
|
||||
tsaddr.SortPrefixes(routes)
|
||||
slices.Compact(routes)
|
||||
|
||||
data, err := json.Marshal(routes)
|
||||
|
|
|
@ -19,6 +19,7 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/check.v1"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/ptr"
|
||||
|
@ -146,105 +147,6 @@ func (s *Suite) TestListPeers(c *check.C) {
|
|||
c.Assert(peersOfNode0[8].Hostname, check.Equals, "testnode10")
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
||||
type base struct {
|
||||
user *types.User
|
||||
key *types.PreAuthKey
|
||||
}
|
||||
|
||||
stor := make([]base, 0)
|
||||
|
||||
for _, name := range []string{"test", "admin"} {
|
||||
user, err := db.CreateUser(types.User{Name: name})
|
||||
c.Assert(err, check.IsNil)
|
||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
stor = append(stor, base{user, pak})
|
||||
}
|
||||
|
||||
_, err := db.GetNodeByID(0)
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
for index := 0; index <= 10; index++ {
|
||||
nodeKey := key.NewNode()
|
||||
machineKey := key.NewMachine()
|
||||
|
||||
v4 := netip.MustParseAddr(fmt.Sprintf("100.64.0.%d", index+1))
|
||||
node := types.Node{
|
||||
ID: types.NodeID(index),
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
IPv4: &v4,
|
||||
Hostname: "testnode" + strconv.Itoa(index),
|
||||
UserID: stor[index%2].user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: ptr.To(stor[index%2].key.ID),
|
||||
}
|
||||
trx := db.DB.Save(&node)
|
||||
c.Assert(trx.Error, check.IsNil)
|
||||
}
|
||||
|
||||
aclPolicy := &policy.ACLPolicy{
|
||||
Groups: map[string][]string{
|
||||
"group:test": {"admin"},
|
||||
},
|
||||
Hosts: map[string]netip.Prefix{},
|
||||
TagOwners: map[string][]string{},
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"admin"},
|
||||
Destinations: []string{"*:*"},
|
||||
},
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"test"},
|
||||
Destinations: []string{"test:*"},
|
||||
},
|
||||
},
|
||||
Tests: []policy.ACLTest{},
|
||||
}
|
||||
|
||||
adminNode, err := db.GetNodeByID(1)
|
||||
c.Logf("Node(%v), user: %v", adminNode.Hostname, adminNode.User)
|
||||
c.Assert(adminNode.IPv4, check.NotNil)
|
||||
c.Assert(adminNode.IPv6, check.IsNil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
testNode, err := db.GetNodeByID(2)
|
||||
c.Logf("Node(%v), user: %v", testNode.Hostname, testNode.User)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
adminPeers, err := db.ListPeers(adminNode.ID)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(adminPeers), check.Equals, 9)
|
||||
|
||||
testPeers, err := db.ListPeers(testNode.ID)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(testPeers), check.Equals, 9)
|
||||
|
||||
adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers, []types.User{*stor[0].user, *stor[1].user})
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers, []types.User{*stor[0].user, *stor[1].user})
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules)
|
||||
peersOfTestNode := policy.FilterNodesByACL(testNode, testPeers, testRules)
|
||||
c.Log(peersOfAdminNode)
|
||||
c.Log(peersOfTestNode)
|
||||
|
||||
c.Assert(len(peersOfTestNode), check.Equals, 9)
|
||||
c.Assert(peersOfTestNode[0].Hostname, check.Equals, "testnode1")
|
||||
c.Assert(peersOfTestNode[1].Hostname, check.Equals, "testnode3")
|
||||
c.Assert(peersOfTestNode[3].Hostname, check.Equals, "testnode5")
|
||||
|
||||
c.Assert(len(peersOfAdminNode), check.Equals, 9)
|
||||
c.Assert(peersOfAdminNode[0].Hostname, check.Equals, "testnode2")
|
||||
c.Assert(peersOfAdminNode[2].Hostname, check.Equals, "testnode4")
|
||||
c.Assert(peersOfAdminNode[5].Hostname, check.Equals, "testnode7")
|
||||
}
|
||||
|
||||
func (s *Suite) TestExpireNode(c *check.C) {
|
||||
user, err := db.CreateUser(types.User{Name: "test"})
|
||||
c.Assert(err, check.IsNil)
|
||||
|
@ -456,143 +358,171 @@ func TestHeadscale_generateGivenName(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// TODO(kradalby): replace this test
|
||||
// func TestAutoApproveRoutes(t *testing.T) {
|
||||
// tests := []struct {
|
||||
// name string
|
||||
// acl string
|
||||
// routes []netip.Prefix
|
||||
// want []netip.Prefix
|
||||
// }{
|
||||
// {
|
||||
// name: "2068-approve-issue-sub",
|
||||
// acl: `
|
||||
// {
|
||||
// "groups": {
|
||||
// "group:k8s": ["test"]
|
||||
// },
|
||||
func TestAutoApproveRoutes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
acl string
|
||||
routes []netip.Prefix
|
||||
want []netip.Prefix
|
||||
want2 []netip.Prefix
|
||||
}{
|
||||
{
|
||||
name: "2068-approve-issue-sub-kube",
|
||||
acl: `
|
||||
{
|
||||
"groups": {
|
||||
"group:k8s": ["test@"]
|
||||
},
|
||||
|
||||
// "acls": [
|
||||
// {"action": "accept", "users": ["*"], "ports": ["*:*"]},
|
||||
// ],
|
||||
|
||||
// "autoApprovers": {
|
||||
// "routes": {
|
||||
// "10.42.0.0/16": ["test"],
|
||||
// }
|
||||
// }
|
||||
// }`,
|
||||
// routes: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")},
|
||||
// want: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")},
|
||||
// },
|
||||
// {
|
||||
// name: "2068-approve-issue-sub",
|
||||
// acl: `
|
||||
// {
|
||||
// "tagOwners": {
|
||||
// "tag:exit": ["test"],
|
||||
// },
|
||||
"autoApprovers": {
|
||||
"routes": {
|
||||
"10.42.0.0/16": ["test@"],
|
||||
}
|
||||
}
|
||||
}`,
|
||||
routes: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")},
|
||||
want: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")},
|
||||
},
|
||||
{
|
||||
name: "2068-approve-issue-sub-exit-tag",
|
||||
acl: `
|
||||
{
|
||||
"tagOwners": {
|
||||
"tag:exit": ["test@"],
|
||||
},
|
||||
|
||||
// "groups": {
|
||||
// "group:test": ["test"]
|
||||
// },
|
||||
"groups": {
|
||||
"group:test": ["test@"]
|
||||
},
|
||||
|
||||
// "acls": [
|
||||
// {"action": "accept", "users": ["*"], "ports": ["*:*"]},
|
||||
// ],
|
||||
|
||||
// "autoApprovers": {
|
||||
// "exitNode": ["tag:exit"],
|
||||
// "routes": {
|
||||
// "10.10.0.0/16": ["group:test"],
|
||||
// "10.11.0.0/16": ["test"],
|
||||
// }
|
||||
// }
|
||||
// }`,
|
||||
// routes: []netip.Prefix{
|
||||
// tsaddr.AllIPv4(),
|
||||
// tsaddr.AllIPv6(),
|
||||
// netip.MustParsePrefix("10.10.0.0/16"),
|
||||
// netip.MustParsePrefix("10.11.0.0/24"),
|
||||
// },
|
||||
// want: []netip.Prefix{
|
||||
// tsaddr.AllIPv4(),
|
||||
// netip.MustParsePrefix("10.10.0.0/16"),
|
||||
// netip.MustParsePrefix("10.11.0.0/24"),
|
||||
// tsaddr.AllIPv6(),
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
"autoApprovers": {
|
||||
"exitNode": ["tag:exit"],
|
||||
"routes": {
|
||||
"10.10.0.0/16": ["group:test"],
|
||||
"10.11.0.0/16": ["test@"],
|
||||
"8.11.0.0/24": ["test2@"], // No nodes
|
||||
}
|
||||
}
|
||||
}`,
|
||||
routes: []netip.Prefix{
|
||||
tsaddr.AllIPv4(),
|
||||
tsaddr.AllIPv6(),
|
||||
netip.MustParsePrefix("10.10.0.0/16"),
|
||||
netip.MustParsePrefix("10.11.0.0/24"),
|
||||
|
||||
// for _, tt := range tests {
|
||||
// t.Run(tt.name, func(t *testing.T) {
|
||||
// adb, err := newSQLiteTestDB()
|
||||
// require.NoError(t, err)
|
||||
// pol, err := policy.LoadACLPolicyFromBytes([]byte(tt.acl))
|
||||
// Not approved
|
||||
netip.MustParsePrefix("8.11.0.0/24"),
|
||||
},
|
||||
want: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.10.0.0/16"),
|
||||
netip.MustParsePrefix("10.11.0.0/24"),
|
||||
},
|
||||
want2: []netip.Prefix{
|
||||
tsaddr.AllIPv4(),
|
||||
tsaddr.AllIPv6(),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// require.NoError(t, err)
|
||||
// require.NotNil(t, pol)
|
||||
for _, tt := range tests {
|
||||
pmfs := policy.PolicyManagerFuncsForTest([]byte(tt.acl))
|
||||
for i, pmf := range pmfs {
|
||||
version := i + 1
|
||||
t.Run(fmt.Sprintf("%s-policyv%d", tt.name, version), func(t *testing.T) {
|
||||
adb, err := newSQLiteTestDB()
|
||||
require.NoError(t, err)
|
||||
|
||||
// user, err := adb.CreateUser(types.User{Name: "test"})
|
||||
// require.NoError(t, err)
|
||||
suffix := ""
|
||||
if version == 1 {
|
||||
suffix = "@"
|
||||
}
|
||||
|
||||
// pak, err := adb.CreatePreAuthKey(types.UserID(user.ID), false, nil, nil)
|
||||
// require.NoError(t, err)
|
||||
user, err := adb.CreateUser(types.User{Name: "test" + suffix})
|
||||
require.NoError(t, err)
|
||||
_, err = adb.CreateUser(types.User{Name: "test2" + suffix})
|
||||
require.NoError(t, err)
|
||||
taggedUser, err := adb.CreateUser(types.User{Name: "tagged" + suffix})
|
||||
require.NoError(t, err)
|
||||
|
||||
// nodeKey := key.NewNode()
|
||||
// machineKey := key.NewMachine()
|
||||
node := types.Node{
|
||||
ID: 1,
|
||||
MachineKey: key.NewMachine().Public(),
|
||||
NodeKey: key.NewNode().Public(),
|
||||
Hostname: "testnode",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: tt.routes,
|
||||
},
|
||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
||||
}
|
||||
|
||||
// v4 := netip.MustParseAddr("100.64.0.1")
|
||||
// node := types.Node{
|
||||
// ID: 0,
|
||||
// MachineKey: machineKey.Public(),
|
||||
// NodeKey: nodeKey.Public(),
|
||||
// Hostname: "test",
|
||||
// UserID: user.ID,
|
||||
// RegisterMethod: util.RegisterMethodAuthKey,
|
||||
// AuthKeyID: ptr.To(pak.ID),
|
||||
// Hostinfo: &tailcfg.Hostinfo{
|
||||
// RequestTags: []string{"tag:exit"},
|
||||
// RoutableIPs: tt.routes,
|
||||
// },
|
||||
// IPv4: &v4,
|
||||
// }
|
||||
err = adb.DB.Save(&node).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// trx := adb.DB.Save(&node)
|
||||
// require.NoError(t, trx.Error)
|
||||
nodeTagged := types.Node{
|
||||
ID: 2,
|
||||
MachineKey: key.NewMachine().Public(),
|
||||
NodeKey: key.NewNode().Public(),
|
||||
Hostname: "taggednode",
|
||||
UserID: taggedUser.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: tt.routes,
|
||||
},
|
||||
ForcedTags: []string{"tag:exit"},
|
||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")),
|
||||
}
|
||||
|
||||
// sendUpdate, err := adb.SaveNodeRoutes(&node)
|
||||
// require.NoError(t, err)
|
||||
// assert.False(t, sendUpdate)
|
||||
err = adb.DB.Save(&nodeTagged).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// node0ByID, err := adb.GetNodeByID(0)
|
||||
// require.NoError(t, err)
|
||||
users, err := adb.ListUsers()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// users, err := adb.ListUsers()
|
||||
// assert.NoError(t, err)
|
||||
nodes, err := adb.ListNodes()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// nodes, err := adb.ListNodes()
|
||||
// assert.NoError(t, err)
|
||||
pm, err := pmf(users, nodes)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pm)
|
||||
|
||||
// pm, err := policy.NewPolicyManager([]byte(tt.acl), users, nodes)
|
||||
// assert.NoError(t, err)
|
||||
changed1 := policy.AutoApproveRoutes(pm, &node)
|
||||
assert.True(t, changed1)
|
||||
|
||||
// // TODO(kradalby): Check state update
|
||||
// err = adb.EnableAutoApprovedRoutes(pm, node0ByID)
|
||||
// require.NoError(t, err)
|
||||
err = adb.DB.Save(&node).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// enabledRoutes, err := adb.GetEnabledRoutes(node0ByID)
|
||||
// require.NoError(t, err)
|
||||
// assert.Len(t, enabledRoutes, len(tt.want))
|
||||
_ = policy.AutoApproveRoutes(pm, &nodeTagged)
|
||||
|
||||
// tsaddr.SortPrefixes(enabledRoutes)
|
||||
err = adb.DB.Save(&nodeTagged).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// if diff := cmp.Diff(tt.want, enabledRoutes, util.Comparers...); diff != "" {
|
||||
// t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
|
||||
// }
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
node1ByID, err := adb.GetNodeByID(1)
|
||||
require.NoError(t, err)
|
||||
|
||||
if diff := cmp.Diff(tt.want, node1ByID.SubnetRoutes(), util.Comparers...); diff != "" {
|
||||
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
node2ByID, err := adb.GetNodeByID(2)
|
||||
require.NoError(t, err)
|
||||
|
||||
if diff := cmp.Diff(tt.want2, node2ByID.SubnetRoutes(), util.Comparers...); diff != "" {
|
||||
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEphemeralGarbageCollectorOrder(t *testing.T) {
|
||||
want := []types.NodeID{1, 3}
|
||||
|
|
|
@ -105,6 +105,11 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
|||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(h.primaryRoutes.String()))
|
||||
}))
|
||||
debug.Handle("policy-manager", "Policy Manager", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(h.polMan.DebugString()))
|
||||
}))
|
||||
|
||||
err := statsviz.Register(debugMux)
|
||||
if err == nil {
|
||||
|
|
|
@ -348,7 +348,7 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
|
|||
routes = append(routes, prefix)
|
||||
}
|
||||
}
|
||||
slices.SortFunc(routes, util.ComparePrefix)
|
||||
tsaddr.SortPrefixes(routes)
|
||||
slices.Compact(routes)
|
||||
|
||||
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
|
@ -525,7 +525,12 @@ func nodesToProto(polMan policy.PolicyManager, isLikelyConnected *xsync.MapOf[ty
|
|||
resp.Online = true
|
||||
}
|
||||
|
||||
tags := polMan.Tags(node)
|
||||
var tags []string
|
||||
for _, tag := range node.RequestTags() {
|
||||
if polMan.NodeCanHaveTag(node, tag) {
|
||||
tags = append(tags, tag)
|
||||
}
|
||||
}
|
||||
resp.ValidTags = lo.Uniq(append(tags, node.ForcedTags...))
|
||||
response[index] = resp
|
||||
}
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/routes"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
|
@ -246,7 +247,7 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
|
||||
tests := []struct {
|
||||
name string
|
||||
pol *policy.ACLPolicy
|
||||
pol []byte
|
||||
node *types.Node
|
||||
peers types.Nodes
|
||||
|
||||
|
@ -258,7 +259,7 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
// {
|
||||
// name: "empty-node",
|
||||
// node: types.Node{},
|
||||
// pol: &policy.ACLPolicy{},
|
||||
// pol: &policyv1.ACLPolicy{},
|
||||
// dnsConfig: &tailcfg.DNSConfig{},
|
||||
// baseDomain: "",
|
||||
// want: nil,
|
||||
|
@ -266,7 +267,6 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
// },
|
||||
{
|
||||
name: "no-pol-no-peers-map-response",
|
||||
pol: &policy.ACLPolicy{},
|
||||
node: mini,
|
||||
peers: types.Nodes{},
|
||||
derpMap: &tailcfg.DERPMap{},
|
||||
|
@ -284,10 +284,15 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
DNSConfig: &tailcfg.DNSConfig{},
|
||||
Domain: "",
|
||||
CollectServices: "false",
|
||||
PacketFilter: []tailcfg.FilterRule{},
|
||||
UserProfiles: []tailcfg.UserProfile{{ID: tailcfg.UserID(user1.ID), LoginName: "user1", DisplayName: "user1"}},
|
||||
SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
|
||||
ControlTime: &time.Time{},
|
||||
UserProfiles: []tailcfg.UserProfile{
|
||||
{
|
||||
ID: tailcfg.UserID(user1.ID),
|
||||
LoginName: "user1",
|
||||
DisplayName: "user1",
|
||||
},
|
||||
},
|
||||
PacketFilter: tailcfg.FilterAllowAll,
|
||||
ControlTime: &time.Time{},
|
||||
Debug: &tailcfg.Debug{
|
||||
DisableLogTail: true,
|
||||
},
|
||||
|
@ -296,7 +301,6 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "no-pol-with-peer-map-response",
|
||||
pol: &policy.ACLPolicy{},
|
||||
node: mini,
|
||||
peers: types.Nodes{
|
||||
peer1,
|
||||
|
@ -318,13 +322,12 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
DNSConfig: &tailcfg.DNSConfig{},
|
||||
Domain: "",
|
||||
CollectServices: "false",
|
||||
PacketFilter: []tailcfg.FilterRule{},
|
||||
UserProfiles: []tailcfg.UserProfile{
|
||||
{ID: tailcfg.UserID(user1.ID), LoginName: "user1", DisplayName: "user1"},
|
||||
{ID: tailcfg.UserID(user2.ID), LoginName: "user2", DisplayName: "user2"},
|
||||
},
|
||||
SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
|
||||
ControlTime: &time.Time{},
|
||||
PacketFilter: tailcfg.FilterAllowAll,
|
||||
ControlTime: &time.Time{},
|
||||
Debug: &tailcfg.Debug{
|
||||
DisableLogTail: true,
|
||||
},
|
||||
|
@ -333,18 +336,17 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "with-pol-map-response",
|
||||
pol: &policy.ACLPolicy{
|
||||
Hosts: policy.Hosts{
|
||||
"mini": netip.MustParsePrefix("100.64.0.1/32"),
|
||||
},
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"100.64.0.2"},
|
||||
Destinations: []string{"mini:*"},
|
||||
},
|
||||
},
|
||||
},
|
||||
pol: []byte(`
|
||||
{
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": ["100.64.0.2"],
|
||||
"dst": ["user1:*"],
|
||||
},
|
||||
],
|
||||
}
|
||||
`),
|
||||
node: mini,
|
||||
peers: types.Nodes{
|
||||
peer1,
|
||||
|
@ -374,11 +376,11 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
SSHPolicy: &tailcfg.SSHPolicy{},
|
||||
UserProfiles: []tailcfg.UserProfile{
|
||||
{ID: tailcfg.UserID(user1.ID), LoginName: "user1", DisplayName: "user1"},
|
||||
{ID: tailcfg.UserID(user2.ID), LoginName: "user2", DisplayName: "user2"},
|
||||
},
|
||||
SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
|
||||
ControlTime: &time.Time{},
|
||||
Debug: &tailcfg.Debug{
|
||||
DisableLogTail: true,
|
||||
|
@ -390,7 +392,8 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
polMan, _ := policy.NewPolicyManagerForTest(tt.pol, []types.User{user1, user2}, append(tt.peers, tt.node))
|
||||
polMan, err := policy.NewPolicyManager(tt.pol, []types.User{user1, user2}, append(tt.peers, tt.node))
|
||||
require.NoError(t, err)
|
||||
primary := routes.New()
|
||||
|
||||
primary.SetRoutes(tt.node.ID, tt.node.SubnetRoutes()...)
|
||||
|
|
|
@ -81,7 +81,12 @@ func tailNode(
|
|||
return nil, fmt.Errorf("tailNode, failed to create FQDN: %s", err)
|
||||
}
|
||||
|
||||
tags := polMan.Tags(node)
|
||||
var tags []string
|
||||
for _, tag := range node.RequestTags() {
|
||||
if polMan.NodeCanHaveTag(node, tag) {
|
||||
tags = append(tags, tag)
|
||||
}
|
||||
}
|
||||
tags = lo.Uniq(append(tags, node.ForcedTags...))
|
||||
|
||||
tNode := tailcfg.Node{
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/routes"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/stretchr/testify/require"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
|
@ -49,7 +50,7 @@ func TestTailNode(t *testing.T) {
|
|||
tests := []struct {
|
||||
name string
|
||||
node *types.Node
|
||||
pol *policy.ACLPolicy
|
||||
pol []byte
|
||||
dnsConfig *tailcfg.DNSConfig
|
||||
baseDomain string
|
||||
want *tailcfg.Node
|
||||
|
@ -61,7 +62,6 @@ func TestTailNode(t *testing.T) {
|
|||
GivenName: "empty",
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
},
|
||||
pol: &policy.ACLPolicy{},
|
||||
dnsConfig: &tailcfg.DNSConfig{},
|
||||
baseDomain: "",
|
||||
want: &tailcfg.Node{
|
||||
|
@ -117,7 +117,6 @@ func TestTailNode(t *testing.T) {
|
|||
ApprovedRoutes: []netip.Prefix{tsaddr.AllIPv4(), netip.MustParsePrefix("192.168.0.0/24")},
|
||||
CreatedAt: created,
|
||||
},
|
||||
pol: &policy.ACLPolicy{},
|
||||
dnsConfig: &tailcfg.DNSConfig{},
|
||||
baseDomain: "",
|
||||
want: &tailcfg.Node{
|
||||
|
@ -179,7 +178,8 @@ func TestTailNode(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
polMan, _ := policy.NewPolicyManagerForTest(tt.pol, []types.User{}, types.Nodes{tt.node})
|
||||
polMan, err := policy.NewPolicyManager(tt.pol, []types.User{}, types.Nodes{tt.node})
|
||||
require.NoError(t, err)
|
||||
primary := routes.New()
|
||||
cfg := &types.Config{
|
||||
BaseDomain: tt.baseDomain,
|
||||
|
@ -248,7 +248,7 @@ func TestNodeExpiry(t *testing.T) {
|
|||
tn, err := tailNode(
|
||||
node,
|
||||
0,
|
||||
&policy.PolicyManagerV1{},
|
||||
nil, // TODO(kradalby): removed in merge but error?
|
||||
nil,
|
||||
&types.Config{},
|
||||
)
|
||||
|
|
|
@ -513,7 +513,7 @@ func renderOIDCCallbackTemplate(
|
|||
) (*bytes.Buffer, error) {
|
||||
var content bytes.Buffer
|
||||
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
|
||||
User: user.DisplayNameOrUsername(),
|
||||
User: user.Display(),
|
||||
Verb: verb,
|
||||
}); err != nil {
|
||||
return nil, fmt.Errorf("rendering OIDC callback template: %w", err)
|
||||
|
|
|
@ -1,219 +1,81 @@
|
|||
package policy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
policyv1 "github.com/juanfont/headscale/hscontrol/policy/v1"
|
||||
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/rs/zerolog/log"
|
||||
"go4.org/netipx"
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/util/deephash"
|
||||
)
|
||||
|
||||
var (
|
||||
polv2 = envknob.Bool("HEADSCALE_EXPERIMENTAL_POLICY_V2")
|
||||
)
|
||||
|
||||
type PolicyManager interface {
|
||||
Filter() []tailcfg.FilterRule
|
||||
SSHPolicy(*types.Node) (*tailcfg.SSHPolicy, error)
|
||||
Tags(*types.Node) []string
|
||||
ApproversForRoute(netip.Prefix) []string
|
||||
ExpandAlias(string) (*netipx.IPSet, error)
|
||||
SetPolicy([]byte) (bool, error)
|
||||
SetUsers(users []types.User) (bool, error)
|
||||
SetNodes(nodes types.Nodes) (bool, error)
|
||||
// NodeCanHaveTag reports whether the given node can have the given tag.
|
||||
NodeCanHaveTag(*types.Node, string) bool
|
||||
|
||||
// NodeCanApproveRoute reports whether the given node can approve the given route.
|
||||
NodeCanApproveRoute(*types.Node, netip.Prefix) bool
|
||||
|
||||
Version() int
|
||||
DebugString() string
|
||||
}
|
||||
|
||||
func NewPolicyManagerFromPath(path string, users []types.User, nodes types.Nodes) (PolicyManager, error) {
|
||||
policyFile, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer policyFile.Close()
|
||||
|
||||
policyBytes, err := io.ReadAll(policyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewPolicyManager(policyBytes, users, nodes)
|
||||
}
|
||||
|
||||
func NewPolicyManager(polB []byte, users []types.User, nodes types.Nodes) (PolicyManager, error) {
|
||||
var pol *ACLPolicy
|
||||
// NewPolicyManager returns a new policy manager, the version is determined by
|
||||
// the environment flag "HEADSCALE_EXPERIMENTAL_POLICY_V2".
|
||||
func NewPolicyManager(pol []byte, users []types.User, nodes types.Nodes) (PolicyManager, error) {
|
||||
var polMan PolicyManager
|
||||
var err error
|
||||
if polB != nil && len(polB) > 0 {
|
||||
pol, err = LoadACLPolicyFromBytes(polB)
|
||||
if polv2 {
|
||||
polMan, err = policyv2.NewPolicyManager(pol, users, nodes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing policy: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
polMan, err = policyv1.NewPolicyManager(pol, users, nodes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
pm := PolicyManagerV1{
|
||||
pol: pol,
|
||||
users: users,
|
||||
nodes: nodes,
|
||||
}
|
||||
|
||||
_, err = pm.updateLocked()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &pm, nil
|
||||
return polMan, err
|
||||
}
|
||||
|
||||
func NewPolicyManagerForTest(pol *ACLPolicy, users []types.User, nodes types.Nodes) (PolicyManager, error) {
|
||||
pm := PolicyManagerV1{
|
||||
pol: pol,
|
||||
users: users,
|
||||
nodes: nodes,
|
||||
}
|
||||
// PolicyManagersForTest returns all available PostureManagers to be used
|
||||
// in tests to validate them in tests that try to determine that they
|
||||
// behave the same.
|
||||
func PolicyManagersForTest(pol []byte, users []types.User, nodes types.Nodes) ([]PolicyManager, error) {
|
||||
var polMans []PolicyManager
|
||||
|
||||
_, err := pm.updateLocked()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &pm, nil
|
||||
}
|
||||
|
||||
type PolicyManagerV1 struct {
|
||||
mu sync.Mutex
|
||||
pol *ACLPolicy
|
||||
|
||||
users []types.User
|
||||
nodes types.Nodes
|
||||
|
||||
filterHash deephash.Sum
|
||||
filter []tailcfg.FilterRule
|
||||
}
|
||||
|
||||
// updateLocked updates the filter rules based on the current policy and nodes.
|
||||
// It must be called with the lock held.
|
||||
func (pm *PolicyManagerV1) updateLocked() (bool, error) {
|
||||
filter, err := pm.pol.CompileFilterRules(pm.users, pm.nodes)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("compiling filter rules: %w", err)
|
||||
}
|
||||
|
||||
filterHash := deephash.Hash(&filter)
|
||||
if filterHash == pm.filterHash {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
pm.filter = filter
|
||||
pm.filterHash = filterHash
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (pm *PolicyManagerV1) Filter() []tailcfg.FilterRule {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
return pm.filter
|
||||
}
|
||||
|
||||
func (pm *PolicyManagerV1) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
return pm.pol.CompileSSHPolicy(node, pm.users, pm.nodes)
|
||||
}
|
||||
|
||||
func (pm *PolicyManagerV1) SetPolicy(polB []byte) (bool, error) {
|
||||
if len(polB) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
pol, err := LoadACLPolicyFromBytes(polB)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("parsing policy: %w", err)
|
||||
}
|
||||
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
pm.pol = pol
|
||||
|
||||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
// SetUsers updates the users in the policy manager and updates the filter rules.
|
||||
func (pm *PolicyManagerV1) SetUsers(users []types.User) (bool, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
pm.users = users
|
||||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
// SetNodes updates the nodes in the policy manager and updates the filter rules.
|
||||
func (pm *PolicyManagerV1) SetNodes(nodes types.Nodes) (bool, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
pm.nodes = nodes
|
||||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
func (pm *PolicyManagerV1) Tags(node *types.Node) []string {
|
||||
if pm == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
tags, invalid := pm.pol.TagsOfNode(pm.users, node)
|
||||
log.Debug().Strs("authorised_tags", tags).Strs("unauthorised_tags", invalid).Uint64("node.id", node.ID.Uint64()).Msg("tags provided by policy")
|
||||
return tags
|
||||
}
|
||||
|
||||
func (pm *PolicyManagerV1) ApproversForRoute(route netip.Prefix) []string {
|
||||
// TODO(kradalby): This can be a parse error of the address in the policy,
|
||||
// in the new policy this will be typed and not a problem, in this policy
|
||||
// we will just return empty list
|
||||
if pm.pol == nil {
|
||||
return nil
|
||||
}
|
||||
approvers, _ := pm.pol.AutoApprovers.GetRouteApprovers(route)
|
||||
return approvers
|
||||
}
|
||||
|
||||
func (pm *PolicyManagerV1) ExpandAlias(alias string) (*netipx.IPSet, error) {
|
||||
ips, err := pm.pol.ExpandAlias(pm.nodes, pm.users, alias)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
func (pm *PolicyManagerV1) NodeCanApproveRoute(node *types.Node, route netip.Prefix) bool {
|
||||
if pm.pol == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
approvers, _ := pm.pol.AutoApprovers.GetRouteApprovers(route)
|
||||
|
||||
for _, approvedAlias := range approvers {
|
||||
if approvedAlias == node.User.Username() {
|
||||
return true
|
||||
} else {
|
||||
ips, err := pm.pol.ExpandAlias(pm.nodes, pm.users, approvedAlias)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// approvedIPs should contain all of node's IPs if it matches the rule, so check for first
|
||||
if ips.Contains(*node.IPv4) {
|
||||
return true
|
||||
}
|
||||
for _, pmf := range PolicyManagerFuncsForTest(pol) {
|
||||
pm, err := pmf(users, nodes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
polMans = append(polMans, pm)
|
||||
}
|
||||
|
||||
return false
|
||||
return polMans, nil
|
||||
}
|
||||
|
||||
func PolicyManagerFuncsForTest(pol []byte) []func([]types.User, types.Nodes) (PolicyManager, error) {
|
||||
var polmanFuncs []func([]types.User, types.Nodes) (PolicyManager, error)
|
||||
|
||||
polmanFuncs = append(polmanFuncs, func(u []types.User, n types.Nodes) (PolicyManager, error) {
|
||||
return policyv1.NewPolicyManager(pol, u, n)
|
||||
})
|
||||
polmanFuncs = append(polmanFuncs, func(u []types.User, n types.Nodes) (PolicyManager, error) {
|
||||
return policyv2.NewPolicyManager(pol, u, n)
|
||||
})
|
||||
|
||||
return polmanFuncs
|
||||
}
|
||||
|
|
109
hscontrol/policy/policy.go
Normal file
109
hscontrol/policy/policy.go
Normal file
|
@ -0,0 +1,109 @@
|
|||
package policy
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"slices"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/samber/lo"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
// FilterNodesByACL returns the list of peers authorized to be accessed from a given node.
|
||||
func FilterNodesByACL(
|
||||
node *types.Node,
|
||||
nodes types.Nodes,
|
||||
filter []tailcfg.FilterRule,
|
||||
) types.Nodes {
|
||||
var result types.Nodes
|
||||
|
||||
for index, peer := range nodes {
|
||||
if peer.ID == node.ID {
|
||||
continue
|
||||
}
|
||||
|
||||
if node.CanAccess(filter, nodes[index]) || peer.CanAccess(filter, node) {
|
||||
result = append(result, peer)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ReduceFilterRules takes a node and a set of rules and removes all rules and destinations
|
||||
// that are not relevant to that particular node.
|
||||
func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.FilterRule {
|
||||
ret := []tailcfg.FilterRule{}
|
||||
|
||||
for _, rule := range rules {
|
||||
// record if the rule is actually relevant for the given node.
|
||||
var dests []tailcfg.NetPortRange
|
||||
DEST_LOOP:
|
||||
for _, dest := range rule.DstPorts {
|
||||
expanded, err := util.ParseIPSet(dest.IP, nil)
|
||||
// Fail closed, if we can't parse it, then we should not allow
|
||||
// access.
|
||||
if err != nil {
|
||||
continue DEST_LOOP
|
||||
}
|
||||
|
||||
if node.InIPSet(expanded) {
|
||||
dests = append(dests, dest)
|
||||
continue DEST_LOOP
|
||||
}
|
||||
|
||||
// If the node exposes routes, ensure they are note removed
|
||||
// when the filters are reduced.
|
||||
if node.Hostinfo != nil {
|
||||
if len(node.Hostinfo.RoutableIPs) > 0 {
|
||||
for _, routableIP := range node.Hostinfo.RoutableIPs {
|
||||
if expanded.OverlapsPrefix(routableIP) {
|
||||
dests = append(dests, dest)
|
||||
continue DEST_LOOP
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(dests) > 0 {
|
||||
ret = append(ret, tailcfg.FilterRule{
|
||||
SrcIPs: rule.SrcIPs,
|
||||
DstPorts: dests,
|
||||
IPProto: rule.IPProto,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// AutoApproveRoutes approves any route that can be autoapproved from
|
||||
// the nodes perspective according to the given policy.
|
||||
// It reports true if any routes were approved.
|
||||
func AutoApproveRoutes(pm PolicyManager, node *types.Node) bool {
|
||||
if pm == nil {
|
||||
return false
|
||||
}
|
||||
var newApproved []netip.Prefix
|
||||
for _, route := range node.AnnouncedRoutes() {
|
||||
if pm.NodeCanApproveRoute(node, route) {
|
||||
newApproved = append(newApproved, route)
|
||||
}
|
||||
}
|
||||
if newApproved != nil {
|
||||
newApproved = append(newApproved, node.ApprovedRoutes...)
|
||||
tsaddr.SortPrefixes(newApproved)
|
||||
newApproved = slices.Compact(newApproved)
|
||||
newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool {
|
||||
return route.IsValid()
|
||||
})
|
||||
node.ApprovedRoutes = newApproved
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
1455
hscontrol/policy/policy_test.go
Normal file
1455
hscontrol/policy/policy_test.go
Normal file
File diff suppressed because it is too large
Load diff
|
@ -1,11 +1,10 @@
|
|||
package policy
|
||||
package v1
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"iter"
|
||||
"net/netip"
|
||||
"os"
|
||||
"slices"
|
||||
|
@ -18,7 +17,6 @@ import (
|
|||
"github.com/rs/zerolog/log"
|
||||
"github.com/tailscale/hujson"
|
||||
"go4.org/netipx"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
|
@ -37,38 +35,6 @@ const (
|
|||
expectedTokenItems = 2
|
||||
)
|
||||
|
||||
var theInternetSet *netipx.IPSet
|
||||
|
||||
// theInternet returns the IPSet for the Internet.
|
||||
// https://www.youtube.com/watch?v=iDbyYGrswtg
|
||||
func theInternet() *netipx.IPSet {
|
||||
if theInternetSet != nil {
|
||||
return theInternetSet
|
||||
}
|
||||
|
||||
var internetBuilder netipx.IPSetBuilder
|
||||
internetBuilder.AddPrefix(netip.MustParsePrefix("2000::/3"))
|
||||
internetBuilder.AddPrefix(tsaddr.AllIPv4())
|
||||
|
||||
// Delete Private network addresses
|
||||
// https://datatracker.ietf.org/doc/html/rfc1918
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("fc00::/7"))
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("10.0.0.0/8"))
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("172.16.0.0/12"))
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("192.168.0.0/16"))
|
||||
|
||||
// Delete Tailscale networks
|
||||
internetBuilder.RemovePrefix(tsaddr.TailscaleULARange())
|
||||
internetBuilder.RemovePrefix(tsaddr.CGNATRange())
|
||||
|
||||
// Delete "can't find DHCP networks"
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("fe80::/10")) // link-local
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("169.254.0.0/16"))
|
||||
|
||||
theInternetSet, _ := internetBuilder.IPSet()
|
||||
return theInternetSet
|
||||
}
|
||||
|
||||
// For some reason golang.org/x/net/internal/iana is an internal package.
|
||||
const (
|
||||
protocolICMP = 1 // Internet Control Message
|
||||
|
@ -240,53 +206,6 @@ func (pol *ACLPolicy) CompileFilterRules(
|
|||
return rules, nil
|
||||
}
|
||||
|
||||
// ReduceFilterRules takes a node and a set of rules and removes all rules and destinations
|
||||
// that are not relevant to that particular node.
|
||||
func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.FilterRule {
|
||||
// TODO(kradalby): Make this nil and not alloc unless needed
|
||||
ret := []tailcfg.FilterRule{}
|
||||
|
||||
for _, rule := range rules {
|
||||
// record if the rule is actually relevant for the given node.
|
||||
var dests []tailcfg.NetPortRange
|
||||
DEST_LOOP:
|
||||
for _, dest := range rule.DstPorts {
|
||||
expanded, err := util.ParseIPSet(dest.IP, nil)
|
||||
// Fail closed, if we can't parse it, then we should not allow
|
||||
// access.
|
||||
if err != nil {
|
||||
continue DEST_LOOP
|
||||
}
|
||||
|
||||
if node.InIPSet(expanded) {
|
||||
dests = append(dests, dest)
|
||||
continue DEST_LOOP
|
||||
}
|
||||
|
||||
// If the node exposes routes, ensure they are note removed
|
||||
// when the filters are reduced.
|
||||
if len(node.SubnetRoutes()) > 0 {
|
||||
for _, routableIP := range node.SubnetRoutes() {
|
||||
if expanded.OverlapsPrefix(routableIP) {
|
||||
dests = append(dests, dest)
|
||||
continue DEST_LOOP
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(dests) > 0 {
|
||||
ret = append(ret, tailcfg.FilterRule{
|
||||
SrcIPs: rule.SrcIPs,
|
||||
DstPorts: dests,
|
||||
IPProto: rule.IPProto,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func (pol *ACLPolicy) CompileSSHPolicy(
|
||||
node *types.Node,
|
||||
users []types.User,
|
||||
|
@ -418,7 +337,7 @@ func (pol *ACLPolicy) CompileSSHPolicy(
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing SSH policy, expanding alias, index: %d->%d: %w", index, innerIndex, err)
|
||||
}
|
||||
for addr := range ipSetAll(ips) {
|
||||
for addr := range util.IPSetAddrIter(ips) {
|
||||
principals = append(principals, &tailcfg.SSHPrincipal{
|
||||
NodeIP: addr.String(),
|
||||
})
|
||||
|
@ -441,19 +360,6 @@ func (pol *ACLPolicy) CompileSSHPolicy(
|
|||
}, nil
|
||||
}
|
||||
|
||||
// ipSetAll returns a function that iterates over all the IPs in the IPSet.
|
||||
func ipSetAll(ipSet *netipx.IPSet) iter.Seq[netip.Addr] {
|
||||
return func(yield func(netip.Addr) bool) {
|
||||
for _, rng := range ipSet.Ranges() {
|
||||
for ip := rng.From(); ip.Compare(rng.To()) <= 0; ip = ip.Next() {
|
||||
if !yield(ip) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func sshCheckAction(duration string) (*tailcfg.SSHAction, error) {
|
||||
sessionLength, err := time.ParseDuration(duration)
|
||||
if err != nil {
|
||||
|
@ -950,7 +856,7 @@ func (pol *ACLPolicy) expandIPsFromIPPrefix(
|
|||
func expandAutoGroup(alias string) (*netipx.IPSet, error) {
|
||||
switch {
|
||||
case strings.HasPrefix(alias, "autogroup:internet"):
|
||||
return theInternet(), nil
|
||||
return util.TheInternet(), nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown autogroup %q", alias)
|
||||
|
@ -1084,24 +990,3 @@ func findUserFromToken(users []types.User, token string) (types.User, error) {
|
|||
|
||||
return potentialUsers[0], nil
|
||||
}
|
||||
|
||||
// FilterNodesByACL returns the list of peers authorized to be accessed from a given node.
|
||||
func FilterNodesByACL(
|
||||
node *types.Node,
|
||||
nodes types.Nodes,
|
||||
filter []tailcfg.FilterRule,
|
||||
) types.Nodes {
|
||||
var result types.Nodes
|
||||
|
||||
for index, peer := range nodes {
|
||||
if peer.ID == node.ID {
|
||||
continue
|
||||
}
|
||||
|
||||
if node.CanAccess(filter, nodes[index]) || peer.CanAccess(filter, node) {
|
||||
result = append(result, peer)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
File diff suppressed because it is too large
Load diff
|
@ -1,4 +1,4 @@
|
|||
package policy
|
||||
package v1
|
||||
|
||||
import (
|
||||
"encoding/json"
|
187
hscontrol/policy/v1/policy.go
Normal file
187
hscontrol/policy/v1/policy.go
Normal file
|
@ -0,0 +1,187 @@
|
|||
package v1
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/rs/zerolog/log"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/util/deephash"
|
||||
)
|
||||
|
||||
func NewPolicyManagerFromPath(path string, users []types.User, nodes types.Nodes) (*PolicyManager, error) {
|
||||
policyFile, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer policyFile.Close()
|
||||
|
||||
policyBytes, err := io.ReadAll(policyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewPolicyManager(policyBytes, users, nodes)
|
||||
}
|
||||
|
||||
func NewPolicyManager(polB []byte, users []types.User, nodes types.Nodes) (*PolicyManager, error) {
|
||||
var pol *ACLPolicy
|
||||
var err error
|
||||
if polB != nil && len(polB) > 0 {
|
||||
pol, err = LoadACLPolicyFromBytes(polB)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing policy: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
pm := PolicyManager{
|
||||
pol: pol,
|
||||
users: users,
|
||||
nodes: nodes,
|
||||
}
|
||||
|
||||
_, err = pm.updateLocked()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &pm, nil
|
||||
}
|
||||
|
||||
type PolicyManager struct {
|
||||
mu sync.Mutex
|
||||
pol *ACLPolicy
|
||||
|
||||
users []types.User
|
||||
nodes types.Nodes
|
||||
|
||||
filterHash deephash.Sum
|
||||
filter []tailcfg.FilterRule
|
||||
}
|
||||
|
||||
// updateLocked updates the filter rules based on the current policy and nodes.
|
||||
// It must be called with the lock held.
|
||||
func (pm *PolicyManager) updateLocked() (bool, error) {
|
||||
filter, err := pm.pol.CompileFilterRules(pm.users, pm.nodes)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("compiling filter rules: %w", err)
|
||||
}
|
||||
|
||||
filterHash := deephash.Hash(&filter)
|
||||
if filterHash == pm.filterHash {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
pm.filter = filter
|
||||
pm.filterHash = filterHash
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) Filter() []tailcfg.FilterRule {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
return pm.filter
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
return pm.pol.CompileSSHPolicy(node, pm.users, pm.nodes)
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) SetPolicy(polB []byte) (bool, error) {
|
||||
if len(polB) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
pol, err := LoadACLPolicyFromBytes(polB)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("parsing policy: %w", err)
|
||||
}
|
||||
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
pm.pol = pol
|
||||
|
||||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
// SetUsers updates the users in the policy manager and updates the filter rules.
|
||||
func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
pm.users = users
|
||||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
// SetNodes updates the nodes in the policy manager and updates the filter rules.
|
||||
func (pm *PolicyManager) SetNodes(nodes types.Nodes) (bool, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
pm.nodes = nodes
|
||||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) NodeCanHaveTag(node *types.Node, tag string) bool {
|
||||
if pm == nil || pm.pol == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
tags, invalid := pm.pol.TagsOfNode(pm.users, node)
|
||||
log.Debug().Strs("authorised_tags", tags).Strs("unauthorised_tags", invalid).Uint64("node.id", node.ID.Uint64()).Msg("tags provided by policy")
|
||||
|
||||
for _, t := range tags {
|
||||
if t == tag {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) NodeCanApproveRoute(node *types.Node, route netip.Prefix) bool {
|
||||
if pm == nil || pm.pol == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
approvers, _ := pm.pol.AutoApprovers.GetRouteApprovers(route)
|
||||
|
||||
for _, approvedAlias := range approvers {
|
||||
if approvedAlias == node.User.Username() {
|
||||
return true
|
||||
} else {
|
||||
ips, err := pm.pol.ExpandAlias(pm.nodes, pm.users, approvedAlias)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// approvedIPs should contain all of node's IPs if it matches the rule, so check for first
|
||||
if ips.Contains(*node.IPv4) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) Version() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) DebugString() string {
|
||||
return "not implemented for v1"
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package policy
|
||||
package v1
|
||||
|
||||
import (
|
||||
"testing"
|
169
hscontrol/policy/v2/filter.go
Normal file
169
hscontrol/policy/v2/filter.go
Normal file
|
@ -0,0 +1,169 @@
|
|||
package v2
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"go4.org/netipx"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidAction = errors.New("invalid action")
|
||||
)
|
||||
|
||||
// compileFilterRules takes a set of nodes and an ACLPolicy and generates a
|
||||
// set of Tailscale compatible FilterRules used to allow traffic on clients.
|
||||
func (pol *Policy) compileFilterRules(
|
||||
users types.Users,
|
||||
nodes types.Nodes,
|
||||
) ([]tailcfg.FilterRule, error) {
|
||||
if pol == nil {
|
||||
return tailcfg.FilterAllowAll, nil
|
||||
}
|
||||
|
||||
var rules []tailcfg.FilterRule
|
||||
|
||||
for _, acl := range pol.ACLs {
|
||||
if acl.Action != "accept" {
|
||||
return nil, ErrInvalidAction
|
||||
}
|
||||
|
||||
srcIPs, err := acl.Sources.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
log.Trace().Err(err).Msgf("resolving source ips")
|
||||
}
|
||||
|
||||
if len(srcIPs.Prefixes()) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO(kradalby): integrate type into schema
|
||||
// TODO(kradalby): figure out the _ is wildcard stuff
|
||||
protocols, _, err := parseProtocol(acl.Protocol)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing policy, protocol err: %w ", err)
|
||||
}
|
||||
|
||||
var destPorts []tailcfg.NetPortRange
|
||||
for _, dest := range acl.Destinations {
|
||||
ips, err := dest.Alias.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
log.Trace().Err(err).Msgf("resolving destination ips")
|
||||
}
|
||||
|
||||
for _, pref := range ips.Prefixes() {
|
||||
for _, port := range dest.Ports {
|
||||
pr := tailcfg.NetPortRange{
|
||||
IP: pref.String(),
|
||||
Ports: port,
|
||||
}
|
||||
destPorts = append(destPorts, pr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(destPorts) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
rules = append(rules, tailcfg.FilterRule{
|
||||
SrcIPs: ipSetToPrefixStringList(srcIPs),
|
||||
DstPorts: destPorts,
|
||||
IPProto: protocols,
|
||||
})
|
||||
}
|
||||
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
func sshAction(accept bool, duration time.Duration) tailcfg.SSHAction {
|
||||
return tailcfg.SSHAction{
|
||||
Reject: !accept,
|
||||
Accept: accept,
|
||||
SessionDuration: duration,
|
||||
AllowAgentForwarding: true,
|
||||
AllowLocalPortForwarding: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (pol *Policy) compileSSHPolicy(
|
||||
users types.Users,
|
||||
node *types.Node,
|
||||
nodes types.Nodes,
|
||||
) (*tailcfg.SSHPolicy, error) {
|
||||
if pol == nil || pol.SSHs == nil || len(pol.SSHs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var rules []*tailcfg.SSHRule
|
||||
|
||||
for index, rule := range pol.SSHs {
|
||||
var dest netipx.IPSetBuilder
|
||||
for _, src := range rule.Destinations {
|
||||
ips, err := src.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
log.Trace().Err(err).Msgf("resolving destination ips")
|
||||
}
|
||||
dest.AddSet(ips)
|
||||
}
|
||||
|
||||
destSet, err := dest.IPSet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !node.InIPSet(destSet) {
|
||||
continue
|
||||
}
|
||||
|
||||
var action tailcfg.SSHAction
|
||||
switch rule.Action {
|
||||
case "accept":
|
||||
action = sshAction(true, 0)
|
||||
case "check":
|
||||
action = sshAction(true, rule.CheckPeriod)
|
||||
default:
|
||||
return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", rule.Action, index, err)
|
||||
}
|
||||
|
||||
var principals []*tailcfg.SSHPrincipal
|
||||
srcIPs, err := rule.Sources.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
log.Trace().Err(err).Msgf("resolving source ips")
|
||||
}
|
||||
|
||||
for addr := range util.IPSetAddrIter(srcIPs) {
|
||||
principals = append(principals, &tailcfg.SSHPrincipal{
|
||||
NodeIP: addr.String(),
|
||||
})
|
||||
}
|
||||
|
||||
userMap := make(map[string]string, len(rule.Users))
|
||||
for _, user := range rule.Users {
|
||||
userMap[user.String()] = "="
|
||||
}
|
||||
rules = append(rules, &tailcfg.SSHRule{
|
||||
Principals: principals,
|
||||
SSHUsers: userMap,
|
||||
Action: &action,
|
||||
})
|
||||
}
|
||||
|
||||
return &tailcfg.SSHPolicy{
|
||||
Rules: rules,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ipSetToPrefixStringList(ips *netipx.IPSet) []string {
|
||||
var out []string
|
||||
|
||||
for _, pref := range ips.Prefixes() {
|
||||
out = append(out, pref.String())
|
||||
}
|
||||
return out
|
||||
}
|
378
hscontrol/policy/v2/filter_test.go
Normal file
378
hscontrol/policy/v2/filter_test.go
Normal file
|
@ -0,0 +1,378 @@
|
|||
package v2
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
func TestParsing(t *testing.T) {
|
||||
users := types.Users{
|
||||
{Model: gorm.Model{ID: 1}, Name: "testuser"},
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
format string
|
||||
acl string
|
||||
want []tailcfg.FilterRule
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "invalid-hujson",
|
||||
format: "hujson",
|
||||
acl: `
|
||||
{
|
||||
`,
|
||||
want: []tailcfg.FilterRule{},
|
||||
wantErr: true,
|
||||
},
|
||||
// The new parser will ignore all that is irrelevant
|
||||
// {
|
||||
// name: "valid-hujson-invalid-content",
|
||||
// format: "hujson",
|
||||
// acl: `
|
||||
// {
|
||||
// "valid_json": true,
|
||||
// "but_a_policy_though": false
|
||||
// }
|
||||
// `,
|
||||
// want: []tailcfg.FilterRule{},
|
||||
// wantErr: true,
|
||||
// },
|
||||
// {
|
||||
// name: "invalid-cidr",
|
||||
// format: "hujson",
|
||||
// acl: `
|
||||
// {"example-host-1": "100.100.100.100/42"}
|
||||
// `,
|
||||
// want: []tailcfg.FilterRule{},
|
||||
// wantErr: true,
|
||||
// },
|
||||
{
|
||||
name: "basic-rule",
|
||||
format: "hujson",
|
||||
acl: `
|
||||
{
|
||||
"hosts": {
|
||||
"host-1": "100.100.100.100",
|
||||
"subnet-1": "100.100.101.100/24",
|
||||
},
|
||||
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": [
|
||||
"subnet-1",
|
||||
"192.168.1.0/24"
|
||||
],
|
||||
"dst": [
|
||||
"*:22,3389",
|
||||
"host-1:*",
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
`,
|
||||
want: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.100.101.0/24", "192.168.1.0/24"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "0.0.0.0/0", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
{IP: "0.0.0.0/0", Ports: tailcfg.PortRange{First: 3389, Last: 3389}},
|
||||
{IP: "::/0", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
{IP: "::/0", Ports: tailcfg.PortRange{First: 3389, Last: 3389}},
|
||||
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "parse-protocol",
|
||||
format: "hujson",
|
||||
acl: `
|
||||
{
|
||||
"hosts": {
|
||||
"host-1": "100.100.100.100",
|
||||
"subnet-1": "100.100.101.100/24",
|
||||
},
|
||||
|
||||
"acls": [
|
||||
{
|
||||
"Action": "accept",
|
||||
"src": [
|
||||
"*",
|
||||
],
|
||||
"proto": "tcp",
|
||||
"dst": [
|
||||
"host-1:*",
|
||||
],
|
||||
},
|
||||
{
|
||||
"Action": "accept",
|
||||
"src": [
|
||||
"*",
|
||||
],
|
||||
"proto": "udp",
|
||||
"dst": [
|
||||
"host-1:53",
|
||||
],
|
||||
},
|
||||
{
|
||||
"Action": "accept",
|
||||
"src": [
|
||||
"*",
|
||||
],
|
||||
"proto": "icmp",
|
||||
"dst": [
|
||||
"host-1:*",
|
||||
],
|
||||
},
|
||||
],
|
||||
}`,
|
||||
want: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"0.0.0.0/0", "::/0"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
|
||||
},
|
||||
IPProto: []int{protocolTCP},
|
||||
},
|
||||
{
|
||||
SrcIPs: []string{"0.0.0.0/0", "::/0"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.100.100.100/32", Ports: tailcfg.PortRange{First: 53, Last: 53}},
|
||||
},
|
||||
IPProto: []int{protocolUDP},
|
||||
},
|
||||
{
|
||||
SrcIPs: []string{"0.0.0.0/0", "::/0"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
|
||||
},
|
||||
IPProto: []int{protocolICMP, protocolIPv6ICMP},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "port-wildcard",
|
||||
format: "hujson",
|
||||
acl: `
|
||||
{
|
||||
"hosts": {
|
||||
"host-1": "100.100.100.100",
|
||||
"subnet-1": "100.100.101.100/24",
|
||||
},
|
||||
|
||||
"acls": [
|
||||
{
|
||||
"Action": "accept",
|
||||
"src": [
|
||||
"*",
|
||||
],
|
||||
"dst": [
|
||||
"host-1:*",
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
`,
|
||||
want: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"0.0.0.0/0", "::/0"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "port-range",
|
||||
format: "hujson",
|
||||
acl: `
|
||||
{
|
||||
"hosts": {
|
||||
"host-1": "100.100.100.100",
|
||||
"subnet-1": "100.100.101.100/24",
|
||||
},
|
||||
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": [
|
||||
"subnet-1",
|
||||
],
|
||||
"dst": [
|
||||
"host-1:5400-5500",
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
`,
|
||||
want: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.100.101.0/24"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{
|
||||
IP: "100.100.100.100/32",
|
||||
Ports: tailcfg.PortRange{First: 5400, Last: 5500},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "port-group",
|
||||
format: "hujson",
|
||||
acl: `
|
||||
{
|
||||
"groups": {
|
||||
"group:example": [
|
||||
"testuser@",
|
||||
],
|
||||
},
|
||||
|
||||
"hosts": {
|
||||
"host-1": "100.100.100.100",
|
||||
"subnet-1": "100.100.101.100/24",
|
||||
},
|
||||
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": [
|
||||
"group:example",
|
||||
],
|
||||
"dst": [
|
||||
"host-1:*",
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
`,
|
||||
want: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"200.200.200.200/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "port-user",
|
||||
format: "hujson",
|
||||
acl: `
|
||||
{
|
||||
"hosts": {
|
||||
"host-1": "100.100.100.100",
|
||||
"subnet-1": "100.100.101.100/24",
|
||||
},
|
||||
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": [
|
||||
"testuser@",
|
||||
],
|
||||
"dst": [
|
||||
"host-1:*",
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
`,
|
||||
want: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"200.200.200.200/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "ipv6",
|
||||
format: "hujson",
|
||||
acl: `
|
||||
{
|
||||
"hosts": {
|
||||
"host-1": "100.100.100.100/32",
|
||||
"subnet-1": "100.100.101.100/24",
|
||||
},
|
||||
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": [
|
||||
"*",
|
||||
],
|
||||
"dst": [
|
||||
"host-1:*",
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
`,
|
||||
want: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"0.0.0.0/0", "::/0"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pol, err := policyFromBytes([]byte(tt.acl))
|
||||
if tt.wantErr && err == nil {
|
||||
t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
||||
return
|
||||
} else if !tt.wantErr && err != nil {
|
||||
t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
rules, err := pol.compileFilterRules(
|
||||
users,
|
||||
types.Nodes{
|
||||
&types.Node{
|
||||
IPv4: ap("100.100.100.100"),
|
||||
},
|
||||
&types.Node{
|
||||
IPv4: ap("200.200.200.200"),
|
||||
User: users[0],
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
},
|
||||
})
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.want, rules); diff != "" {
|
||||
t.Errorf("parsing() unexpected result (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
283
hscontrol/policy/v2/policy.go
Normal file
283
hscontrol/policy/v2/policy.go
Normal file
|
@ -0,0 +1,283 @@
|
|||
package v2
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"go4.org/netipx"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/util/deephash"
|
||||
)
|
||||
|
||||
type PolicyManager struct {
|
||||
mu sync.Mutex
|
||||
pol *Policy
|
||||
users []types.User
|
||||
nodes types.Nodes
|
||||
|
||||
filterHash deephash.Sum
|
||||
filter []tailcfg.FilterRule
|
||||
|
||||
tagOwnerMapHash deephash.Sum
|
||||
tagOwnerMap map[Tag]*netipx.IPSet
|
||||
|
||||
autoApproveMapHash deephash.Sum
|
||||
autoApproveMap map[netip.Prefix]*netipx.IPSet
|
||||
|
||||
// Lazy map of SSH policies
|
||||
sshPolicyMap map[types.NodeID]*tailcfg.SSHPolicy
|
||||
}
|
||||
|
||||
// NewPolicyManager creates a new PolicyManager from a policy file and a list of users and nodes.
|
||||
// It returns an error if the policy file is invalid.
|
||||
// The policy manager will update the filter rules based on the users and nodes.
|
||||
func NewPolicyManager(b []byte, users []types.User, nodes types.Nodes) (*PolicyManager, error) {
|
||||
policy, err := policyFromBytes(b)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing policy: %w", err)
|
||||
}
|
||||
|
||||
pm := PolicyManager{
|
||||
pol: policy,
|
||||
users: users,
|
||||
nodes: nodes,
|
||||
sshPolicyMap: make(map[types.NodeID]*tailcfg.SSHPolicy, len(nodes)),
|
||||
}
|
||||
|
||||
_, err = pm.updateLocked()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &pm, nil
|
||||
}
|
||||
|
||||
// updateLocked updates the filter rules based on the current policy and nodes.
|
||||
// It must be called with the lock held.
|
||||
func (pm *PolicyManager) updateLocked() (bool, error) {
|
||||
filter, err := pm.pol.compileFilterRules(pm.users, pm.nodes)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("compiling filter rules: %w", err)
|
||||
}
|
||||
|
||||
filterHash := deephash.Hash(&filter)
|
||||
filterChanged := filterHash == pm.filterHash
|
||||
pm.filter = filter
|
||||
pm.filterHash = filterHash
|
||||
|
||||
// Order matters, tags might be used in autoapprovers, so we need to ensure
|
||||
// that the map for tag owners is resolved before resolving autoapprovers.
|
||||
// TODO(kradalby): Order might not matter after #2417
|
||||
tagMap, err := resolveTagOwners(pm.pol, pm.users, pm.nodes)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("resolving tag owners map: %w", err)
|
||||
}
|
||||
|
||||
tagOwnerMapHash := deephash.Hash(&tagMap)
|
||||
tagOwnerChanged := tagOwnerMapHash != pm.tagOwnerMapHash
|
||||
pm.tagOwnerMap = tagMap
|
||||
pm.tagOwnerMapHash = tagOwnerMapHash
|
||||
|
||||
autoMap, err := resolveAutoApprovers(pm.pol, pm.users, pm.nodes)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("resolving auto approvers map: %w", err)
|
||||
}
|
||||
|
||||
autoApproveMapHash := deephash.Hash(&autoMap)
|
||||
autoApproveChanged := autoApproveMapHash != pm.autoApproveMapHash
|
||||
pm.autoApproveMap = autoMap
|
||||
pm.autoApproveMapHash = autoApproveMapHash
|
||||
|
||||
// If neither of the calculated values changed, no need to update nodes
|
||||
if !filterChanged && !tagOwnerChanged && !autoApproveChanged {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Clear the SSH policy map to ensure it's recalculated with the new policy.
|
||||
// TODO(kradalby): This could potentially be optimized by only clearing the
|
||||
// policies for nodes that have changed. Particularly if the only difference is
|
||||
// that nodes has been added or removed.
|
||||
clear(pm.sshPolicyMap)
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
if sshPol, ok := pm.sshPolicyMap[node.ID]; ok {
|
||||
return sshPol, nil
|
||||
}
|
||||
|
||||
sshPol, err := pm.pol.compileSSHPolicy(pm.users, node, pm.nodes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("compiling SSH policy: %w", err)
|
||||
}
|
||||
pm.sshPolicyMap[node.ID] = sshPol
|
||||
|
||||
return sshPol, nil
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) SetPolicy(polB []byte) (bool, error) {
|
||||
if len(polB) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
pol, err := policyFromBytes(polB)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("parsing policy: %w", err)
|
||||
}
|
||||
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
pm.pol = pol
|
||||
|
||||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
// Filter returns the current filter rules for the entire tailnet.
|
||||
func (pm *PolicyManager) Filter() []tailcfg.FilterRule {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
return pm.filter
|
||||
}
|
||||
|
||||
// SetUsers updates the users in the policy manager and updates the filter rules.
|
||||
func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
pm.users = users
|
||||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
// SetNodes updates the nodes in the policy manager and updates the filter rules.
|
||||
func (pm *PolicyManager) SetNodes(nodes types.Nodes) (bool, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
pm.nodes = nodes
|
||||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) NodeCanHaveTag(node *types.Node, tag string) bool {
|
||||
if pm == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
if ips, ok := pm.tagOwnerMap[Tag(tag)]; ok {
|
||||
for _, nodeAddr := range node.IPs() {
|
||||
if ips.Contains(nodeAddr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) NodeCanApproveRoute(node *types.Node, route netip.Prefix) bool {
|
||||
if pm == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
// The fast path is that a node requests to approve a prefix
|
||||
// where there is an exact entry, e.g. 10.0.0.0/8, then
|
||||
// check and return quickly
|
||||
if _, ok := pm.autoApproveMap[route]; ok {
|
||||
for _, nodeAddr := range node.IPs() {
|
||||
if pm.autoApproveMap[route].Contains(nodeAddr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// The slow path is that the node tries to approve
|
||||
// 10.0.10.0/24, which is a part of 10.0.0.0/8, then we
|
||||
// cannot just lookup in the prefix map and have to check
|
||||
// if there is a "parent" prefix available.
|
||||
for prefix, approveAddrs := range pm.autoApproveMap {
|
||||
// We do not want the exit node entry to approve all
|
||||
// sorts of routes. The logic here is that it would be
|
||||
// unexpected behaviour to have specific routes approved
|
||||
// just because the node is allowed to designate itself as
|
||||
// an exit.
|
||||
if tsaddr.IsExitRoute(prefix) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if prefix is larger (so containing) and then overlaps
|
||||
// the route to see if the node can approve a subset of an autoapprover
|
||||
if prefix.Bits() <= route.Bits() && prefix.Overlaps(route) {
|
||||
for _, nodeAddr := range node.IPs() {
|
||||
if approveAddrs.Contains(nodeAddr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) Version() int {
|
||||
return 2
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) DebugString() string {
|
||||
var sb strings.Builder
|
||||
|
||||
fmt.Fprintf(&sb, "PolicyManager (v%d):\n\n", pm.Version())
|
||||
|
||||
sb.WriteString("\n\n")
|
||||
|
||||
if pm.pol != nil {
|
||||
pol, err := json.MarshalIndent(pm.pol, "", " ")
|
||||
if err == nil {
|
||||
sb.WriteString("Policy:\n")
|
||||
sb.Write(pol)
|
||||
sb.WriteString("\n\n")
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintf(&sb, "AutoApprover (%d):\n", len(pm.autoApproveMap))
|
||||
for prefix, approveAddrs := range pm.autoApproveMap {
|
||||
fmt.Fprintf(&sb, "\t%s:\n", prefix)
|
||||
for _, iprange := range approveAddrs.Ranges() {
|
||||
fmt.Fprintf(&sb, "\t\t%s\n", iprange)
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("\n\n")
|
||||
|
||||
fmt.Fprintf(&sb, "TagOwner (%d):\n", len(pm.tagOwnerMap))
|
||||
for prefix, tagOwners := range pm.tagOwnerMap {
|
||||
fmt.Fprintf(&sb, "\t%s:\n", prefix)
|
||||
for _, iprange := range tagOwners.Ranges() {
|
||||
fmt.Fprintf(&sb, "\t\t%s\n", iprange)
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("\n\n")
|
||||
if pm.filter != nil {
|
||||
filter, err := json.MarshalIndent(pm.filter, "", " ")
|
||||
if err == nil {
|
||||
sb.WriteString("Compiled filter:\n")
|
||||
sb.Write(filter)
|
||||
sb.WriteString("\n\n")
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
58
hscontrol/policy/v2/policy_test.go
Normal file
58
hscontrol/policy/v2/policy_test.go
Normal file
|
@ -0,0 +1,58 @@
|
|||
package v2
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo) *types.Node {
|
||||
return &types.Node{
|
||||
ID: 0,
|
||||
Hostname: name,
|
||||
IPv4: ap(ipv4),
|
||||
IPv6: ap(ipv6),
|
||||
User: user,
|
||||
UserID: user.ID,
|
||||
Hostinfo: hostinfo,
|
||||
}
|
||||
}
|
||||
|
||||
func TestPolicyManager(t *testing.T) {
|
||||
users := types.Users{
|
||||
{Model: gorm.Model{ID: 1}, Name: "testuser", Email: "testuser@headscale.net"},
|
||||
{Model: gorm.Model{ID: 2}, Name: "otheruser", Email: "otheruser@headscale.net"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
pol string
|
||||
nodes types.Nodes
|
||||
wantFilter []tailcfg.FilterRule
|
||||
}{
|
||||
{
|
||||
name: "empty-policy",
|
||||
pol: "{}",
|
||||
nodes: types.Nodes{},
|
||||
wantFilter: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pm, err := NewPolicyManager([]byte(tt.pol), users, tt.nodes)
|
||||
require.NoError(t, err)
|
||||
|
||||
filter := pm.Filter()
|
||||
if diff := cmp.Diff(filter, tt.wantFilter); diff != "" {
|
||||
t.Errorf("Filter() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
// TODO(kradalby): Test SSH Policy
|
||||
})
|
||||
}
|
||||
}
|
1005
hscontrol/policy/v2/types.go
Normal file
1005
hscontrol/policy/v2/types.go
Normal file
File diff suppressed because it is too large
Load diff
1162
hscontrol/policy/v2/types_test.go
Normal file
1162
hscontrol/policy/v2/types_test.go
Normal file
File diff suppressed because it is too large
Load diff
164
hscontrol/policy/v2/utils.go
Normal file
164
hscontrol/policy/v2/utils.go
Normal file
|
@ -0,0 +1,164 @@
|
|||
package v2
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
// splitDestinationAndPort takes an input string and returns the destination and port as a tuple, or an error if the input is invalid.
|
||||
func splitDestinationAndPort(input string) (string, string, error) {
|
||||
// Find the last occurrence of the colon character
|
||||
lastColonIndex := strings.LastIndex(input, ":")
|
||||
|
||||
// Check if the colon character is present and not at the beginning or end of the string
|
||||
if lastColonIndex == -1 {
|
||||
return "", "", errors.New("input must contain a colon character separating destination and port")
|
||||
}
|
||||
if lastColonIndex == 0 {
|
||||
return "", "", errors.New("input cannot start with a colon character")
|
||||
}
|
||||
if lastColonIndex == len(input)-1 {
|
||||
return "", "", errors.New("input cannot end with a colon character")
|
||||
}
|
||||
|
||||
// Split the string into destination and port based on the last colon
|
||||
destination := input[:lastColonIndex]
|
||||
port := input[lastColonIndex+1:]
|
||||
|
||||
return destination, port, nil
|
||||
}
|
||||
|
||||
// parsePortRange parses a port definition string and returns a slice of PortRange structs.
|
||||
func parsePortRange(portDef string) ([]tailcfg.PortRange, error) {
|
||||
if portDef == "*" {
|
||||
return []tailcfg.PortRange{tailcfg.PortRangeAny}, nil
|
||||
}
|
||||
|
||||
var portRanges []tailcfg.PortRange
|
||||
parts := strings.Split(portDef, ",")
|
||||
|
||||
for _, part := range parts {
|
||||
if strings.Contains(part, "-") {
|
||||
rangeParts := strings.Split(part, "-")
|
||||
rangeParts = slices.DeleteFunc(rangeParts, func(e string) bool {
|
||||
return e == ""
|
||||
})
|
||||
if len(rangeParts) != 2 {
|
||||
return nil, errors.New("invalid port range format")
|
||||
}
|
||||
|
||||
first, err := parsePort(rangeParts[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
last, err := parsePort(rangeParts[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if first > last {
|
||||
return nil, errors.New("invalid port range: first port is greater than last port")
|
||||
}
|
||||
|
||||
portRanges = append(portRanges, tailcfg.PortRange{First: first, Last: last})
|
||||
} else {
|
||||
port, err := parsePort(part)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
portRanges = append(portRanges, tailcfg.PortRange{First: port, Last: port})
|
||||
}
|
||||
}
|
||||
|
||||
return portRanges, nil
|
||||
}
|
||||
|
||||
// parsePort parses a single port number from a string.
|
||||
func parsePort(portStr string) (uint16, error) {
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return 0, errors.New("invalid port number")
|
||||
}
|
||||
|
||||
if port < 0 || port > 65535 {
|
||||
return 0, errors.New("port number out of range")
|
||||
}
|
||||
|
||||
return uint16(port), nil
|
||||
}
|
||||
|
||||
// For some reason golang.org/x/net/internal/iana is an internal package.
|
||||
const (
|
||||
protocolICMP = 1 // Internet Control Message
|
||||
protocolIGMP = 2 // Internet Group Management
|
||||
protocolIPv4 = 4 // IPv4 encapsulation
|
||||
protocolTCP = 6 // Transmission Control
|
||||
protocolEGP = 8 // Exterior Gateway Protocol
|
||||
protocolIGP = 9 // any private interior gateway (used by Cisco for their IGRP)
|
||||
protocolUDP = 17 // User Datagram
|
||||
protocolGRE = 47 // Generic Routing Encapsulation
|
||||
protocolESP = 50 // Encap Security Payload
|
||||
protocolAH = 51 // Authentication Header
|
||||
protocolIPv6ICMP = 58 // ICMP for IPv6
|
||||
protocolSCTP = 132 // Stream Control Transmission Protocol
|
||||
ProtocolFC = 133 // Fibre Channel
|
||||
)
|
||||
|
||||
// parseProtocol reads the proto field of the ACL and generates a list of
|
||||
// protocols that will be allowed, following the IANA IP protocol number
|
||||
// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml
|
||||
//
|
||||
// If the ACL proto field is empty, it allows ICMPv4, ICMPv6, TCP, and UDP,
|
||||
// as per Tailscale behaviour (see tailcfg.FilterRule).
|
||||
//
|
||||
// Also returns a boolean indicating if the protocol
|
||||
// requires all the destinations to use wildcard as port number (only TCP,
|
||||
// UDP and SCTP support specifying ports).
|
||||
func parseProtocol(protocol string) ([]int, bool, error) {
|
||||
switch protocol {
|
||||
case "":
|
||||
return nil, false, nil
|
||||
case "igmp":
|
||||
return []int{protocolIGMP}, true, nil
|
||||
case "ipv4", "ip-in-ip":
|
||||
return []int{protocolIPv4}, true, nil
|
||||
case "tcp":
|
||||
return []int{protocolTCP}, false, nil
|
||||
case "egp":
|
||||
return []int{protocolEGP}, true, nil
|
||||
case "igp":
|
||||
return []int{protocolIGP}, true, nil
|
||||
case "udp":
|
||||
return []int{protocolUDP}, false, nil
|
||||
case "gre":
|
||||
return []int{protocolGRE}, true, nil
|
||||
case "esp":
|
||||
return []int{protocolESP}, true, nil
|
||||
case "ah":
|
||||
return []int{protocolAH}, true, nil
|
||||
case "sctp":
|
||||
return []int{protocolSCTP}, false, nil
|
||||
case "icmp":
|
||||
return []int{protocolICMP, protocolIPv6ICMP}, true, nil
|
||||
|
||||
default:
|
||||
protocolNumber, err := strconv.Atoi(protocol)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("parsing protocol number: %w", err)
|
||||
}
|
||||
|
||||
// TODO(kradalby): What is this?
|
||||
needsWildcard := protocolNumber != protocolTCP &&
|
||||
protocolNumber != protocolUDP &&
|
||||
protocolNumber != protocolSCTP
|
||||
|
||||
return []int{protocolNumber}, needsWildcard, nil
|
||||
}
|
||||
}
|
102
hscontrol/policy/v2/utils_test.go
Normal file
102
hscontrol/policy/v2/utils_test.go
Normal file
|
@ -0,0 +1,102 @@
|
|||
package v2
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
// TestParseDestinationAndPort tests the parseDestinationAndPort function using table-driven tests.
|
||||
func TestParseDestinationAndPort(t *testing.T) {
|
||||
testCases := []struct {
|
||||
input string
|
||||
expectedDst string
|
||||
expectedPort string
|
||||
expectedErr error
|
||||
}{
|
||||
{"git-server:*", "git-server", "*", nil},
|
||||
{"192.168.1.0/24:22", "192.168.1.0/24", "22", nil},
|
||||
{"fd7a:115c:a1e0::2:22", "fd7a:115c:a1e0::2", "22", nil},
|
||||
{"fd7a:115c:a1e0::2/128:22", "fd7a:115c:a1e0::2/128", "22", nil},
|
||||
{"tag:montreal-webserver:80,443", "tag:montreal-webserver", "80,443", nil},
|
||||
{"tag:api-server:443", "tag:api-server", "443", nil},
|
||||
{"example-host-1:*", "example-host-1", "*", nil},
|
||||
{"hostname:80-90", "hostname", "80-90", nil},
|
||||
{"invalidinput", "", "", errors.New("input must contain a colon character separating destination and port")},
|
||||
{":invalid", "", "", errors.New("input cannot start with a colon character")},
|
||||
{"invalid:", "", "", errors.New("input cannot end with a colon character")},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
dst, port, err := splitDestinationAndPort(testCase.input)
|
||||
if dst != testCase.expectedDst || port != testCase.expectedPort || (err != nil && err.Error() != testCase.expectedErr.Error()) {
|
||||
t.Errorf("parseDestinationAndPort(%q) = (%q, %q, %v), want (%q, %q, %v)",
|
||||
testCase.input, dst, port, err, testCase.expectedDst, testCase.expectedPort, testCase.expectedErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePort(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected uint16
|
||||
err string
|
||||
}{
|
||||
{"80", 80, ""},
|
||||
{"0", 0, ""},
|
||||
{"65535", 65535, ""},
|
||||
{"-1", 0, "port number out of range"},
|
||||
{"65536", 0, "port number out of range"},
|
||||
{"abc", 0, "invalid port number"},
|
||||
{"", 0, "invalid port number"},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
result, err := parsePort(test.input)
|
||||
if err != nil && err.Error() != test.err {
|
||||
t.Errorf("parsePort(%q) error = %v, expected error = %v", test.input, err, test.err)
|
||||
}
|
||||
if err == nil && test.err != "" {
|
||||
t.Errorf("parsePort(%q) expected error = %v, got nil", test.input, test.err)
|
||||
}
|
||||
if result != test.expected {
|
||||
t.Errorf("parsePort(%q) = %v, expected %v", test.input, result, test.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePortRange(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected []tailcfg.PortRange
|
||||
err string
|
||||
}{
|
||||
{"80", []tailcfg.PortRange{{80, 80}}, ""},
|
||||
{"80-90", []tailcfg.PortRange{{80, 90}}, ""},
|
||||
{"80,90", []tailcfg.PortRange{{80, 80}, {90, 90}}, ""},
|
||||
{"80-91,92,93-95", []tailcfg.PortRange{{80, 91}, {92, 92}, {93, 95}}, ""},
|
||||
{"*", []tailcfg.PortRange{tailcfg.PortRangeAny}, ""},
|
||||
{"80-", nil, "invalid port range format"},
|
||||
{"-90", nil, "invalid port range format"},
|
||||
{"80-90,", nil, "invalid port number"},
|
||||
{"80,90-", nil, "invalid port range format"},
|
||||
{"80-90,abc", nil, "invalid port number"},
|
||||
{"80-90,65536", nil, "port number out of range"},
|
||||
{"80-90,90-80", nil, "invalid port range: first port is greater than last port"},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
result, err := parsePortRange(test.input)
|
||||
if err != nil && err.Error() != test.err {
|
||||
t.Errorf("parsePortRange(%q) error = %v, expected error = %v", test.input, err, test.err)
|
||||
}
|
||||
if err == nil && test.err != "" {
|
||||
t.Errorf("parsePortRange(%q) expected error = %v, got nil", test.input, test.err)
|
||||
}
|
||||
if diff := cmp.Diff(result, test.expected); diff != "" {
|
||||
t.Errorf("parsePortRange(%q) mismatch (-want +got):\n%s", test.input, diff)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -10,10 +10,9 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/mapper"
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/samber/lo"
|
||||
"github.com/sasha-s/go-deadlock"
|
||||
xslices "golang.org/x/exp/slices"
|
||||
"tailscale.com/net/tsaddr"
|
||||
|
@ -459,25 +458,10 @@ func (m *mapSession) handleEndpointUpdate() {
|
|||
// TODO(kradalby): I am not sure if we need this?
|
||||
nodesChangedHook(m.h.db, m.h.polMan, m.h.nodeNotifier)
|
||||
|
||||
// Take all the routes presented to us by the node and check
|
||||
// if any of them should be auto approved by the policy.
|
||||
// If any of them are, add them to the approved routes of the node.
|
||||
// Keep all the old entries and compact the list to remove duplicates.
|
||||
var newApproved []netip.Prefix
|
||||
for _, route := range m.node.Hostinfo.RoutableIPs {
|
||||
if m.h.polMan.NodeCanApproveRoute(m.node, route) {
|
||||
newApproved = append(newApproved, route)
|
||||
}
|
||||
}
|
||||
if newApproved != nil {
|
||||
newApproved = append(newApproved, m.node.ApprovedRoutes...)
|
||||
slices.SortFunc(newApproved, util.ComparePrefix)
|
||||
slices.Compact(newApproved)
|
||||
newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool {
|
||||
return route.IsValid()
|
||||
})
|
||||
m.node.ApprovedRoutes = newApproved
|
||||
|
||||
// Approve routes if they are auto-approved by the policy.
|
||||
// If any of them are approved, report them to the primary route tracker
|
||||
// and send updates accordingly.
|
||||
if policy.AutoApproveRoutes(m.h.polMan, m.node) {
|
||||
if m.h.primaryRoutes.SetRoutes(m.node.ID, m.node.SubnetRoutes()...) {
|
||||
ctx := types.NotifyCtx(m.ctx, "poll-primary-change", m.node.Hostname)
|
||||
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
|
|
|
@ -150,6 +150,68 @@ func (node *Node) IPs() []netip.Addr {
|
|||
return ret
|
||||
}
|
||||
|
||||
// HasIP reports if a node has a given IP address.
|
||||
func (node *Node) HasIP(i netip.Addr) bool {
|
||||
for _, ip := range node.IPs() {
|
||||
if ip.Compare(i) == 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsTagged reports if a device is tagged
|
||||
// and therefore should not be treated as a
|
||||
// user owned device.
|
||||
// Currently, this function only handles tags set
|
||||
// via CLI ("forced tags" and preauthkeys)
|
||||
func (node *Node) IsTagged() bool {
|
||||
if len(node.ForcedTags) > 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
if node.AuthKey != nil && len(node.AuthKey.Tags) > 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
if node.Hostinfo == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// TODO(kradalby): Figure out how tagging should work
|
||||
// and hostinfo.requestedtags.
|
||||
// Do this in other work.
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// HasTag reports if a node has a given tag.
|
||||
// Currently, this function only handles tags set
|
||||
// via CLI ("forced tags" and preauthkeys)
|
||||
func (node *Node) HasTag(tag string) bool {
|
||||
if slices.Contains(node.ForcedTags, tag) {
|
||||
return true
|
||||
}
|
||||
|
||||
if node.AuthKey != nil && slices.Contains(node.AuthKey.Tags, tag) {
|
||||
return true
|
||||
}
|
||||
|
||||
// TODO(kradalby): Figure out how tagging should work
|
||||
// and hostinfo.requestedtags.
|
||||
// Do this in other work.
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (node *Node) RequestTags() []string {
|
||||
if node.Hostinfo == nil {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
return node.Hostinfo.RequestTags
|
||||
}
|
||||
|
||||
func (node *Node) Prefixes() []netip.Prefix {
|
||||
addrs := []netip.Prefix{}
|
||||
for _, nodeAddress := range node.IPs() {
|
||||
|
@ -163,12 +225,8 @@ func (node *Node) Prefixes() []netip.Prefix {
|
|||
func (node *Node) IPsAsString() []string {
|
||||
var ret []string
|
||||
|
||||
if node.IPv4 != nil {
|
||||
ret = append(ret, node.IPv4.String())
|
||||
}
|
||||
|
||||
if node.IPv6 != nil {
|
||||
ret = append(ret, node.IPv6.String())
|
||||
for _, ip := range node.IPs() {
|
||||
ret = append(ret, ip.String())
|
||||
}
|
||||
|
||||
return ret
|
||||
|
@ -335,9 +393,9 @@ func (node *Node) SubnetRoutes() []netip.Prefix {
|
|||
return routes
|
||||
}
|
||||
|
||||
// func (node *Node) String() string {
|
||||
// return node.Hostname
|
||||
// }
|
||||
func (node *Node) String() string {
|
||||
return node.Hostname
|
||||
}
|
||||
|
||||
// PeerChangeFromMapRequest takes a MapRequest and compares it to the node
|
||||
// to produce a PeerChange struct that can be used to updated the node and
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"fmt"
|
||||
"net/mail"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
|
@ -18,6 +19,19 @@ import (
|
|||
|
||||
type UserID uint64
|
||||
|
||||
type Users []User
|
||||
|
||||
func (u Users) String() string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("[ ")
|
||||
for _, user := range u {
|
||||
fmt.Fprintf(&sb, "%d: %s, ", user.ID, user.Name)
|
||||
}
|
||||
sb.WriteString(" ]")
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// User is the way Headscale implements the concept of users in Tailscale
|
||||
//
|
||||
// At the end of the day, users in Tailscale are some kind of 'bubbles' or users
|
||||
|
@ -74,12 +88,13 @@ func (u *User) Username() string {
|
|||
u.Email,
|
||||
u.Name,
|
||||
u.ProviderIdentifier.String,
|
||||
u.StringID())
|
||||
u.StringID(),
|
||||
)
|
||||
}
|
||||
|
||||
// DisplayNameOrUsername returns the DisplayName if it exists, otherwise
|
||||
// Display returns the DisplayName if it exists, otherwise
|
||||
// it will return the Username.
|
||||
func (u *User) DisplayNameOrUsername() string {
|
||||
func (u *User) Display() string {
|
||||
return cmp.Or(u.DisplayName, u.Username())
|
||||
}
|
||||
|
||||
|
@ -91,7 +106,7 @@ func (u *User) profilePicURL() string {
|
|||
func (u *User) TailscaleUser() *tailcfg.User {
|
||||
user := tailcfg.User{
|
||||
ID: tailcfg.UserID(u.ID),
|
||||
DisplayName: u.DisplayNameOrUsername(),
|
||||
DisplayName: u.Display(),
|
||||
ProfilePicURL: u.profilePicURL(),
|
||||
Created: u.CreatedAt,
|
||||
}
|
||||
|
@ -101,11 +116,10 @@ func (u *User) TailscaleUser() *tailcfg.User {
|
|||
|
||||
func (u *User) TailscaleLogin() *tailcfg.Login {
|
||||
login := tailcfg.Login{
|
||||
ID: tailcfg.LoginID(u.ID),
|
||||
// TODO(kradalby): this should reflect registration method.
|
||||
ID: tailcfg.LoginID(u.ID),
|
||||
Provider: u.Provider,
|
||||
LoginName: u.Username(),
|
||||
DisplayName: u.DisplayNameOrUsername(),
|
||||
DisplayName: u.Display(),
|
||||
ProfilePicURL: u.profilePicURL(),
|
||||
}
|
||||
|
||||
|
@ -116,7 +130,7 @@ func (u *User) TailscaleUserProfile() tailcfg.UserProfile {
|
|||
return tailcfg.UserProfile{
|
||||
ID: tailcfg.UserID(u.ID),
|
||||
LoginName: u.Username(),
|
||||
DisplayName: u.DisplayNameOrUsername(),
|
||||
DisplayName: u.Display(),
|
||||
ProfilePicURL: u.profilePicURL(),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package util
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"iter"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
|
@ -111,3 +112,16 @@ func StringToIPPrefix(prefixes []string) ([]netip.Prefix, error) {
|
|||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// IPSetAddrIter returns a function that iterates over all the IPs in the IPSet.
|
||||
func IPSetAddrIter(ipSet *netipx.IPSet) iter.Seq[netip.Addr] {
|
||||
return func(yield func(netip.Addr) bool) {
|
||||
for _, rng := range ipSet.Ranges() {
|
||||
for ip := rng.From(); ip.Compare(rng.To()) <= 0; ip = ip.Next() {
|
||||
if !yield(ip) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"go4.org/netipx"
|
||||
"tailscale.com/net/tsaddr"
|
||||
)
|
||||
|
||||
func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) {
|
||||
|
@ -13,24 +16,6 @@ func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) {
|
|||
return d.DialContext(ctx, "unix", addr)
|
||||
}
|
||||
|
||||
// TODO(kradalby): Remove when in stdlib;
|
||||
// https://github.com/golang/go/issues/61642
|
||||
// Compare returns an integer comparing two prefixes.
|
||||
// The result will be 0 if p == p2, -1 if p < p2, and +1 if p > p2.
|
||||
// Prefixes sort first by validity (invalid before valid), then
|
||||
// address family (IPv4 before IPv6), then prefix length, then
|
||||
// address.
|
||||
func ComparePrefix(p, p2 netip.Prefix) int {
|
||||
if c := cmp.Compare(p.Addr().BitLen(), p2.Addr().BitLen()); c != 0 {
|
||||
return c
|
||||
}
|
||||
if c := cmp.Compare(p.Bits(), p2.Bits()); c != 0 {
|
||||
return c
|
||||
}
|
||||
|
||||
return p.Addr().Compare(p2.Addr())
|
||||
}
|
||||
|
||||
func PrefixesToString(prefixes []netip.Prefix) []string {
|
||||
ret := make([]string, 0, len(prefixes))
|
||||
for _, prefix := range prefixes {
|
||||
|
@ -49,3 +34,29 @@ func MustStringsToPrefixes(strings []string) []netip.Prefix {
|
|||
|
||||
return ret
|
||||
}
|
||||
|
||||
// TheInternet returns the IPSet for the Internet.
|
||||
// https://www.youtube.com/watch?v=iDbyYGrswtg
|
||||
var TheInternet = sync.OnceValue(func() *netipx.IPSet {
|
||||
var internetBuilder netipx.IPSetBuilder
|
||||
internetBuilder.AddPrefix(netip.MustParsePrefix("2000::/3"))
|
||||
internetBuilder.AddPrefix(tsaddr.AllIPv4())
|
||||
|
||||
// Delete Private network addresses
|
||||
// https://datatracker.ietf.org/doc/html/rfc1918
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("fc00::/7"))
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("10.0.0.0/8"))
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("172.16.0.0/12"))
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("192.168.0.0/16"))
|
||||
|
||||
// Delete Tailscale networks
|
||||
internetBuilder.RemovePrefix(tsaddr.TailscaleULARange())
|
||||
internetBuilder.RemovePrefix(tsaddr.CGNATRange())
|
||||
|
||||
// Delete "can't find DHCP networks"
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("fe80::/10")) // link-local
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("169.254.0.0/16"))
|
||||
|
||||
theInternetSet, _ := internetBuilder.IPSet()
|
||||
return theInternetSet
|
||||
})
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue