use tsaddr library and cleanups (#2150)
* resuse tsaddr code instead of handrolled Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * ensure we dont give out internal tailscale IPs Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * use prefix instead of string for routes Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * remove old custom compare func Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * trim unused util code Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> --------- Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
63035cdb5a
commit
3964dec1c6
19 changed files with 123 additions and 153 deletions
|
@ -14,6 +14,7 @@ import (
|
|||
"github.com/rs/zerolog/log"
|
||||
"go4.org/netipx"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/net/tsaddr"
|
||||
)
|
||||
|
||||
// IPAllocator is a singleton responsible for allocating
|
||||
|
@ -190,8 +191,9 @@ func (i *IPAllocator) next(prev netip.Addr, prefix *netip.Prefix) (*netip.Addr,
|
|||
return nil, ErrCouldNotAllocateIP
|
||||
}
|
||||
|
||||
// Check if the IP has already been allocated.
|
||||
if set.Contains(ip) {
|
||||
// Check if the IP has already been allocated
|
||||
// or if it is a IP reserved by Tailscale.
|
||||
if set.Contains(ip) || isTailscaleReservedIP(ip) {
|
||||
switch i.strategy {
|
||||
case types.IPAllocationStrategySequential:
|
||||
ip = ip.Next()
|
||||
|
@ -248,6 +250,12 @@ func randomNext(pfx netip.Prefix) (netip.Addr, error) {
|
|||
return ip, nil
|
||||
}
|
||||
|
||||
func isTailscaleReservedIP(ip netip.Addr) bool {
|
||||
return tsaddr.ChromeOSVMRange().Contains(ip) ||
|
||||
tsaddr.TailscaleServiceIP() == ip ||
|
||||
tsaddr.TailscaleServiceIPv6() == ip
|
||||
}
|
||||
|
||||
// BackfillNodeIPs will take a database transaction, and
|
||||
// iterate through all of the current nodes in headscale
|
||||
// and ensure it has IP addresses according to the current
|
||||
|
|
|
@ -12,6 +12,9 @@ import (
|
|||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
var mpp = func(pref string) *netip.Prefix {
|
||||
|
@ -514,3 +517,26 @@ func TestBackfillIPAddresses(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPAllocatorNextNoReservedIPs(t *testing.T) {
|
||||
alloc, err := NewIPAllocator(db, ptr.To(tsaddr.CGNATRange()), ptr.To(tsaddr.TailscaleULARange()), types.IPAllocationStrategySequential)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to set up ip alloc: %s", err)
|
||||
}
|
||||
|
||||
// Validate that we do not give out 100.100.100.100
|
||||
nextQuad100, err := alloc.next(na("100.100.100.99"), ptr.To(tsaddr.CGNATRange()))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, na("100.100.100.101"), *nextQuad100)
|
||||
|
||||
// Validate that we do not give out fd7a:115c:a1e0::53
|
||||
nextQuad100v6, err := alloc.next(na("fd7a:115c:a1e0::52"), ptr.To(tsaddr.TailscaleULARange()))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, na("fd7a:115c:a1e0::54"), *nextQuad100v6)
|
||||
|
||||
// Validate that we do not give out fd7a:115c:a1e0::53
|
||||
nextChrome, err := alloc.next(na("100.115.91.255"), ptr.To(tsaddr.CGNATRange()))
|
||||
t.Logf("chrome: %s", nextChrome.String())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, na("100.115.94.0"), *nextChrome)
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -215,7 +216,7 @@ func SetTags(
|
|||
|
||||
var newTags types.StringList
|
||||
for _, tag := range tags {
|
||||
if !util.StringOrPrefixListContains(newTags, tag) {
|
||||
if !slices.Contains(newTags, tag) {
|
||||
newTags = append(newTags, tag)
|
||||
}
|
||||
}
|
||||
|
@ -538,34 +539,24 @@ func IsRoutesEnabled(tx *gorm.DB, node *types.Node, routeStr string) bool {
|
|||
|
||||
func (hsdb *HSDatabase) enableRoutes(
|
||||
node *types.Node,
|
||||
routeStrs ...string,
|
||||
newRoutes ...netip.Prefix,
|
||||
) (*types.StateUpdate, error) {
|
||||
return Write(hsdb.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
|
||||
return enableRoutes(tx, node, routeStrs...)
|
||||
return enableRoutes(tx, node, newRoutes...)
|
||||
})
|
||||
}
|
||||
|
||||
// enableRoutes enables new routes based on a list of new routes.
|
||||
func enableRoutes(tx *gorm.DB,
|
||||
node *types.Node, routeStrs ...string,
|
||||
node *types.Node, newRoutes ...netip.Prefix,
|
||||
) (*types.StateUpdate, error) {
|
||||
newRoutes := make([]netip.Prefix, len(routeStrs))
|
||||
for index, routeStr := range routeStrs {
|
||||
route, err := netip.ParsePrefix(routeStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newRoutes[index] = route
|
||||
}
|
||||
|
||||
advertisedRoutes, err := GetAdvertisedRoutes(tx, node)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, newRoute := range newRoutes {
|
||||
if !util.StringOrPrefixListContains(advertisedRoutes, newRoute) {
|
||||
if !slices.Contains(advertisedRoutes, newRoute) {
|
||||
return nil, fmt.Errorf(
|
||||
"route (%s) is not available on node %s: %w",
|
||||
node.Hostname,
|
||||
|
@ -607,12 +598,6 @@ func enableRoutes(tx *gorm.DB,
|
|||
|
||||
node.Routes = nRoutes
|
||||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("node", node.Hostname).
|
||||
Strs("routes", routeStrs).
|
||||
Msg("enabling routes")
|
||||
|
||||
return &types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: []types.NodeID{node.ID},
|
||||
|
|
|
@ -6,7 +6,6 @@ import (
|
|||
"math/big"
|
||||
"net/netip"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
|
@ -20,6 +19,7 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
"gopkg.in/check.v1"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/ptr"
|
||||
|
@ -528,16 +528,16 @@ func TestAutoApproveRoutes(t *testing.T) {
|
|||
}
|
||||
}`,
|
||||
routes: []netip.Prefix{
|
||||
netip.MustParsePrefix("0.0.0.0/0"),
|
||||
netip.MustParsePrefix("::/0"),
|
||||
tsaddr.AllIPv4(),
|
||||
tsaddr.AllIPv6(),
|
||||
netip.MustParsePrefix("10.10.0.0/16"),
|
||||
netip.MustParsePrefix("10.11.0.0/24"),
|
||||
},
|
||||
want: []netip.Prefix{
|
||||
netip.MustParsePrefix("::/0"),
|
||||
netip.MustParsePrefix("10.11.0.0/24"),
|
||||
tsaddr.AllIPv4(),
|
||||
netip.MustParsePrefix("10.10.0.0/16"),
|
||||
netip.MustParsePrefix("0.0.0.0/0"),
|
||||
netip.MustParsePrefix("10.11.0.0/24"),
|
||||
tsaddr.AllIPv6(),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@ -594,9 +594,7 @@ func TestAutoApproveRoutes(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
assert.Len(t, enabledRoutes, len(tt.want))
|
||||
|
||||
sort.Slice(enabledRoutes, func(i, j int) bool {
|
||||
return util.ComparePrefix(enabledRoutes[i], enabledRoutes[j]) > 0
|
||||
})
|
||||
tsaddr.SortPrefixes(enabledRoutes)
|
||||
|
||||
if diff := cmp.Diff(tt.want, enabledRoutes, util.Comparers...); diff != "" {
|
||||
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"github.com/puzpuzpuz/xsync/v3"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/util/set"
|
||||
)
|
||||
|
||||
|
@ -117,12 +118,12 @@ func EnableRoute(tx *gorm.DB, id uint64) (*types.StateUpdate, error) {
|
|||
return enableRoutes(
|
||||
tx,
|
||||
&route.Node,
|
||||
types.ExitRouteV4.String(),
|
||||
types.ExitRouteV6.String(),
|
||||
tsaddr.AllIPv4(),
|
||||
tsaddr.AllIPv6(),
|
||||
)
|
||||
}
|
||||
|
||||
return enableRoutes(tx, &route.Node, netip.Prefix(route.Prefix).String())
|
||||
return enableRoutes(tx, &route.Node, netip.Prefix(route.Prefix))
|
||||
}
|
||||
|
||||
func DisableRoute(tx *gorm.DB,
|
||||
|
|
|
@ -27,6 +27,10 @@ var smap = func(m map[types.NodeID]bool) *xsync.MapOf[types.NodeID, bool] {
|
|||
return s
|
||||
}
|
||||
|
||||
var mp = func(p string) netip.Prefix {
|
||||
return netip.MustParsePrefix(p)
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetRoutes(c *check.C) {
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
@ -64,10 +68,10 @@ func (s *Suite) TestGetRoutes(c *check.C) {
|
|||
c.Assert(len(advertisedRoutes), check.Equals, 1)
|
||||
|
||||
// TODO(kradalby): check state update
|
||||
_, err = db.enableRoutes(&node, "192.168.0.0/24")
|
||||
_, err = db.enableRoutes(&node, mp("192.168.0.0/24"))
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
_, err = db.enableRoutes(&node, "10.0.0.0/24")
|
||||
_, err = db.enableRoutes(&node, mp("10.0.0.0/24"))
|
||||
c.Assert(err, check.IsNil)
|
||||
}
|
||||
|
||||
|
@ -119,10 +123,10 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
|
|||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(noEnabledRoutes), check.Equals, 0)
|
||||
|
||||
_, err = db.enableRoutes(&node, "192.168.0.0/24")
|
||||
_, err = db.enableRoutes(&node, mp("192.168.0.0/24"))
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
_, err = db.enableRoutes(&node, "10.0.0.0/24")
|
||||
_, err = db.enableRoutes(&node, mp("10.0.0.0/24"))
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
enabledRoutes, err := db.GetEnabledRoutes(&node)
|
||||
|
@ -130,14 +134,14 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
|
|||
c.Assert(len(enabledRoutes), check.Equals, 1)
|
||||
|
||||
// Adding it twice will just let it pass through
|
||||
_, err = db.enableRoutes(&node, "10.0.0.0/24")
|
||||
_, err = db.enableRoutes(&node, mp("10.0.0.0/24"))
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
enableRoutesAfterDoubleApply, err := db.GetEnabledRoutes(&node)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(enableRoutesAfterDoubleApply), check.Equals, 1)
|
||||
|
||||
_, err = db.enableRoutes(&node, "150.0.10.0/25")
|
||||
_, err = db.enableRoutes(&node, mp("150.0.10.0/25"))
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
enabledRoutesWithAdditionalRoute, err := db.GetEnabledRoutes(&node)
|
||||
|
@ -183,10 +187,10 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
|
|||
c.Assert(err, check.IsNil)
|
||||
c.Assert(sendUpdate, check.Equals, false)
|
||||
|
||||
_, err = db.enableRoutes(&node1, route.String())
|
||||
_, err = db.enableRoutes(&node1, route)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.enableRoutes(&node1, route2.String())
|
||||
_, err = db.enableRoutes(&node1, route2)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
hostInfo2 := tailcfg.Hostinfo{
|
||||
|
@ -206,7 +210,7 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
|
|||
c.Assert(err, check.IsNil)
|
||||
c.Assert(sendUpdate, check.Equals, false)
|
||||
|
||||
_, err = db.enableRoutes(&node2, route2.String())
|
||||
_, err = db.enableRoutes(&node2, route2)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
enabledRoutes1, err := db.GetEnabledRoutes(&node1)
|
||||
|
@ -267,10 +271,10 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
|
|||
c.Assert(err, check.IsNil)
|
||||
c.Assert(sendUpdate, check.Equals, false)
|
||||
|
||||
_, err = db.enableRoutes(&node1, prefix.String())
|
||||
_, err = db.enableRoutes(&node1, prefix)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.enableRoutes(&node1, prefix2.String())
|
||||
_, err = db.enableRoutes(&node1, prefix2)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
routes, err := db.GetNodeRoutes(&node1)
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"gopkg.in/check.v1"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/dnstype"
|
||||
"tailscale.com/types/key"
|
||||
|
@ -195,7 +196,7 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
Routes: []types.Route{
|
||||
{
|
||||
Prefix: types.IPPrefix(netip.MustParsePrefix("0.0.0.0/0")),
|
||||
Prefix: types.IPPrefix(tsaddr.AllIPv4()),
|
||||
Advertised: true,
|
||||
Enabled: true,
|
||||
IsPrimary: false,
|
||||
|
@ -234,7 +235,7 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")},
|
||||
AllowedIPs: []netip.Prefix{
|
||||
netip.MustParsePrefix("100.64.0.1/32"),
|
||||
netip.MustParsePrefix("0.0.0.0/0"),
|
||||
tsaddr.AllIPv4(),
|
||||
netip.MustParsePrefix("192.168.0.0/24"),
|
||||
},
|
||||
DERP: "127.3.3.40:0",
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
@ -108,7 +109,7 @@ func TestTailNode(t *testing.T) {
|
|||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
Routes: []types.Route{
|
||||
{
|
||||
Prefix: types.IPPrefix(netip.MustParsePrefix("0.0.0.0/0")),
|
||||
Prefix: types.IPPrefix(tsaddr.AllIPv4()),
|
||||
Advertised: true,
|
||||
Enabled: true,
|
||||
IsPrimary: false,
|
||||
|
@ -152,7 +153,7 @@ func TestTailNode(t *testing.T) {
|
|||
Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")},
|
||||
AllowedIPs: []netip.Prefix{
|
||||
netip.MustParsePrefix("100.64.0.1/32"),
|
||||
netip.MustParsePrefix("0.0.0.0/0"),
|
||||
tsaddr.AllIPv4(),
|
||||
netip.MustParsePrefix("192.168.0.0/24"),
|
||||
},
|
||||
DERP: "127.3.3.40:0",
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -16,6 +17,7 @@ import (
|
|||
"github.com/rs/zerolog/log"
|
||||
"github.com/tailscale/hujson"
|
||||
"go4.org/netipx"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
|
@ -45,7 +47,7 @@ func theInternet() *netipx.IPSet {
|
|||
|
||||
var internetBuilder netipx.IPSetBuilder
|
||||
internetBuilder.AddPrefix(netip.MustParsePrefix("2000::/3"))
|
||||
internetBuilder.AddPrefix(netip.MustParsePrefix("0.0.0.0/0"))
|
||||
internetBuilder.AddPrefix(tsaddr.AllIPv4())
|
||||
|
||||
// Delete Private network addresses
|
||||
// https://datatracker.ietf.org/doc/html/rfc1918
|
||||
|
@ -55,8 +57,8 @@ func theInternet() *netipx.IPSet {
|
|||
internetBuilder.RemovePrefix(netip.MustParsePrefix("192.168.0.0/16"))
|
||||
|
||||
// Delete Tailscale networks
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("fd7a:115c:a1e0::/48"))
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("100.64.0.0/10"))
|
||||
internetBuilder.RemovePrefix(tsaddr.TailscaleULARange())
|
||||
internetBuilder.RemovePrefix(tsaddr.CGNATRange())
|
||||
|
||||
// Delete "cant find DHCP networks"
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("fe80::/10")) // link-loca
|
||||
|
@ -603,7 +605,7 @@ func excludeCorrectlyTaggedNodes(
|
|||
for tag := range aclPolicy.TagOwners {
|
||||
owners, _ := expandOwnersFromTag(aclPolicy, user)
|
||||
ns := append(owners, user)
|
||||
if util.StringOrPrefixListContains(ns, user) {
|
||||
if slices.Contains(ns, user) {
|
||||
tags = append(tags, tag)
|
||||
}
|
||||
}
|
||||
|
@ -616,7 +618,7 @@ func excludeCorrectlyTaggedNodes(
|
|||
}
|
||||
|
||||
for _, t := range node.Hostinfo.RequestTags {
|
||||
if util.StringOrPrefixListContains(tags, t) {
|
||||
if slices.Contains(tags, t) {
|
||||
found = true
|
||||
|
||||
break
|
||||
|
@ -779,7 +781,7 @@ func (pol *ACLPolicy) expandIPsFromTag(
|
|||
|
||||
// check for forced tags
|
||||
for _, node := range nodes {
|
||||
if util.StringOrPrefixListContains(node.ForcedTags, alias) {
|
||||
if slices.Contains(node.ForcedTags, alias) {
|
||||
node.AppendToIPSet(&build)
|
||||
}
|
||||
}
|
||||
|
@ -811,7 +813,7 @@ func (pol *ACLPolicy) expandIPsFromTag(
|
|||
continue
|
||||
}
|
||||
|
||||
if util.StringOrPrefixListContains(node.Hostinfo.RequestTags, alias) {
|
||||
if slices.Contains(node.Hostinfo.RequestTags, alias) {
|
||||
node.AppendToIPSet(&build)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package policy
|
|||
import (
|
||||
"errors"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
@ -13,6 +14,7 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
"go4.org/netipx"
|
||||
"gopkg.in/check.v1"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
|
@ -341,7 +343,7 @@ func TestParsing(t *testing.T) {
|
|||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
`,
|
||||
want: []tailcfg.FilterRule{
|
||||
{
|
||||
|
@ -1998,7 +2000,7 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
IPv6: iap("fd7a:115c:a1e0::100"),
|
||||
User: types.User{Name: "user100"},
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{types.ExitRouteV4, types.ExitRouteV6},
|
||||
RoutableIPs: tsaddr.ExitRoutes(),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -2036,7 +2038,7 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
IPv6: iap("fd7a:115c:a1e0::100"),
|
||||
User: types.User{Name: "user100"},
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{types.ExitRouteV4, types.ExitRouteV6},
|
||||
RoutableIPs: tsaddr.ExitRoutes(),
|
||||
},
|
||||
},
|
||||
peers: types.Nodes{
|
||||
|
@ -2132,7 +2134,7 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
IPv6: iap("fd7a:115c:a1e0::100"),
|
||||
User: types.User{Name: "user100"},
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{types.ExitRouteV4, types.ExitRouteV6},
|
||||
RoutableIPs: tsaddr.ExitRoutes(),
|
||||
},
|
||||
},
|
||||
peers: types.Nodes{
|
||||
|
@ -2548,7 +2550,7 @@ func Test_getTags(t *testing.T) {
|
|||
test.args.node,
|
||||
)
|
||||
for _, valid := range gotValid {
|
||||
if !util.StringOrPrefixListContains(test.wantValid, valid) {
|
||||
if !slices.Contains(test.wantValid, valid) {
|
||||
t.Errorf(
|
||||
"valids: getTags() = %v, want %v",
|
||||
gotValid,
|
||||
|
@ -2559,7 +2561,7 @@ func Test_getTags(t *testing.T) {
|
|||
}
|
||||
}
|
||||
for _, invalid := range gotInvalid {
|
||||
if !util.StringOrPrefixListContains(test.wantInvalid, invalid) {
|
||||
if !slices.Contains(test.wantInvalid, invalid) {
|
||||
t.Errorf(
|
||||
"invalids: getTags() = %v, want %v",
|
||||
gotInvalid,
|
||||
|
|
|
@ -6,18 +6,17 @@ import (
|
|||
"math/rand/v2"
|
||||
"net/http"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/db"
|
||||
"github.com/juanfont/headscale/hscontrol/mapper"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/sasha-s/go-deadlock"
|
||||
xslices "golang.org/x/exp/slices"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
|
@ -666,12 +665,8 @@ func hostInfoChanged(old, new *tailcfg.Hostinfo) (bool, bool) {
|
|||
oldRoutes := old.RoutableIPs
|
||||
newRoutes := new.RoutableIPs
|
||||
|
||||
sort.Slice(oldRoutes, func(i, j int) bool {
|
||||
return util.ComparePrefix(oldRoutes[i], oldRoutes[j]) > 0
|
||||
})
|
||||
sort.Slice(newRoutes, func(i, j int) bool {
|
||||
return util.ComparePrefix(newRoutes[i], newRoutes[j]) > 0
|
||||
})
|
||||
tsaddr.SortPrefixes(oldRoutes)
|
||||
tsaddr.SortPrefixes(newRoutes)
|
||||
|
||||
if !xslices.Equal(oldRoutes, newRoutes) {
|
||||
return true, true
|
||||
|
|
|
@ -7,11 +7,7 @@ import (
|
|||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
ExitRouteV4 = netip.MustParsePrefix("0.0.0.0/0")
|
||||
ExitRouteV6 = netip.MustParsePrefix("::/0")
|
||||
"tailscale.com/net/tsaddr"
|
||||
)
|
||||
|
||||
type Route struct {
|
||||
|
@ -35,7 +31,7 @@ func (r *Route) String() string {
|
|||
}
|
||||
|
||||
func (r *Route) IsExitRoute() bool {
|
||||
return netip.Prefix(r.Prefix) == ExitRouteV4 || netip.Prefix(r.Prefix) == ExitRouteV6
|
||||
return tsaddr.IsExitRoute(netip.Prefix(r.Prefix))
|
||||
}
|
||||
|
||||
func (r *Route) IsAnnouncable() bool {
|
||||
|
|
|
@ -3,7 +3,6 @@ package util
|
|||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"go4.org/netipx"
|
||||
|
@ -104,7 +103,7 @@ func StringToIPPrefix(prefixes []string) ([]netip.Prefix, error) {
|
|||
for index, prefixStr := range prefixes {
|
||||
prefix, err := netip.ParsePrefix(prefixStr)
|
||||
if err != nil {
|
||||
return []netip.Prefix{}, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result[index] = prefix
|
||||
|
@ -112,13 +111,3 @@ func StringToIPPrefix(prefixes []string) ([]netip.Prefix, error) {
|
|||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func StringOrPrefixListContains[T string | netip.Prefix](ts []T, t T) bool {
|
||||
for _, v := range ts {
|
||||
if reflect.DeepEqual(v, t) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -1,33 +1,10 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"regexp"
|
||||
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
var (
|
||||
NodePublicKeyRegex = regexp.MustCompile("nodekey:[a-fA-F0-9]+")
|
||||
ErrCannotDecryptResponse = errors.New("cannot decrypt response")
|
||||
ZstdCompression = "zstd"
|
||||
)
|
||||
|
||||
func DecodeAndUnmarshalNaCl(
|
||||
msg []byte,
|
||||
output interface{},
|
||||
pubKey *key.MachinePublic,
|
||||
privKey *key.MachinePrivate,
|
||||
) error {
|
||||
decrypted, ok := privKey.OpenFrom(*pubKey, msg)
|
||||
if !ok {
|
||||
return ErrCannotDecryptResponse
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(decrypted, output); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) {
|
||||
|
@ -12,19 +10,3 @@ func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) {
|
|||
|
||||
return d.DialContext(ctx, "unix", addr)
|
||||
}
|
||||
|
||||
// TODO(kradalby): Remove after go 1.24, will be in stdlib.
|
||||
// Compare returns an integer comparing two prefixes.
|
||||
// The result will be 0 if p == p2, -1 if p < p2, and +1 if p > p2.
|
||||
// Prefixes sort first by validity (invalid before valid), then
|
||||
// address family (IPv4 before IPv6), then prefix length, then
|
||||
// address.
|
||||
func ComparePrefix(p, p2 netip.Prefix) int {
|
||||
if c := cmp.Compare(p.Addr().BitLen(), p2.Addr().BitLen()); c != 0 {
|
||||
return c
|
||||
}
|
||||
if c := cmp.Compare(p.Bits(), p2.Bits()); c != 0 {
|
||||
return c
|
||||
}
|
||||
return p.Addr().Compare(p2.Addr())
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue