Migrate IP fields in database to dedicated columns (#1869)
This commit is contained in:
parent
85cef84e17
commit
2ce23df45a
39 changed files with 1885 additions and 1055 deletions
|
@ -9,7 +9,6 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
_ "net/http/pprof" //nolint
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
|
@ -56,6 +55,7 @@ import (
|
|||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/dnstype"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/util/dnsname"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -148,7 +148,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
app.ipAlloc, err = db.NewIPAllocator(app.db, *cfg.PrefixV4, *cfg.PrefixV6)
|
||||
app.ipAlloc, err = db.NewIPAllocator(app.db, cfg.PrefixV4, cfg.PrefixV6, cfg.IPAllocation)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -166,7 +166,15 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
|||
|
||||
if app.cfg.DNSConfig != nil && app.cfg.DNSConfig.Proxied { // if MagicDNS
|
||||
// TODO(kradalby): revisit why this takes a list.
|
||||
magicDNSDomains := util.GenerateMagicDNSRootDomains([]netip.Prefix{*cfg.PrefixV4, *cfg.PrefixV6})
|
||||
|
||||
var magicDNSDomains []dnsname.FQDN
|
||||
if cfg.PrefixV4 != nil {
|
||||
magicDNSDomains = append(magicDNSDomains, util.GenerateIPv4DNSRootDomain(*cfg.PrefixV4)...)
|
||||
}
|
||||
if cfg.PrefixV6 != nil {
|
||||
magicDNSDomains = append(magicDNSDomains, util.GenerateIPv6DNSRootDomain(*cfg.PrefixV6)...)
|
||||
}
|
||||
|
||||
// we might have routes already from Split DNS
|
||||
if app.cfg.DNSConfig.Routes == nil {
|
||||
app.cfg.DNSConfig.Routes = make(map[string][]*dnstype.Resolver)
|
||||
|
|
|
@ -383,7 +383,7 @@ func (h *Headscale) handleAuthKey(
|
|||
ForcedTags: pak.Proto().GetAclTags(),
|
||||
}
|
||||
|
||||
addrs, err := h.ipAlloc.Next()
|
||||
ipv4, ipv6, err := h.ipAlloc.Next()
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
@ -397,7 +397,7 @@ func (h *Headscale) handleAuthKey(
|
|||
|
||||
node, err = h.db.RegisterNode(
|
||||
nodeToRegister,
|
||||
addrs,
|
||||
ipv4, ipv6,
|
||||
)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
|
@ -461,7 +461,6 @@ func (h *Headscale) handleAuthKey(
|
|||
|
||||
log.Info().
|
||||
Str("node", registerRequest.Hostinfo.Hostname).
|
||||
Str("ips", strings.Join(node.IPAddresses.StringSlice(), ", ")).
|
||||
Msg("Successfully authenticated via AuthKey")
|
||||
}
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
@ -330,6 +331,66 @@ func NewHeadscaleDatabase(
|
|||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
// Replace column with IP address list with dedicated
|
||||
// IP v4 and v6 column.
|
||||
// Note that previously, the list _could_ contain more
|
||||
// than two addresses, which should not really happen.
|
||||
// In that case, the first occurence of each type will
|
||||
// be kept.
|
||||
ID: "2024041121742",
|
||||
Migrate: func(tx *gorm.DB) error {
|
||||
_ = tx.Migrator().AddColumn(&types.Node{}, "ipv4")
|
||||
_ = tx.Migrator().AddColumn(&types.Node{}, "ipv6")
|
||||
|
||||
type node struct {
|
||||
ID uint64 `gorm:"column:id"`
|
||||
Addresses string `gorm:"column:ip_addresses"`
|
||||
}
|
||||
|
||||
var nodes []node
|
||||
|
||||
_ = tx.Raw("SELECT id, ip_addresses FROM nodes").Scan(&nodes).Error
|
||||
|
||||
for _, node := range nodes {
|
||||
addrs := strings.Split(node.Addresses, ",")
|
||||
|
||||
if len(addrs) == 0 {
|
||||
fmt.Errorf("no addresses found for node(%d)", node.ID)
|
||||
}
|
||||
|
||||
var v4 *netip.Addr
|
||||
var v6 *netip.Addr
|
||||
|
||||
for _, addrStr := range addrs {
|
||||
addr, err := netip.ParseAddr(addrStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing IP for node(%d) from database: %w", node.ID, err)
|
||||
}
|
||||
|
||||
if addr.Is4() && v4 == nil {
|
||||
v4 = &addr
|
||||
}
|
||||
|
||||
if addr.Is6() && v6 == nil {
|
||||
v6 = &addr
|
||||
}
|
||||
}
|
||||
|
||||
err = tx.Save(&types.Node{ID: types.NodeID(node.ID), IPv4: v4, IPv6: v6}).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("saving ip addresses to new columns: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
_ = tx.Migrator().DropColumn(&types.Node{}, "ip_addresses")
|
||||
|
||||
return nil
|
||||
},
|
||||
Rollback: func(tx *gorm.DB) error {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
@ -1,13 +1,17 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"go4.org/netipx"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
@ -20,13 +24,16 @@ import (
|
|||
type IPAllocator struct {
|
||||
mu sync.Mutex
|
||||
|
||||
prefix4 netip.Prefix
|
||||
prefix6 netip.Prefix
|
||||
prefix4 *netip.Prefix
|
||||
prefix6 *netip.Prefix
|
||||
|
||||
// Previous IPs handed out
|
||||
prev4 netip.Addr
|
||||
prev6 netip.Addr
|
||||
|
||||
// strategy used for handing out IP addresses.
|
||||
strategy types.IPAllocationStrategy
|
||||
|
||||
// Set of all IPs handed out.
|
||||
// This might not be in sync with the database,
|
||||
// but it is more conservative. If saves to the
|
||||
|
@ -40,40 +47,71 @@ type IPAllocator struct {
|
|||
// provided IPv4 and IPv6 prefix. It needs to be created
|
||||
// when headscale starts and needs to finish its read
|
||||
// transaction before any writes to the database occur.
|
||||
func NewIPAllocator(db *HSDatabase, prefix4, prefix6 netip.Prefix) (*IPAllocator, error) {
|
||||
var addressesSlices []string
|
||||
func NewIPAllocator(
|
||||
db *HSDatabase,
|
||||
prefix4, prefix6 *netip.Prefix,
|
||||
strategy types.IPAllocationStrategy,
|
||||
) (*IPAllocator, error) {
|
||||
ret := IPAllocator{
|
||||
prefix4: prefix4,
|
||||
prefix6: prefix6,
|
||||
|
||||
strategy: strategy,
|
||||
}
|
||||
|
||||
var v4s []sql.NullString
|
||||
var v6s []sql.NullString
|
||||
|
||||
if db != nil {
|
||||
db.Read(func(rx *gorm.DB) error {
|
||||
return rx.Model(&types.Node{}).Pluck("ip_addresses", &addressesSlices).Error
|
||||
err := db.Read(func(rx *gorm.DB) error {
|
||||
return rx.Model(&types.Node{}).Pluck("ipv4", &v4s).Error
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading IPv4 addresses from database: %w", err)
|
||||
}
|
||||
|
||||
err = db.Read(func(rx *gorm.DB) error {
|
||||
return rx.Model(&types.Node{}).Pluck("ipv6", &v6s).Error
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading IPv6 addresses from database: %w", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
var ips netipx.IPSetBuilder
|
||||
|
||||
// Add network and broadcast addrs to used pool so they
|
||||
// are not handed out to nodes.
|
||||
network4, broadcast4 := util.GetIPPrefixEndpoints(prefix4)
|
||||
network6, broadcast6 := util.GetIPPrefixEndpoints(prefix6)
|
||||
ips.Add(network4)
|
||||
ips.Add(broadcast4)
|
||||
ips.Add(network6)
|
||||
ips.Add(broadcast6)
|
||||
if prefix4 != nil {
|
||||
network4, broadcast4 := util.GetIPPrefixEndpoints(*prefix4)
|
||||
ips.Add(network4)
|
||||
ips.Add(broadcast4)
|
||||
|
||||
// Use network as starting point, it will be used to call .Next()
|
||||
// TODO(kradalby): Could potentially take all the IPs loaded from
|
||||
// the database into account to start at a more "educated" location.
|
||||
ret.prev4 = network4
|
||||
}
|
||||
|
||||
if prefix6 != nil {
|
||||
network6, broadcast6 := util.GetIPPrefixEndpoints(*prefix6)
|
||||
ips.Add(network6)
|
||||
ips.Add(broadcast6)
|
||||
|
||||
ret.prev6 = network6
|
||||
}
|
||||
|
||||
// Fetch all the IP Addresses currently handed out from the Database
|
||||
// and add them to the used IP set.
|
||||
for _, slice := range addressesSlices {
|
||||
var machineAddresses types.NodeAddresses
|
||||
err := machineAddresses.Scan(slice)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"parsing IPs from database %v: %w", machineAddresses,
|
||||
err,
|
||||
)
|
||||
}
|
||||
for _, addrStr := range append(v4s, v6s...) {
|
||||
if addrStr.Valid {
|
||||
addr, err := netip.ParseAddr(addrStr.String)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing IP address from database: %w", err)
|
||||
}
|
||||
|
||||
for _, ip := range machineAddresses {
|
||||
ips.Add(ip)
|
||||
ips.Add(addr)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -86,42 +124,61 @@ func NewIPAllocator(db *HSDatabase, prefix4, prefix6 netip.Prefix) (*IPAllocator
|
|||
)
|
||||
}
|
||||
|
||||
return &IPAllocator{
|
||||
usedIPs: ips,
|
||||
ret.usedIPs = ips
|
||||
|
||||
prefix4: prefix4,
|
||||
prefix6: prefix6,
|
||||
|
||||
// Use network as starting point, it will be used to call .Next()
|
||||
// TODO(kradalby): Could potentially take all the IPs loaded from
|
||||
// the database into account to start at a more "educated" location.
|
||||
prev4: network4,
|
||||
prev6: network6,
|
||||
}, nil
|
||||
return &ret, nil
|
||||
}
|
||||
|
||||
func (i *IPAllocator) Next() (types.NodeAddresses, error) {
|
||||
func (i *IPAllocator) Next() (*netip.Addr, *netip.Addr, error) {
|
||||
i.mu.Lock()
|
||||
defer i.mu.Unlock()
|
||||
|
||||
v4, err := i.next(i.prev4, i.prefix4)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("allocating IPv4 address: %w", err)
|
||||
var err error
|
||||
var ret4 *netip.Addr
|
||||
var ret6 *netip.Addr
|
||||
|
||||
if i.prefix4 != nil {
|
||||
ret4, err = i.next(i.prev4, i.prefix4)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("allocating IPv4 address: %w", err)
|
||||
}
|
||||
i.prev4 = *ret4
|
||||
}
|
||||
|
||||
v6, err := i.next(i.prev6, i.prefix6)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("allocating IPv6 address: %w", err)
|
||||
if i.prefix6 != nil {
|
||||
ret6, err = i.next(i.prev6, i.prefix6)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("allocating IPv6 address: %w", err)
|
||||
}
|
||||
i.prev6 = *ret6
|
||||
}
|
||||
|
||||
return types.NodeAddresses{*v4, *v6}, nil
|
||||
return ret4, ret6, nil
|
||||
}
|
||||
|
||||
var ErrCouldNotAllocateIP = errors.New("failed to allocate IP")
|
||||
|
||||
func (i *IPAllocator) next(prev netip.Addr, prefix netip.Prefix) (*netip.Addr, error) {
|
||||
// Get the first IP in our prefix
|
||||
ip := prev.Next()
|
||||
func (i *IPAllocator) nextLocked(prev netip.Addr, prefix *netip.Prefix) (*netip.Addr, error) {
|
||||
i.mu.Lock()
|
||||
defer i.mu.Unlock()
|
||||
|
||||
return i.next(prev, prefix)
|
||||
}
|
||||
|
||||
func (i *IPAllocator) next(prev netip.Addr, prefix *netip.Prefix) (*netip.Addr, error) {
|
||||
var err error
|
||||
var ip netip.Addr
|
||||
|
||||
switch i.strategy {
|
||||
case types.IPAllocationStrategySequential:
|
||||
// Get the first IP in our prefix
|
||||
ip = prev.Next()
|
||||
case types.IPAllocationStrategyRandom:
|
||||
ip, err = randomNext(*prefix)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting random IP: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(kradalby): maybe this can be done less often.
|
||||
set, err := i.usedIPs.IPSet()
|
||||
|
@ -136,7 +193,15 @@ func (i *IPAllocator) next(prev netip.Addr, prefix netip.Prefix) (*netip.Addr, e
|
|||
|
||||
// Check if the IP has already been allocated.
|
||||
if set.Contains(ip) {
|
||||
ip = ip.Next()
|
||||
switch i.strategy {
|
||||
case types.IPAllocationStrategySequential:
|
||||
ip = ip.Next()
|
||||
case types.IPAllocationStrategyRandom:
|
||||
ip, err = randomNext(*prefix)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting random IP: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
@ -146,3 +211,120 @@ func (i *IPAllocator) next(prev netip.Addr, prefix netip.Prefix) (*netip.Addr, e
|
|||
return &ip, nil
|
||||
}
|
||||
}
|
||||
|
||||
func randomNext(pfx netip.Prefix) (netip.Addr, error) {
|
||||
rang := netipx.RangeOfPrefix(pfx)
|
||||
fromIP, toIP := rang.From(), rang.To()
|
||||
|
||||
var from, to big.Int
|
||||
|
||||
from.SetBytes(fromIP.AsSlice())
|
||||
to.SetBytes(toIP.AsSlice())
|
||||
|
||||
// Find the max, this is how we can do "random range",
|
||||
// get the "max" as 0 -> to - from and then add back from
|
||||
// after.
|
||||
tempMax := big.NewInt(0).Sub(&to, &from)
|
||||
|
||||
out, err := rand.Int(rand.Reader, tempMax)
|
||||
if err != nil {
|
||||
return netip.Addr{}, fmt.Errorf("generating random IP: %w", err)
|
||||
}
|
||||
|
||||
valInRange := big.NewInt(0).Add(&from, out)
|
||||
|
||||
ip, ok := netip.AddrFromSlice(valInRange.Bytes())
|
||||
if !ok {
|
||||
return netip.Addr{}, fmt.Errorf("generated ip bytes are invalid ip")
|
||||
}
|
||||
|
||||
if !pfx.Contains(ip) {
|
||||
return netip.Addr{}, fmt.Errorf(
|
||||
"generated ip(%s) not in prefix(%s)",
|
||||
ip.String(),
|
||||
pfx.String(),
|
||||
)
|
||||
}
|
||||
|
||||
return ip, nil
|
||||
}
|
||||
|
||||
// BackfillNodeIPs will take a database transaction, and
|
||||
// iterate through all of the current nodes in headscale
|
||||
// and ensure it has IP addresses according to the current
|
||||
// configuration.
|
||||
// This means that if both IPv4 and IPv6 is set in the
|
||||
// config, and some nodes are missing that type of IP,
|
||||
// it will be added.
|
||||
// If a prefix type has been removed (IPv4 or IPv6), it
|
||||
// will remove the IPs in that family from the node.
|
||||
func (db *HSDatabase) BackfillNodeIPs(i *IPAllocator) ([]string, error) {
|
||||
var err error
|
||||
var ret []string
|
||||
err = db.Write(func(tx *gorm.DB) error {
|
||||
if i == nil {
|
||||
return errors.New("backfilling IPs: ip allocator was nil")
|
||||
}
|
||||
|
||||
log.Trace().Msgf("starting to backfill IPs")
|
||||
|
||||
nodes, err := ListNodes(tx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listing nodes to backfill IPs: %w", err)
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
log.Trace().Uint64("node.id", node.ID.Uint64()).Msg("checking if need backfill")
|
||||
|
||||
changed := false
|
||||
// IPv4 prefix is set, but node ip is missing, alloc
|
||||
if i.prefix4 != nil && node.IPv4 == nil {
|
||||
ret4, err := i.nextLocked(i.prev4, i.prefix4)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to allocate ipv4 for node(%d): %w", node.ID, err)
|
||||
}
|
||||
|
||||
node.IPv4 = ret4
|
||||
changed = true
|
||||
ret = append(ret, fmt.Sprintf("assigned IPv4 %q to Node(%d) %q", ret4.String(), node.ID, node.Hostname))
|
||||
}
|
||||
|
||||
// IPv6 prefix is set, but node ip is missing, alloc
|
||||
if i.prefix6 != nil && node.IPv6 == nil {
|
||||
ret6, err := i.nextLocked(i.prev6, i.prefix6)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to allocate ipv6 for node(%d): %w", node.ID, err)
|
||||
}
|
||||
|
||||
node.IPv6 = ret6
|
||||
changed = true
|
||||
ret = append(ret, fmt.Sprintf("assigned IPv6 %q to Node(%d) %q", ret6.String(), node.ID, node.Hostname))
|
||||
}
|
||||
|
||||
// IPv4 prefix is not set, but node has IP, remove
|
||||
if i.prefix4 == nil && node.IPv4 != nil {
|
||||
ret = append(ret, fmt.Sprintf("removing IPv4 %q from Node(%d) %q", node.IPv4.String(), node.ID, node.Hostname))
|
||||
node.IPv4 = nil
|
||||
changed = true
|
||||
}
|
||||
|
||||
// IPv6 prefix is not set, but node has IP, remove
|
||||
if i.prefix6 == nil && node.IPv6 != nil {
|
||||
ret = append(ret, fmt.Sprintf("removing IPv6 %q from Node(%d) %q", node.IPv6.String(), node.ID, node.Hostname))
|
||||
node.IPv6 = nil
|
||||
changed = true
|
||||
}
|
||||
|
||||
if changed {
|
||||
err := tx.Save(node).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("saving node(%d) after adding IPs: %w", node.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return ret, err
|
||||
}
|
||||
|
|
|
@ -1,49 +1,41 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
)
|
||||
|
||||
func TestIPAllocator(t *testing.T) {
|
||||
mpp := func(pref string) netip.Prefix {
|
||||
return netip.MustParsePrefix(pref)
|
||||
}
|
||||
na := func(pref string) netip.Addr {
|
||||
return netip.MustParseAddr(pref)
|
||||
}
|
||||
newDb := func() *HSDatabase {
|
||||
tmpDir, err := os.MkdirTemp("", "headscale-db-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("creating temp dir: %s", err)
|
||||
}
|
||||
db, _ = NewHeadscaleDatabase(
|
||||
types.DatabaseConfig{
|
||||
Type: "sqlite3",
|
||||
Sqlite: types.SqliteConfig{
|
||||
Path: tmpDir + "/headscale_test.db",
|
||||
},
|
||||
},
|
||||
"",
|
||||
)
|
||||
|
||||
return db
|
||||
}
|
||||
var mpp = func(pref string) *netip.Prefix {
|
||||
p := netip.MustParsePrefix(pref)
|
||||
return &p
|
||||
}
|
||||
var na = func(pref string) netip.Addr {
|
||||
return netip.MustParseAddr(pref)
|
||||
}
|
||||
var nap = func(pref string) *netip.Addr {
|
||||
n := na(pref)
|
||||
return &n
|
||||
}
|
||||
|
||||
func TestIPAllocatorSequential(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
dbFunc func() *HSDatabase
|
||||
|
||||
prefix4 netip.Prefix
|
||||
prefix6 netip.Prefix
|
||||
prefix4 *netip.Prefix
|
||||
prefix6 *netip.Prefix
|
||||
getCount int
|
||||
want []types.NodeAddresses
|
||||
want4 []netip.Addr
|
||||
want6 []netip.Addr
|
||||
}{
|
||||
{
|
||||
name: "simple",
|
||||
|
@ -56,23 +48,49 @@ func TestIPAllocator(t *testing.T) {
|
|||
|
||||
getCount: 1,
|
||||
|
||||
want: []types.NodeAddresses{
|
||||
{
|
||||
na("100.64.0.1"),
|
||||
na("fd7a:115c:a1e0::1"),
|
||||
},
|
||||
want4: []netip.Addr{
|
||||
na("100.64.0.1"),
|
||||
},
|
||||
want6: []netip.Addr{
|
||||
na("fd7a:115c:a1e0::1"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "simple-v4",
|
||||
dbFunc: func() *HSDatabase {
|
||||
return nil
|
||||
},
|
||||
|
||||
prefix4: mpp("100.64.0.0/10"),
|
||||
|
||||
getCount: 1,
|
||||
|
||||
want4: []netip.Addr{
|
||||
na("100.64.0.1"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "simple-v6",
|
||||
dbFunc: func() *HSDatabase {
|
||||
return nil
|
||||
},
|
||||
|
||||
prefix6: mpp("fd7a:115c:a1e0::/48"),
|
||||
|
||||
getCount: 1,
|
||||
|
||||
want6: []netip.Addr{
|
||||
na("fd7a:115c:a1e0::1"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "simple-with-db",
|
||||
dbFunc: func() *HSDatabase {
|
||||
db := newDb()
|
||||
db := dbForTest(t, "simple-with-db")
|
||||
|
||||
db.DB.Save(&types.Node{
|
||||
IPAddresses: types.NodeAddresses{
|
||||
na("100.64.0.1"),
|
||||
na("fd7a:115c:a1e0::1"),
|
||||
},
|
||||
IPv4: nap("100.64.0.1"),
|
||||
IPv6: nap("fd7a:115c:a1e0::1"),
|
||||
})
|
||||
|
||||
return db
|
||||
|
@ -83,23 +101,21 @@ func TestIPAllocator(t *testing.T) {
|
|||
|
||||
getCount: 1,
|
||||
|
||||
want: []types.NodeAddresses{
|
||||
{
|
||||
na("100.64.0.2"),
|
||||
na("fd7a:115c:a1e0::2"),
|
||||
},
|
||||
want4: []netip.Addr{
|
||||
na("100.64.0.2"),
|
||||
},
|
||||
want6: []netip.Addr{
|
||||
na("fd7a:115c:a1e0::2"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "before-after-free-middle-in-db",
|
||||
dbFunc: func() *HSDatabase {
|
||||
db := newDb()
|
||||
db := dbForTest(t, "before-after-free-middle-in-db")
|
||||
|
||||
db.DB.Save(&types.Node{
|
||||
IPAddresses: types.NodeAddresses{
|
||||
na("100.64.0.2"),
|
||||
na("fd7a:115c:a1e0::2"),
|
||||
},
|
||||
IPv4: nap("100.64.0.2"),
|
||||
IPv6: nap("fd7a:115c:a1e0::2"),
|
||||
})
|
||||
|
||||
return db
|
||||
|
@ -110,15 +126,13 @@ func TestIPAllocator(t *testing.T) {
|
|||
|
||||
getCount: 2,
|
||||
|
||||
want: []types.NodeAddresses{
|
||||
{
|
||||
na("100.64.0.1"),
|
||||
na("fd7a:115c:a1e0::1"),
|
||||
},
|
||||
{
|
||||
na("100.64.0.3"),
|
||||
na("fd7a:115c:a1e0::3"),
|
||||
},
|
||||
want4: []netip.Addr{
|
||||
na("100.64.0.1"),
|
||||
na("100.64.0.3"),
|
||||
},
|
||||
want6: []netip.Addr{
|
||||
na("fd7a:115c:a1e0::1"),
|
||||
na("fd7a:115c:a1e0::3"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@ -127,24 +141,347 @@ func TestIPAllocator(t *testing.T) {
|
|||
t.Run(tt.name, func(t *testing.T) {
|
||||
db := tt.dbFunc()
|
||||
|
||||
alloc, _ := NewIPAllocator(db, tt.prefix4, tt.prefix6)
|
||||
alloc, _ := NewIPAllocator(
|
||||
db,
|
||||
tt.prefix4,
|
||||
tt.prefix6,
|
||||
types.IPAllocationStrategySequential,
|
||||
)
|
||||
|
||||
spew.Dump(alloc)
|
||||
|
||||
t.Logf("prefixes: %q, %q", tt.prefix4.String(), tt.prefix6.String())
|
||||
|
||||
var got []types.NodeAddresses
|
||||
var got4s []netip.Addr
|
||||
var got6s []netip.Addr
|
||||
|
||||
for range tt.getCount {
|
||||
gotSet, err := alloc.Next()
|
||||
got4, got6, err := alloc.Next()
|
||||
if err != nil {
|
||||
t.Fatalf("allocating next IP: %s", err)
|
||||
}
|
||||
|
||||
got = append(got, gotSet)
|
||||
if got4 != nil {
|
||||
got4s = append(got4s, *got4)
|
||||
}
|
||||
|
||||
if got6 != nil {
|
||||
got6s = append(got6s, *got6)
|
||||
}
|
||||
}
|
||||
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
|
||||
t.Errorf("IPAllocator unexpected result (-want +got):\n%s", diff)
|
||||
if diff := cmp.Diff(tt.want4, got4s, util.Comparers...); diff != "" {
|
||||
t.Errorf("IPAllocator 4s unexpected result (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.want6, got6s, util.Comparers...); diff != "" {
|
||||
t.Errorf("IPAllocator 6s unexpected result (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPAllocatorRandom(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
dbFunc func() *HSDatabase
|
||||
|
||||
getCount int
|
||||
|
||||
prefix4 *netip.Prefix
|
||||
prefix6 *netip.Prefix
|
||||
want4 bool
|
||||
want6 bool
|
||||
}{
|
||||
{
|
||||
name: "simple",
|
||||
dbFunc: func() *HSDatabase {
|
||||
return nil
|
||||
},
|
||||
|
||||
prefix4: mpp("100.64.0.0/10"),
|
||||
prefix6: mpp("fd7a:115c:a1e0::/48"),
|
||||
|
||||
getCount: 1,
|
||||
|
||||
want4: true,
|
||||
want6: true,
|
||||
},
|
||||
{
|
||||
name: "simple-v4",
|
||||
dbFunc: func() *HSDatabase {
|
||||
return nil
|
||||
},
|
||||
|
||||
prefix4: mpp("100.64.0.0/10"),
|
||||
|
||||
getCount: 1,
|
||||
|
||||
want4: true,
|
||||
want6: false,
|
||||
},
|
||||
{
|
||||
name: "simple-v6",
|
||||
dbFunc: func() *HSDatabase {
|
||||
return nil
|
||||
},
|
||||
|
||||
prefix6: mpp("fd7a:115c:a1e0::/48"),
|
||||
|
||||
getCount: 1,
|
||||
|
||||
want4: false,
|
||||
want6: true,
|
||||
},
|
||||
{
|
||||
name: "generate-lots-of-random",
|
||||
dbFunc: func() *HSDatabase {
|
||||
return nil
|
||||
},
|
||||
|
||||
prefix4: mpp("100.64.0.0/10"),
|
||||
prefix6: mpp("fd7a:115c:a1e0::/48"),
|
||||
|
||||
getCount: 1000,
|
||||
|
||||
want4: true,
|
||||
want6: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db := tt.dbFunc()
|
||||
|
||||
alloc, _ := NewIPAllocator(db, tt.prefix4, tt.prefix6, types.IPAllocationStrategyRandom)
|
||||
|
||||
spew.Dump(alloc)
|
||||
|
||||
for range tt.getCount {
|
||||
got4, got6, err := alloc.Next()
|
||||
if err != nil {
|
||||
t.Fatalf("allocating next IP: %s", err)
|
||||
}
|
||||
|
||||
t.Logf("addrs ipv4: %v, ipv6: %v", got4, got6)
|
||||
|
||||
if tt.want4 {
|
||||
if got4 == nil {
|
||||
t.Fatalf("expected ipv4 addr, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
if tt.want6 {
|
||||
if got6 == nil {
|
||||
t.Fatalf("expected ipv4 addr, got nil")
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackfillIPAddresses(t *testing.T) {
|
||||
fullNodeP := func(i int) *types.Node {
|
||||
v4 := fmt.Sprintf("100.64.0.%d", i)
|
||||
v6 := fmt.Sprintf("fd7a:115c:a1e0::%d", i)
|
||||
return &types.Node{
|
||||
IPv4DatabaseField: sql.NullString{
|
||||
Valid: true,
|
||||
String: v4,
|
||||
},
|
||||
IPv4: nap(v4),
|
||||
IPv6DatabaseField: sql.NullString{
|
||||
Valid: true,
|
||||
String: v6,
|
||||
},
|
||||
IPv6: nap(v6),
|
||||
}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
dbFunc func() *HSDatabase
|
||||
|
||||
prefix4 *netip.Prefix
|
||||
prefix6 *netip.Prefix
|
||||
want types.Nodes
|
||||
}{
|
||||
{
|
||||
name: "simple-backfill-ipv6",
|
||||
dbFunc: func() *HSDatabase {
|
||||
db := dbForTest(t, "simple-backfill-ipv6")
|
||||
|
||||
db.DB.Save(&types.Node{
|
||||
IPv4: nap("100.64.0.1"),
|
||||
})
|
||||
|
||||
return db
|
||||
},
|
||||
|
||||
prefix4: mpp("100.64.0.0/10"),
|
||||
prefix6: mpp("fd7a:115c:a1e0::/48"),
|
||||
|
||||
want: types.Nodes{
|
||||
&types.Node{
|
||||
IPv4DatabaseField: sql.NullString{
|
||||
Valid: true,
|
||||
String: "100.64.0.1",
|
||||
},
|
||||
IPv4: nap("100.64.0.1"),
|
||||
IPv6DatabaseField: sql.NullString{
|
||||
Valid: true,
|
||||
String: "fd7a:115c:a1e0::1",
|
||||
},
|
||||
IPv6: nap("fd7a:115c:a1e0::1"),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "simple-backfill-ipv4",
|
||||
dbFunc: func() *HSDatabase {
|
||||
db := dbForTest(t, "simple-backfill-ipv4")
|
||||
|
||||
db.DB.Save(&types.Node{
|
||||
IPv6: nap("fd7a:115c:a1e0::1"),
|
||||
})
|
||||
|
||||
return db
|
||||
},
|
||||
|
||||
prefix4: mpp("100.64.0.0/10"),
|
||||
prefix6: mpp("fd7a:115c:a1e0::/48"),
|
||||
|
||||
want: types.Nodes{
|
||||
&types.Node{
|
||||
IPv4DatabaseField: sql.NullString{
|
||||
Valid: true,
|
||||
String: "100.64.0.1",
|
||||
},
|
||||
IPv4: nap("100.64.0.1"),
|
||||
IPv6DatabaseField: sql.NullString{
|
||||
Valid: true,
|
||||
String: "fd7a:115c:a1e0::1",
|
||||
},
|
||||
IPv6: nap("fd7a:115c:a1e0::1"),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "simple-backfill-remove-ipv6",
|
||||
dbFunc: func() *HSDatabase {
|
||||
db := dbForTest(t, "simple-backfill-remove-ipv6")
|
||||
|
||||
db.DB.Save(&types.Node{
|
||||
IPv4: nap("100.64.0.1"),
|
||||
IPv6: nap("fd7a:115c:a1e0::1"),
|
||||
})
|
||||
|
||||
return db
|
||||
},
|
||||
|
||||
prefix4: mpp("100.64.0.0/10"),
|
||||
|
||||
want: types.Nodes{
|
||||
&types.Node{
|
||||
IPv4DatabaseField: sql.NullString{
|
||||
Valid: true,
|
||||
String: "100.64.0.1",
|
||||
},
|
||||
IPv4: nap("100.64.0.1"),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "simple-backfill-remove-ipv4",
|
||||
dbFunc: func() *HSDatabase {
|
||||
db := dbForTest(t, "simple-backfill-remove-ipv4")
|
||||
|
||||
db.DB.Save(&types.Node{
|
||||
IPv4: nap("100.64.0.1"),
|
||||
IPv6: nap("fd7a:115c:a1e0::1"),
|
||||
})
|
||||
|
||||
return db
|
||||
},
|
||||
|
||||
prefix6: mpp("fd7a:115c:a1e0::/48"),
|
||||
|
||||
want: types.Nodes{
|
||||
&types.Node{
|
||||
IPv6DatabaseField: sql.NullString{
|
||||
Valid: true,
|
||||
String: "fd7a:115c:a1e0::1",
|
||||
},
|
||||
IPv6: nap("fd7a:115c:a1e0::1"),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multi-backfill-ipv6",
|
||||
dbFunc: func() *HSDatabase {
|
||||
db := dbForTest(t, "simple-backfill-ipv6")
|
||||
|
||||
db.DB.Save(&types.Node{
|
||||
IPv4: nap("100.64.0.1"),
|
||||
})
|
||||
db.DB.Save(&types.Node{
|
||||
IPv4: nap("100.64.0.2"),
|
||||
})
|
||||
db.DB.Save(&types.Node{
|
||||
IPv4: nap("100.64.0.3"),
|
||||
})
|
||||
db.DB.Save(&types.Node{
|
||||
IPv4: nap("100.64.0.4"),
|
||||
})
|
||||
|
||||
return db
|
||||
},
|
||||
|
||||
prefix4: mpp("100.64.0.0/10"),
|
||||
prefix6: mpp("fd7a:115c:a1e0::/48"),
|
||||
|
||||
want: types.Nodes{
|
||||
fullNodeP(1),
|
||||
fullNodeP(2),
|
||||
fullNodeP(3),
|
||||
fullNodeP(4),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
comps := append(util.Comparers, cmpopts.IgnoreFields(types.Node{},
|
||||
"ID",
|
||||
"MachineKeyDatabaseField",
|
||||
"NodeKeyDatabaseField",
|
||||
"DiscoKeyDatabaseField",
|
||||
"Endpoints",
|
||||
"HostinfoDatabaseField",
|
||||
"Hostinfo",
|
||||
"Routes",
|
||||
"CreatedAt",
|
||||
"UpdatedAt",
|
||||
))
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db := tt.dbFunc()
|
||||
|
||||
alloc, err := NewIPAllocator(db, tt.prefix4, tt.prefix6, types.IPAllocationStrategySequential)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to set up ip alloc: %s", err)
|
||||
}
|
||||
|
||||
logs, err := db.BackfillNodeIPs(alloc)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to backfill: %s", err)
|
||||
}
|
||||
|
||||
t.Logf("backfill log: \n%s", strings.Join(logs, "\n"))
|
||||
|
||||
got, err := db.ListNodes()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get nodes: %s", err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.want, got, comps...); diff != "" {
|
||||
t.Errorf("Backfill unexpected result (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -5,7 +5,6 @@ import (
|
|||
"fmt"
|
||||
"net/netip"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
|
@ -294,7 +293,8 @@ func RegisterNodeFromAuthCallback(
|
|||
userName string,
|
||||
nodeExpiry *time.Time,
|
||||
registrationMethod string,
|
||||
addrs types.NodeAddresses,
|
||||
ipv4 *netip.Addr,
|
||||
ipv6 *netip.Addr,
|
||||
) (*types.Node, error) {
|
||||
log.Debug().
|
||||
Str("machine_key", mkey.ShortString()).
|
||||
|
@ -330,7 +330,7 @@ func RegisterNodeFromAuthCallback(
|
|||
node, err := RegisterNode(
|
||||
tx,
|
||||
registrationNode,
|
||||
addrs,
|
||||
ipv4, ipv6,
|
||||
)
|
||||
|
||||
if err == nil {
|
||||
|
@ -346,14 +346,14 @@ func RegisterNodeFromAuthCallback(
|
|||
return nil, ErrNodeNotFoundRegistrationCache
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) RegisterNode(node types.Node, addrs types.NodeAddresses) (*types.Node, error) {
|
||||
func (hsdb *HSDatabase) RegisterNode(node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) {
|
||||
return Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
return RegisterNode(tx, node, addrs)
|
||||
return RegisterNode(tx, node, ipv4, ipv6)
|
||||
})
|
||||
}
|
||||
|
||||
// RegisterNode is executed from the CLI to register a new Node using its MachineKey.
|
||||
func RegisterNode(tx *gorm.DB, node types.Node, addrs types.NodeAddresses) (*types.Node, error) {
|
||||
func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) {
|
||||
log.Debug().
|
||||
Str("node", node.Hostname).
|
||||
Str("machine_key", node.MachineKey.ShortString()).
|
||||
|
@ -361,10 +361,10 @@ func RegisterNode(tx *gorm.DB, node types.Node, addrs types.NodeAddresses) (*typ
|
|||
Str("user", node.User.Name).
|
||||
Msg("Registering node")
|
||||
|
||||
// If the node exists and we had already IPs for it, we just save it
|
||||
// If the node exists and it already has IP(s), we just save it
|
||||
// so we store the node.Expire and node.Nodekey that has been set when
|
||||
// adding it to the registrationCache
|
||||
if len(node.IPAddresses) > 0 {
|
||||
if node.IPv4 != nil || node.IPv6 != nil {
|
||||
if err := tx.Save(&node).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed register existing node in the database: %w", err)
|
||||
}
|
||||
|
@ -380,7 +380,8 @@ func RegisterNode(tx *gorm.DB, node types.Node, addrs types.NodeAddresses) (*typ
|
|||
return &node, nil
|
||||
}
|
||||
|
||||
node.IPAddresses = addrs
|
||||
node.IPv4 = ipv4
|
||||
node.IPv6 = ipv6
|
||||
|
||||
if err := tx.Save(&node).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed register(save) node in the database: %w", err)
|
||||
|
@ -389,7 +390,6 @@ func RegisterNode(tx *gorm.DB, node types.Node, addrs types.NodeAddresses) (*typ
|
|||
log.Trace().
|
||||
Caller().
|
||||
Str("node", node.Hostname).
|
||||
Str("ip", strings.Join(addrs.StringSlice(), ",")).
|
||||
Msg("Node registered with the database")
|
||||
|
||||
return &node, nil
|
||||
|
|
|
@ -188,13 +188,12 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
|||
nodeKey := key.NewNode()
|
||||
machineKey := key.NewMachine()
|
||||
|
||||
v4 := netip.MustParseAddr(fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1)))
|
||||
node := types.Node{
|
||||
ID: types.NodeID(index),
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
IPAddresses: types.NodeAddresses{
|
||||
netip.MustParseAddr(fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1))),
|
||||
},
|
||||
ID: types.NodeID(index),
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
IPv4: &v4,
|
||||
Hostname: "testnode" + strconv.Itoa(index),
|
||||
UserID: stor[index%2].user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
|
@ -301,27 +300,6 @@ func (s *Suite) TestExpireNode(c *check.C) {
|
|||
c.Assert(nodeFromDB.IsExpired(), check.Equals, true)
|
||||
}
|
||||
|
||||
func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) {
|
||||
input := types.NodeAddresses([]netip.Addr{
|
||||
netip.MustParseAddr("192.0.2.1"),
|
||||
netip.MustParseAddr("2001:db8::1"),
|
||||
})
|
||||
serialized, err := input.Value()
|
||||
c.Assert(err, check.IsNil)
|
||||
if serial, ok := serialized.(string); ok {
|
||||
c.Assert(serial, check.Equals, "192.0.2.1,2001:db8::1")
|
||||
}
|
||||
|
||||
var deserialized types.NodeAddresses
|
||||
err = deserialized.Scan(serialized)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(len(deserialized), check.Equals, len(input))
|
||||
for i := range deserialized {
|
||||
c.Assert(deserialized[i], check.Equals, input[i])
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Suite) TestGenerateGivenName(c *check.C) {
|
||||
user1, err := db.CreateUser("user-1")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
@ -561,6 +539,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
|
|||
// Check if a subprefix of an autoapproved route is approved
|
||||
route2 := netip.MustParsePrefix("10.11.0.0/24")
|
||||
|
||||
v4 := netip.MustParseAddr("100.64.0.1")
|
||||
node := types.Node{
|
||||
ID: 0,
|
||||
MachineKey: machineKey.Public(),
|
||||
|
@ -573,7 +552,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
|
|||
RequestTags: []string{"tag:exit"},
|
||||
RoutableIPs: []netip.Prefix{defaultRouteV4, defaultRouteV6, route1, route2},
|
||||
},
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
|
||||
IPv4: &v4,
|
||||
}
|
||||
|
||||
db.DB.Save(&node)
|
||||
|
|
|
@ -609,7 +609,7 @@ func EnableAutoApprovedRoutes(
|
|||
aclPolicy *policy.ACLPolicy,
|
||||
node *types.Node,
|
||||
) error {
|
||||
if len(node.IPAddresses) == 0 {
|
||||
if node.IPv4 == nil && node.IPv6 == nil {
|
||||
return nil // This node has no IPAddresses, so can't possibly match any autoApprovers ACLs
|
||||
}
|
||||
|
||||
|
@ -652,7 +652,7 @@ func EnableAutoApprovedRoutes(
|
|||
}
|
||||
|
||||
// approvedIPs should contain all of node's IPs if it matches the rule, so check for first
|
||||
if approvedIps.Contains(node.IPAddresses[0]) {
|
||||
if approvedIps.Contains(*node.IPv4) {
|
||||
approvedRoutes = append(approvedRoutes, advertisedRoute)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package hscontrol
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
@ -195,7 +196,7 @@ func (api headscaleV1APIServer) RegisterNode(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
addrs, err := api.h.ipAlloc.Next()
|
||||
ipv4, ipv6, err := api.h.ipAlloc.Next()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -208,7 +209,7 @@ func (api headscaleV1APIServer) RegisterNode(
|
|||
request.GetUser(),
|
||||
nil,
|
||||
util.RegisterMethodCLI,
|
||||
addrs,
|
||||
ipv4, ipv6,
|
||||
)
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -468,6 +469,24 @@ func (api headscaleV1APIServer) MoveNode(
|
|||
return &v1.MoveNodeResponse{Node: node.Proto()}, nil
|
||||
}
|
||||
|
||||
func (api headscaleV1APIServer) BackfillNodeIPs(
|
||||
ctx context.Context,
|
||||
request *v1.BackfillNodeIPsRequest,
|
||||
) (*v1.BackfillNodeIPsResponse, error) {
|
||||
log.Trace().Msg("Backfill called")
|
||||
|
||||
if !request.Confirmed {
|
||||
return nil, errors.New("not confirmed, aborting")
|
||||
}
|
||||
|
||||
changes, err := api.h.db.BackfillNodeIPs(api.h.ipAlloc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &v1.BackfillNodeIPsResponse{Changes: changes}, nil
|
||||
}
|
||||
|
||||
func (api headscaleV1APIServer) GetRoutes(
|
||||
ctx context.Context,
|
||||
request *v1.GetRoutesRequest,
|
||||
|
|
|
@ -174,8 +174,8 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
|
|||
"device_model": []string{node.Hostinfo.OS},
|
||||
}
|
||||
|
||||
if len(node.IPAddresses) > 0 {
|
||||
attrs.Add("device_ip", node.IPAddresses[0].String())
|
||||
if len(node.IPs()) > 0 {
|
||||
attrs.Add("device_ip", node.IPs()[0].String())
|
||||
}
|
||||
|
||||
resolver.Addr = fmt.Sprintf("%s?%s", resolver.Addr, attrs.Encode())
|
||||
|
|
|
@ -17,6 +17,11 @@ import (
|
|||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
var iap = func(ipStr string) *netip.Addr {
|
||||
ip := netip.MustParseAddr(ipStr)
|
||||
return &ip
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
|
||||
mach := func(hostname, username string, userid uint) *types.Node {
|
||||
return &types.Node{
|
||||
|
@ -176,17 +181,17 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
DiscoKey: mustDK(
|
||||
"discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084",
|
||||
),
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
|
||||
Hostname: "mini",
|
||||
GivenName: "mini",
|
||||
UserID: 0,
|
||||
User: types.User{Name: "mini"},
|
||||
ForcedTags: []string{},
|
||||
AuthKeyID: 0,
|
||||
AuthKey: &types.PreAuthKey{},
|
||||
LastSeen: &lastSeen,
|
||||
Expiry: &expire,
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
IPv4: iap("100.64.0.1"),
|
||||
Hostname: "mini",
|
||||
GivenName: "mini",
|
||||
UserID: 0,
|
||||
User: types.User{Name: "mini"},
|
||||
ForcedTags: []string{},
|
||||
AuthKeyID: 0,
|
||||
AuthKey: &types.PreAuthKey{},
|
||||
LastSeen: &lastSeen,
|
||||
Expiry: &expire,
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
Routes: []types.Route{
|
||||
{
|
||||
Prefix: types.IPPrefix(netip.MustParsePrefix("0.0.0.0/0")),
|
||||
|
@ -257,17 +262,17 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
DiscoKey: mustDK(
|
||||
"discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084",
|
||||
),
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
|
||||
Hostname: "peer1",
|
||||
GivenName: "peer1",
|
||||
UserID: 0,
|
||||
User: types.User{Name: "mini"},
|
||||
ForcedTags: []string{},
|
||||
LastSeen: &lastSeen,
|
||||
Expiry: &expire,
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
Routes: []types.Route{},
|
||||
CreatedAt: created,
|
||||
IPv4: iap("100.64.0.2"),
|
||||
Hostname: "peer1",
|
||||
GivenName: "peer1",
|
||||
UserID: 0,
|
||||
User: types.User{Name: "mini"},
|
||||
ForcedTags: []string{},
|
||||
LastSeen: &lastSeen,
|
||||
Expiry: &expire,
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
Routes: []types.Route{},
|
||||
CreatedAt: created,
|
||||
}
|
||||
|
||||
tailPeer1 := &tailcfg.Node{
|
||||
|
@ -312,17 +317,17 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
DiscoKey: mustDK(
|
||||
"discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084",
|
||||
),
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
|
||||
Hostname: "peer2",
|
||||
GivenName: "peer2",
|
||||
UserID: 1,
|
||||
User: types.User{Name: "peer2"},
|
||||
ForcedTags: []string{},
|
||||
LastSeen: &lastSeen,
|
||||
Expiry: &expire,
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
Routes: []types.Route{},
|
||||
CreatedAt: created,
|
||||
IPv4: iap("100.64.0.3"),
|
||||
Hostname: "peer2",
|
||||
GivenName: "peer2",
|
||||
UserID: 1,
|
||||
User: types.User{Name: "peer2"},
|
||||
ForcedTags: []string{},
|
||||
LastSeen: &lastSeen,
|
||||
Expiry: &expire,
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
Routes: []types.Route{},
|
||||
CreatedAt: created,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
|
|
|
@ -44,7 +44,7 @@ func tailNode(
|
|||
pol *policy.ACLPolicy,
|
||||
cfg *types.Config,
|
||||
) (*tailcfg.Node, error) {
|
||||
addrs := node.IPAddresses.Prefixes()
|
||||
addrs := node.Prefixes()
|
||||
|
||||
allowedIPs := append(
|
||||
[]netip.Prefix{},
|
||||
|
|
|
@ -89,9 +89,7 @@ func TestTailNode(t *testing.T) {
|
|||
DiscoKey: mustDK(
|
||||
"discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084",
|
||||
),
|
||||
IPAddresses: []netip.Addr{
|
||||
netip.MustParseAddr("100.64.0.1"),
|
||||
},
|
||||
IPv4: iap("100.64.0.1"),
|
||||
Hostname: "mini",
|
||||
GivenName: "mini",
|
||||
UserID: 0,
|
||||
|
|
|
@ -597,7 +597,7 @@ func (h *Headscale) registerNodeForOIDCCallback(
|
|||
machineKey *key.MachinePublic,
|
||||
expiry time.Time,
|
||||
) error {
|
||||
addrs, err := h.ipAlloc.Next()
|
||||
ipv4, ipv6, err := h.ipAlloc.Next()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -611,7 +611,7 @@ func (h *Headscale) registerNodeForOIDCCallback(
|
|||
user.Name,
|
||||
&expiry,
|
||||
util.RegisterMethodOIDC,
|
||||
addrs,
|
||||
ipv4, ipv6,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -229,7 +229,7 @@ func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.F
|
|||
continue
|
||||
}
|
||||
|
||||
if node.IPAddresses.InIPSet(expanded) {
|
||||
if node.InIPSet(expanded) {
|
||||
dests = append(dests, dest)
|
||||
}
|
||||
|
||||
|
@ -306,7 +306,7 @@ func (pol *ACLPolicy) CompileSSHPolicy(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if !node.IPAddresses.InIPSet(destSet) {
|
||||
if !node.InIPSet(destSet) {
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -744,7 +744,7 @@ func (pol *ACLPolicy) expandIPsFromGroup(
|
|||
for _, user := range users {
|
||||
filteredNodes := filterNodesByUser(nodes, user)
|
||||
for _, node := range filteredNodes {
|
||||
node.IPAddresses.AppendToIPSet(&build)
|
||||
node.AppendToIPSet(&build)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -760,7 +760,7 @@ func (pol *ACLPolicy) expandIPsFromTag(
|
|||
// check for forced tags
|
||||
for _, node := range nodes {
|
||||
if util.StringOrPrefixListContains(node.ForcedTags, alias) {
|
||||
node.IPAddresses.AppendToIPSet(&build)
|
||||
node.AppendToIPSet(&build)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -792,7 +792,7 @@ func (pol *ACLPolicy) expandIPsFromTag(
|
|||
}
|
||||
|
||||
if util.StringOrPrefixListContains(node.Hostinfo.RequestTags, alias) {
|
||||
node.IPAddresses.AppendToIPSet(&build)
|
||||
node.AppendToIPSet(&build)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -815,7 +815,7 @@ func (pol *ACLPolicy) expandIPsFromUser(
|
|||
}
|
||||
|
||||
for _, node := range filteredNodes {
|
||||
node.IPAddresses.AppendToIPSet(&build)
|
||||
node.AppendToIPSet(&build)
|
||||
}
|
||||
|
||||
return build.IPSet()
|
||||
|
@ -833,7 +833,7 @@ func (pol *ACLPolicy) expandIPsFromSingleIP(
|
|||
build.Add(ip)
|
||||
|
||||
for _, node := range matches {
|
||||
node.IPAddresses.AppendToIPSet(&build)
|
||||
node.AppendToIPSet(&build)
|
||||
}
|
||||
|
||||
return build.IPSet()
|
||||
|
@ -850,11 +850,11 @@ func (pol *ACLPolicy) expandIPsFromIPPrefix(
|
|||
// This is suboptimal and quite expensive, but if we only add the prefix, we will miss all the relevant IPv6
|
||||
// addresses for the hosts that belong to tailscale. This doesnt really affect stuff like subnet routers.
|
||||
for _, node := range nodes {
|
||||
for _, ip := range node.IPAddresses {
|
||||
for _, ip := range node.IPs() {
|
||||
// log.Trace().
|
||||
// Msgf("checking if node ip (%s) is part of prefix (%s): %v, is single ip prefix (%v), addr: %s", ip.String(), prefix.String(), prefix.Contains(ip), prefix.IsSingleIP(), prefix.Addr().String())
|
||||
if prefix.Contains(ip) {
|
||||
node.IPAddresses.AppendToIPSet(&build)
|
||||
node.AppendToIPSet(&build)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -31,6 +31,13 @@ var errOidcMutuallyExclusive = errors.New(
|
|||
"oidc_client_secret and oidc_client_secret_path are mutually exclusive",
|
||||
)
|
||||
|
||||
type IPAllocationStrategy string
|
||||
|
||||
const (
|
||||
IPAllocationStrategySequential IPAllocationStrategy = "sequential"
|
||||
IPAllocationStrategyRandom IPAllocationStrategy = "random"
|
||||
)
|
||||
|
||||
// Config contains the initial Headscale configuration.
|
||||
type Config struct {
|
||||
ServerURL string
|
||||
|
@ -42,6 +49,7 @@ type Config struct {
|
|||
NodeUpdateCheckInterval time.Duration
|
||||
PrefixV4 *netip.Prefix
|
||||
PrefixV6 *netip.Prefix
|
||||
IPAllocation IPAllocationStrategy
|
||||
NoisePrivateKeyPath string
|
||||
BaseDomain string
|
||||
Log LogConfig
|
||||
|
@ -230,6 +238,8 @@ func LoadConfig(path string, isFile bool) error {
|
|||
viper.SetDefault("tuning.batch_change_delay", "800ms")
|
||||
viper.SetDefault("tuning.node_mapsession_buffered_chan_size", 30)
|
||||
|
||||
viper.SetDefault("prefixes.allocation", IPAllocationStrategySequential)
|
||||
|
||||
if IsCLIConfigured() {
|
||||
return nil
|
||||
}
|
||||
|
@ -579,18 +589,16 @@ func GetDNSConfig() (*tailcfg.DNSConfig, string) {
|
|||
return nil, ""
|
||||
}
|
||||
|
||||
func Prefixes() (*netip.Prefix, *netip.Prefix, error) {
|
||||
func PrefixV4() (*netip.Prefix, error) {
|
||||
prefixV4Str := viper.GetString("prefixes.v4")
|
||||
prefixV6Str := viper.GetString("prefixes.v6")
|
||||
|
||||
if prefixV4Str == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
prefixV4, err := netip.ParsePrefix(prefixV4Str)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
prefixV6, err := netip.ParsePrefix(prefixV6Str)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, fmt.Errorf("parsing IPv4 prefix from config: %w", err)
|
||||
}
|
||||
|
||||
builder := netipx.IPSetBuilder{}
|
||||
|
@ -603,13 +611,33 @@ func Prefixes() (*netip.Prefix, *netip.Prefix, error) {
|
|||
prefixV4Str, tsaddr.CGNATRange())
|
||||
}
|
||||
|
||||
return &prefixV4, nil
|
||||
}
|
||||
|
||||
func PrefixV6() (*netip.Prefix, error) {
|
||||
prefixV6Str := viper.GetString("prefixes.v6")
|
||||
|
||||
if prefixV6Str == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
prefixV6, err := netip.ParsePrefix(prefixV6Str)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing IPv6 prefix from config: %w", err)
|
||||
}
|
||||
|
||||
builder := netipx.IPSetBuilder{}
|
||||
builder.AddPrefix(tsaddr.CGNATRange())
|
||||
builder.AddPrefix(tsaddr.TailscaleULARange())
|
||||
ipSet, _ := builder.IPSet()
|
||||
|
||||
if !ipSet.ContainsPrefix(prefixV6) {
|
||||
log.Warn().
|
||||
Msgf("Prefix %s is not in the %s range. This is an unsupported configuration.",
|
||||
prefixV6Str, tsaddr.TailscaleULARange())
|
||||
}
|
||||
|
||||
return &prefixV4, &prefixV6, nil
|
||||
return &prefixV6, nil
|
||||
}
|
||||
|
||||
func GetHeadscaleConfig() (*Config, error) {
|
||||
|
@ -624,11 +652,27 @@ func GetHeadscaleConfig() (*Config, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
prefix4, prefix6, err := Prefixes()
|
||||
prefix4, err := PrefixV4()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
prefix6, err := PrefixV6()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
allocStr := viper.GetString("prefixes.allocation")
|
||||
var alloc IPAllocationStrategy
|
||||
switch allocStr {
|
||||
case string(IPAllocationStrategySequential):
|
||||
alloc = IPAllocationStrategySequential
|
||||
case string(IPAllocationStrategyRandom):
|
||||
alloc = IPAllocationStrategyRandom
|
||||
default:
|
||||
log.Fatal().Msgf("config error, prefixes.allocation is set to %s, which is not a valid strategy, allowed options: %s, %s", allocStr, IPAllocationStrategySequential, IPAllocationStrategyRandom)
|
||||
}
|
||||
|
||||
dnsConfig, baseDomain := GetDNSConfig()
|
||||
derpConfig := GetDERPConfig()
|
||||
logConfig := GetLogTailConfig()
|
||||
|
@ -655,8 +699,9 @@ func GetHeadscaleConfig() (*Config, error) {
|
|||
GRPCAllowInsecure: viper.GetBool("grpc_allow_insecure"),
|
||||
DisableUpdateCheck: viper.GetBool("disable_check_updates"),
|
||||
|
||||
PrefixV4: prefix4,
|
||||
PrefixV6: prefix6,
|
||||
PrefixV4: prefix4,
|
||||
PrefixV6: prefix6,
|
||||
IPAllocation: IPAllocationStrategy(alloc),
|
||||
|
||||
NoisePrivateKeyPath: util.AbsolutePathFromConfigPath(
|
||||
viper.GetString("noise.private_key_path"),
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -14,7 +13,6 @@ import (
|
|||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"go4.org/netipx"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
"gorm.io/gorm"
|
||||
|
@ -83,7 +81,19 @@ type Node struct {
|
|||
HostinfoDatabaseField string `gorm:"column:host_info"`
|
||||
Hostinfo *tailcfg.Hostinfo `gorm:"-"`
|
||||
|
||||
IPAddresses NodeAddresses
|
||||
// IPv4DatabaseField is the string representation of v4 address,
|
||||
// it is _only_ used for reading and writing the key to the
|
||||
// database and should not be used.
|
||||
// Use V4 instead.
|
||||
IPv4DatabaseField sql.NullString `gorm:"column:ipv4"`
|
||||
IPv4 *netip.Addr `gorm:"-"`
|
||||
|
||||
// IPv6DatabaseField is the string representation of v4 address,
|
||||
// it is _only_ used for reading and writing the key to the
|
||||
// database and should not be used.
|
||||
// Use V6 instead.
|
||||
IPv6DatabaseField sql.NullString `gorm:"column:ipv6"`
|
||||
IPv6 *netip.Addr `gorm:"-"`
|
||||
|
||||
// Hostname represents the name given by the Tailscale
|
||||
// client during registration
|
||||
|
@ -123,89 +133,6 @@ type (
|
|||
Nodes []*Node
|
||||
)
|
||||
|
||||
type NodeAddresses []netip.Addr
|
||||
|
||||
func (na NodeAddresses) Sort() {
|
||||
sort.Slice(na, func(index1, index2 int) bool {
|
||||
if na[index1].Is4() && na[index2].Is6() {
|
||||
return true
|
||||
}
|
||||
if na[index1].Is6() && na[index2].Is4() {
|
||||
return false
|
||||
}
|
||||
|
||||
return na[index1].Compare(na[index2]) < 0
|
||||
})
|
||||
}
|
||||
|
||||
func (na NodeAddresses) StringSlice() []string {
|
||||
na.Sort()
|
||||
strSlice := make([]string, 0, len(na))
|
||||
for _, addr := range na {
|
||||
strSlice = append(strSlice, addr.String())
|
||||
}
|
||||
|
||||
return strSlice
|
||||
}
|
||||
|
||||
func (na NodeAddresses) Prefixes() []netip.Prefix {
|
||||
addrs := []netip.Prefix{}
|
||||
for _, nodeAddress := range na {
|
||||
ip := netip.PrefixFrom(nodeAddress, nodeAddress.BitLen())
|
||||
addrs = append(addrs, ip)
|
||||
}
|
||||
|
||||
return addrs
|
||||
}
|
||||
|
||||
func (na NodeAddresses) InIPSet(set *netipx.IPSet) bool {
|
||||
for _, nodeAddr := range na {
|
||||
if set.Contains(nodeAddr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// AppendToIPSet adds the individual ips in NodeAddresses to a
|
||||
// given netipx.IPSetBuilder.
|
||||
func (na NodeAddresses) AppendToIPSet(build *netipx.IPSetBuilder) {
|
||||
for _, ip := range na {
|
||||
build.Add(ip)
|
||||
}
|
||||
}
|
||||
|
||||
func (na *NodeAddresses) Scan(destination interface{}) error {
|
||||
switch value := destination.(type) {
|
||||
case string:
|
||||
addresses := strings.Split(value, ",")
|
||||
*na = (*na)[:0]
|
||||
for _, addr := range addresses {
|
||||
if len(addr) < 1 {
|
||||
continue
|
||||
}
|
||||
parsed, err := netip.ParseAddr(addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*na = append(*na, parsed)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
default:
|
||||
return fmt.Errorf("%w: unexpected data type %T", ErrNodeAddressesInvalid, destination)
|
||||
}
|
||||
}
|
||||
|
||||
// Value return json value, implement driver.Valuer interface.
|
||||
func (na NodeAddresses) Value() (driver.Value, error) {
|
||||
addresses := strings.Join(na.StringSlice(), ",")
|
||||
|
||||
return addresses, nil
|
||||
}
|
||||
|
||||
// IsExpired returns whether the node registration has expired.
|
||||
func (node Node) IsExpired() bool {
|
||||
// If Expiry is not set, the client has not indicated that
|
||||
|
@ -224,8 +151,65 @@ func (node *Node) IsEphemeral() bool {
|
|||
return node.AuthKey != nil && node.AuthKey.Ephemeral
|
||||
}
|
||||
|
||||
func (node *Node) IPs() []netip.Addr {
|
||||
var ret []netip.Addr
|
||||
|
||||
if node.IPv4 != nil {
|
||||
ret = append(ret, *node.IPv4)
|
||||
}
|
||||
|
||||
if node.IPv6 != nil {
|
||||
ret = append(ret, *node.IPv6)
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func (node *Node) Prefixes() []netip.Prefix {
|
||||
addrs := []netip.Prefix{}
|
||||
for _, nodeAddress := range node.IPs() {
|
||||
ip := netip.PrefixFrom(nodeAddress, nodeAddress.BitLen())
|
||||
addrs = append(addrs, ip)
|
||||
}
|
||||
|
||||
return addrs
|
||||
}
|
||||
|
||||
func (node *Node) IPsAsString() []string {
|
||||
var ret []string
|
||||
|
||||
if node.IPv4 != nil {
|
||||
ret = append(ret, node.IPv4.String())
|
||||
}
|
||||
|
||||
if node.IPv6 != nil {
|
||||
ret = append(ret, node.IPv6.String())
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func (node *Node) InIPSet(set *netipx.IPSet) bool {
|
||||
for _, nodeAddr := range node.IPs() {
|
||||
if set.Contains(nodeAddr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// AppendToIPSet adds the individual ips in NodeAddresses to a
|
||||
// given netipx.IPSetBuilder.
|
||||
func (node *Node) AppendToIPSet(build *netipx.IPSetBuilder) {
|
||||
for _, ip := range node.IPs() {
|
||||
build.Add(ip)
|
||||
}
|
||||
}
|
||||
|
||||
func (node *Node) CanAccess(filter []tailcfg.FilterRule, node2 *Node) bool {
|
||||
allowedIPs := append([]netip.Addr{}, node2.IPAddresses...)
|
||||
src := node.IPs()
|
||||
allowedIPs := node2.IPs()
|
||||
|
||||
for _, route := range node2.Routes {
|
||||
if route.Enabled {
|
||||
|
@ -237,7 +221,7 @@ func (node *Node) CanAccess(filter []tailcfg.FilterRule, node2 *Node) bool {
|
|||
// TODO(kradalby): Cache or pregen this
|
||||
matcher := matcher.MatchFromFilterRule(rule)
|
||||
|
||||
if !matcher.SrcsContainsIPs([]netip.Addr(node.IPAddresses)) {
|
||||
if !matcher.SrcsContainsIPs(src) {
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -250,13 +234,16 @@ func (node *Node) CanAccess(filter []tailcfg.FilterRule, node2 *Node) bool {
|
|||
}
|
||||
|
||||
func (nodes Nodes) FilterByIP(ip netip.Addr) Nodes {
|
||||
found := make(Nodes, 0)
|
||||
var found Nodes
|
||||
|
||||
for _, node := range nodes {
|
||||
for _, mIP := range node.IPAddresses {
|
||||
if ip == mIP {
|
||||
found = append(found, node)
|
||||
}
|
||||
if node.IPv4 != nil && ip == *node.IPv4 {
|
||||
found = append(found, node)
|
||||
continue
|
||||
}
|
||||
|
||||
if node.IPv6 != nil && ip == *node.IPv6 {
|
||||
found = append(found, node)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -281,10 +268,22 @@ func (node *Node) BeforeSave(tx *gorm.DB) error {
|
|||
|
||||
hi, err := json.Marshal(node.Hostinfo)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal Hostinfo to store in db: %w", err)
|
||||
return fmt.Errorf("marshalling Hostinfo to store in db: %w", err)
|
||||
}
|
||||
node.HostinfoDatabaseField = string(hi)
|
||||
|
||||
if node.IPv4 != nil {
|
||||
node.IPv4DatabaseField.String, node.IPv4DatabaseField.Valid = node.IPv4.String(), true
|
||||
} else {
|
||||
node.IPv4DatabaseField.String, node.IPv4DatabaseField.Valid = "", false
|
||||
}
|
||||
|
||||
if node.IPv6 != nil {
|
||||
node.IPv6DatabaseField.String, node.IPv6DatabaseField.Valid = node.IPv6.String(), true
|
||||
} else {
|
||||
node.IPv6DatabaseField.String, node.IPv6DatabaseField.Valid = "", false
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -296,19 +295,19 @@ func (node *Node) BeforeSave(tx *gorm.DB) error {
|
|||
func (node *Node) AfterFind(tx *gorm.DB) error {
|
||||
var machineKey key.MachinePublic
|
||||
if err := machineKey.UnmarshalText([]byte(node.MachineKeyDatabaseField)); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal machine key from db: %w", err)
|
||||
return fmt.Errorf("unmarshalling machine key from db: %w", err)
|
||||
}
|
||||
node.MachineKey = machineKey
|
||||
|
||||
var nodeKey key.NodePublic
|
||||
if err := nodeKey.UnmarshalText([]byte(node.NodeKeyDatabaseField)); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal node key from db: %w", err)
|
||||
return fmt.Errorf("unmarshalling node key from db: %w", err)
|
||||
}
|
||||
node.NodeKey = nodeKey
|
||||
|
||||
var discoKey key.DiscoPublic
|
||||
if err := discoKey.UnmarshalText([]byte(node.DiscoKeyDatabaseField)); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal disco key from db: %w", err)
|
||||
return fmt.Errorf("unmarshalling disco key from db: %w", err)
|
||||
}
|
||||
node.DiscoKey = discoKey
|
||||
|
||||
|
@ -316,7 +315,7 @@ func (node *Node) AfterFind(tx *gorm.DB) error {
|
|||
for idx, ep := range node.EndpointsDatabaseField {
|
||||
addrPort, err := netip.ParseAddrPort(ep)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse endpoint from db: %w", err)
|
||||
return fmt.Errorf("parsing endpoint from db: %w", err)
|
||||
}
|
||||
|
||||
endpoints[idx] = addrPort
|
||||
|
@ -325,12 +324,28 @@ func (node *Node) AfterFind(tx *gorm.DB) error {
|
|||
|
||||
var hi tailcfg.Hostinfo
|
||||
if err := json.Unmarshal([]byte(node.HostinfoDatabaseField), &hi); err != nil {
|
||||
log.Trace().Err(err).Msgf("Hostinfo content: %s", node.HostinfoDatabaseField)
|
||||
|
||||
return fmt.Errorf("failed to unmarshal Hostinfo from db: %w", err)
|
||||
return fmt.Errorf("unmarshalling hostinfo from database: %w", err)
|
||||
}
|
||||
node.Hostinfo = &hi
|
||||
|
||||
if node.IPv4DatabaseField.Valid {
|
||||
ip, err := netip.ParseAddr(node.IPv4DatabaseField.String)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing IPv4 from database: %w", err)
|
||||
}
|
||||
|
||||
node.IPv4 = &ip
|
||||
}
|
||||
|
||||
if node.IPv6DatabaseField.Valid {
|
||||
ip, err := netip.ParseAddr(node.IPv6DatabaseField.String)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing IPv6 from database: %w", err)
|
||||
}
|
||||
|
||||
node.IPv6 = &ip
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -339,9 +354,11 @@ func (node *Node) Proto() *v1.Node {
|
|||
Id: uint64(node.ID),
|
||||
MachineKey: node.MachineKey.String(),
|
||||
|
||||
NodeKey: node.NodeKey.String(),
|
||||
DiscoKey: node.DiscoKey.String(),
|
||||
IpAddresses: node.IPAddresses.StringSlice(),
|
||||
NodeKey: node.NodeKey.String(),
|
||||
DiscoKey: node.DiscoKey.String(),
|
||||
|
||||
// TODO(kradalby): replace list with v4, v6 field?
|
||||
IpAddresses: node.IPsAsString(),
|
||||
Name: node.Hostname,
|
||||
GivenName: node.GivenName,
|
||||
User: node.User.Proto(),
|
||||
|
|
|
@ -12,6 +12,10 @@ import (
|
|||
)
|
||||
|
||||
func Test_NodeCanAccess(t *testing.T) {
|
||||
iap := func(ipStr string) *netip.Addr {
|
||||
ip := netip.MustParseAddr(ipStr)
|
||||
return &ip
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
node1 Node
|
||||
|
@ -22,10 +26,10 @@ func Test_NodeCanAccess(t *testing.T) {
|
|||
{
|
||||
name: "no-rules",
|
||||
node1: Node{
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("10.0.0.1")},
|
||||
IPv4: iap("10.0.0.1"),
|
||||
},
|
||||
node2: Node{
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("10.0.0.2")},
|
||||
IPv4: iap("10.0.0.2"),
|
||||
},
|
||||
rules: []tailcfg.FilterRule{},
|
||||
want: false,
|
||||
|
@ -33,10 +37,10 @@ func Test_NodeCanAccess(t *testing.T) {
|
|||
{
|
||||
name: "wildcard",
|
||||
node1: Node{
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("10.0.0.1")},
|
||||
IPv4: iap("10.0.0.1"),
|
||||
},
|
||||
node2: Node{
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("10.0.0.2")},
|
||||
IPv4: iap("10.0.0.2"),
|
||||
},
|
||||
rules: []tailcfg.FilterRule{
|
||||
{
|
||||
|
@ -54,10 +58,10 @@ func Test_NodeCanAccess(t *testing.T) {
|
|||
{
|
||||
name: "other-cant-access-src",
|
||||
node1: Node{
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
|
||||
IPv4: iap("100.64.0.1"),
|
||||
},
|
||||
node2: Node{
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
|
||||
IPv4: iap("100.64.0.3"),
|
||||
},
|
||||
rules: []tailcfg.FilterRule{
|
||||
{
|
||||
|
@ -72,10 +76,10 @@ func Test_NodeCanAccess(t *testing.T) {
|
|||
{
|
||||
name: "dest-cant-access-src",
|
||||
node1: Node{
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
|
||||
IPv4: iap("100.64.0.3"),
|
||||
},
|
||||
node2: Node{
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
|
||||
IPv4: iap("100.64.0.2"),
|
||||
},
|
||||
rules: []tailcfg.FilterRule{
|
||||
{
|
||||
|
@ -90,10 +94,10 @@ func Test_NodeCanAccess(t *testing.T) {
|
|||
{
|
||||
name: "src-can-access-dest",
|
||||
node1: Node{
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
|
||||
IPv4: iap("100.64.0.2"),
|
||||
},
|
||||
node2: Node{
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
|
||||
IPv4: iap("100.64.0.3"),
|
||||
},
|
||||
rules: []tailcfg.FilterRule{
|
||||
{
|
||||
|
@ -118,32 +122,6 @@ func Test_NodeCanAccess(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestNodeAddressesOrder(t *testing.T) {
|
||||
machineAddresses := NodeAddresses{
|
||||
netip.MustParseAddr("2001:db8::2"),
|
||||
netip.MustParseAddr("100.64.0.2"),
|
||||
netip.MustParseAddr("2001:db8::1"),
|
||||
netip.MustParseAddr("100.64.0.1"),
|
||||
}
|
||||
|
||||
strSlice := machineAddresses.StringSlice()
|
||||
expected := []string{
|
||||
"100.64.0.1",
|
||||
"100.64.0.2",
|
||||
"2001:db8::1",
|
||||
"2001:db8::2",
|
||||
}
|
||||
|
||||
if len(strSlice) != len(expected) {
|
||||
t.Fatalf("unexpected slice length: got %v, want %v", len(strSlice), len(expected))
|
||||
}
|
||||
for i, addr := range strSlice {
|
||||
if addr != expected[i] {
|
||||
t.Errorf("unexpected address at index %v: got %v, want %v", i, addr, expected[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNodeFQDN(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
|
|
@ -103,33 +103,7 @@ func CheckForFQDNRules(name string) error {
|
|||
|
||||
// From the netmask we can find out the wildcard bits (the bits that are not set in the netmask).
|
||||
// This allows us to then calculate the subnets included in the subsequent class block and generate the entries.
|
||||
func GenerateMagicDNSRootDomains(ipPrefixes []netip.Prefix) []dnsname.FQDN {
|
||||
fqdns := make([]dnsname.FQDN, 0, len(ipPrefixes))
|
||||
for _, ipPrefix := range ipPrefixes {
|
||||
var generateDNSRoot func(netip.Prefix) []dnsname.FQDN
|
||||
switch ipPrefix.Addr().BitLen() {
|
||||
case ipv4AddressLength:
|
||||
generateDNSRoot = generateIPv4DNSRootDomain
|
||||
|
||||
case ipv6AddressLength:
|
||||
generateDNSRoot = generateIPv6DNSRootDomain
|
||||
|
||||
default:
|
||||
panic(
|
||||
fmt.Sprintf(
|
||||
"unsupported IP version with address length %d",
|
||||
ipPrefix.Addr().BitLen(),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
fqdns = append(fqdns, generateDNSRoot(ipPrefix)...)
|
||||
}
|
||||
|
||||
return fqdns
|
||||
}
|
||||
|
||||
func generateIPv4DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
|
||||
func GenerateIPv4DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
|
||||
// Conversion to the std lib net.IPnet, a bit easier to operate
|
||||
netRange := netipx.PrefixIPNet(ipPrefix)
|
||||
maskBits, _ := netRange.Mask.Size()
|
||||
|
@ -165,7 +139,27 @@ func generateIPv4DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
|
|||
return fqdns
|
||||
}
|
||||
|
||||
func generateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
|
||||
// generateMagicDNSRootDomains generates a list of DNS entries to be included in `Routes` in `MapResponse`.
|
||||
// This list of reverse DNS entries instructs the OS on what subnets and domains the Tailscale embedded DNS
|
||||
// server (listening in 100.100.100.100 udp/53) should be used for.
|
||||
//
|
||||
// Tailscale.com includes in the list:
|
||||
// - the `BaseDomain` of the user
|
||||
// - the reverse DNS entry for IPv6 (0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa., see below more on IPv6)
|
||||
// - the reverse DNS entries for the IPv4 subnets covered by the user's `IPPrefix`.
|
||||
// In the public SaaS this is [64-127].100.in-addr.arpa.
|
||||
//
|
||||
// The main purpose of this function is then generating the list of IPv4 entries. For the 100.64.0.0/10, this
|
||||
// is clear, and could be hardcoded. But we are allowing any range as `IPPrefix`, so we need to find out the
|
||||
// subnets when we have 172.16.0.0/16 (i.e., [0-255].16.172.in-addr.arpa.), or any other subnet.
|
||||
//
|
||||
// How IN-ADDR.ARPA domains work is defined in RFC1035 (section 3.5). Tailscale.com seems to adhere to this,
|
||||
// and do not make use of RFC2317 ("Classless IN-ADDR.ARPA delegation") - hence generating the entries for the next
|
||||
// class block only.
|
||||
|
||||
// From the netmask we can find out the wildcard bits (the bits that are not set in the netmask).
|
||||
// This allows us to then calculate the subnets included in the subsequent class block and generate the entries.
|
||||
func GenerateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
|
||||
const nibbleLen = 4
|
||||
|
||||
maskBits, _ := netipx.PrefixIPNet(ipPrefix).Mask.Size()
|
||||
|
|
|
@ -148,10 +148,7 @@ func TestCheckForFQDNRules(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMagicDNSRootDomains100(t *testing.T) {
|
||||
prefixes := []netip.Prefix{
|
||||
netip.MustParsePrefix("100.64.0.0/10"),
|
||||
}
|
||||
domains := GenerateMagicDNSRootDomains(prefixes)
|
||||
domains := GenerateIPv4DNSRootDomain(netip.MustParsePrefix("100.64.0.0/10"))
|
||||
|
||||
found := false
|
||||
for _, domain := range domains {
|
||||
|
@ -185,10 +182,7 @@ func TestMagicDNSRootDomains100(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMagicDNSRootDomains172(t *testing.T) {
|
||||
prefixes := []netip.Prefix{
|
||||
netip.MustParsePrefix("172.16.0.0/16"),
|
||||
}
|
||||
domains := GenerateMagicDNSRootDomains(prefixes)
|
||||
domains := GenerateIPv4DNSRootDomain(netip.MustParsePrefix("172.16.0.0/16"))
|
||||
|
||||
found := false
|
||||
for _, domain := range domains {
|
||||
|
@ -213,20 +207,14 @@ func TestMagicDNSRootDomains172(t *testing.T) {
|
|||
|
||||
// Happens when netmask is a multiple of 4 bits (sounds likely).
|
||||
func TestMagicDNSRootDomainsIPv6Single(t *testing.T) {
|
||||
prefixes := []netip.Prefix{
|
||||
netip.MustParsePrefix("fd7a:115c:a1e0::/48"),
|
||||
}
|
||||
domains := GenerateMagicDNSRootDomains(prefixes)
|
||||
domains := GenerateIPv6DNSRootDomain(netip.MustParsePrefix("fd7a:115c:a1e0::/48"))
|
||||
|
||||
assert.Len(t, domains, 1)
|
||||
assert.Equal(t, "0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa.", domains[0].WithTrailingDot())
|
||||
}
|
||||
|
||||
func TestMagicDNSRootDomainsIPv6SingleMultiple(t *testing.T) {
|
||||
prefixes := []netip.Prefix{
|
||||
netip.MustParsePrefix("fd7a:115c:a1e0::/50"),
|
||||
}
|
||||
domains := GenerateMagicDNSRootDomains(prefixes)
|
||||
domains := GenerateIPv6DNSRootDomain(netip.MustParsePrefix("fd7a:115c:a1e0::/50"))
|
||||
|
||||
yieldsRoot := func(dom string) bool {
|
||||
for _, candidate := range domains {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue