new IP allocator and add postgres to integration tests. (#1756)
This commit is contained in:
parent
f581d4d9c0
commit
384ca03208
119 changed files with 3686 additions and 443 deletions
|
@ -1,106 +0,0 @@
|
|||
// Codehere is mostly taken from github.com/tailscale/tailscale
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"go4.org/netipx"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var ErrCouldNotAllocateIP = errors.New("could not find any suitable IP")
|
||||
|
||||
func (hsdb *HSDatabase) getAvailableIPs() (types.NodeAddresses, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (types.NodeAddresses, error) {
|
||||
return getAvailableIPs(rx, hsdb.ipPrefixes)
|
||||
})
|
||||
}
|
||||
|
||||
func getAvailableIPs(rx *gorm.DB, ipPrefixes []netip.Prefix) (types.NodeAddresses, error) {
|
||||
var ips types.NodeAddresses
|
||||
var err error
|
||||
for _, ipPrefix := range ipPrefixes {
|
||||
var ip *netip.Addr
|
||||
ip, err = getAvailableIP(rx, ipPrefix)
|
||||
if err != nil {
|
||||
return ips, err
|
||||
}
|
||||
ips = append(ips, *ip)
|
||||
}
|
||||
|
||||
return ips, err
|
||||
}
|
||||
|
||||
func getAvailableIP(rx *gorm.DB, ipPrefix netip.Prefix) (*netip.Addr, error) {
|
||||
usedIps, err := getUsedIPs(rx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ipPrefixNetworkAddress, ipPrefixBroadcastAddress := util.GetIPPrefixEndpoints(ipPrefix)
|
||||
|
||||
// Get the first IP in our prefix
|
||||
ip := ipPrefixNetworkAddress.Next()
|
||||
|
||||
for {
|
||||
if !ipPrefix.Contains(ip) {
|
||||
return nil, ErrCouldNotAllocateIP
|
||||
}
|
||||
|
||||
switch {
|
||||
case ip.Compare(ipPrefixBroadcastAddress) == 0:
|
||||
fallthrough
|
||||
case usedIps.Contains(ip):
|
||||
fallthrough
|
||||
case ip == netip.Addr{} || ip.IsLoopback():
|
||||
ip = ip.Next()
|
||||
|
||||
continue
|
||||
|
||||
default:
|
||||
return &ip, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getUsedIPs(rx *gorm.DB) (*netipx.IPSet, error) {
|
||||
// FIXME: This really deserves a better data model,
|
||||
// but this was quick to get running and it should be enough
|
||||
// to begin experimenting with a dual stack tailnet.
|
||||
var addressesSlices []string
|
||||
rx.Model(&types.Node{}).Pluck("ip_addresses", &addressesSlices)
|
||||
|
||||
var ips netipx.IPSetBuilder
|
||||
for _, slice := range addressesSlices {
|
||||
var machineAddresses types.NodeAddresses
|
||||
err := machineAddresses.Scan(slice)
|
||||
if err != nil {
|
||||
return &netipx.IPSet{}, fmt.Errorf(
|
||||
"failed to read ip from database: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
for _, ip := range machineAddresses {
|
||||
ips.Add(ip)
|
||||
}
|
||||
}
|
||||
|
||||
ipSet, err := ips.IPSet()
|
||||
if err != nil {
|
||||
return &netipx.IPSet{}, fmt.Errorf(
|
||||
"failed to build IP Set: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
return ipSet, nil
|
||||
}
|
|
@ -1,196 +0,0 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"go4.org/netipx"
|
||||
"gopkg.in/check.v1"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func (s *Suite) TestGetAvailableIp(c *check.C) {
|
||||
tx := db.DB.Begin()
|
||||
defer tx.Rollback()
|
||||
|
||||
ips, err := getAvailableIPs(tx, []netip.Prefix{
|
||||
netip.MustParsePrefix("10.27.0.0/23"),
|
||||
})
|
||||
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
expected := netip.MustParseAddr("10.27.0.1")
|
||||
|
||||
c.Assert(len(ips), check.Equals, 1)
|
||||
c.Assert(ips[0].String(), check.Equals, expected.String())
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetUsedIps(c *check.C) {
|
||||
ips, err := db.getAvailableIPs()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
user, err := db.CreateUser("test-ip")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.getNode("test", "testnode")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
node := types.Node{
|
||||
ID: 0,
|
||||
Hostname: "testnode",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
IPAddresses: ips,
|
||||
}
|
||||
db.Write(func(tx *gorm.DB) error {
|
||||
return tx.Save(&node).Error
|
||||
})
|
||||
|
||||
usedIps, err := Read(db.DB, func(rx *gorm.DB) (*netipx.IPSet, error) {
|
||||
return getUsedIPs(rx)
|
||||
})
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
expected := netip.MustParseAddr("10.27.0.1")
|
||||
expectedIPSetBuilder := netipx.IPSetBuilder{}
|
||||
expectedIPSetBuilder.Add(expected)
|
||||
expectedIPSet, _ := expectedIPSetBuilder.IPSet()
|
||||
|
||||
c.Assert(usedIps.Equal(expectedIPSet), check.Equals, true)
|
||||
c.Assert(usedIps.Contains(expected), check.Equals, true)
|
||||
|
||||
node1, err := db.GetNodeByID(0)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(len(node1.IPAddresses), check.Equals, 1)
|
||||
c.Assert(node1.IPAddresses[0], check.Equals, expected)
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetMultiIp(c *check.C) {
|
||||
user, err := db.CreateUser("test-ip")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
ipPrefixes := []netip.Prefix{
|
||||
netip.MustParsePrefix("10.27.0.0/23"),
|
||||
}
|
||||
|
||||
for index := 1; index <= 350; index++ {
|
||||
tx := db.DB.Begin()
|
||||
|
||||
ips, err := getAvailableIPs(tx, ipPrefixes)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := CreatePreAuthKey(tx, user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = getNode(tx, "test", "testnode")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
node := types.Node{
|
||||
ID: uint64(index),
|
||||
Hostname: "testnode",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
IPAddresses: ips,
|
||||
}
|
||||
tx.Save(&node)
|
||||
c.Assert(tx.Commit().Error, check.IsNil)
|
||||
}
|
||||
|
||||
usedIps, err := Read(db.DB, func(rx *gorm.DB) (*netipx.IPSet, error) {
|
||||
return getUsedIPs(rx)
|
||||
})
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
expected0 := netip.MustParseAddr("10.27.0.1")
|
||||
expected9 := netip.MustParseAddr("10.27.0.10")
|
||||
expected300 := netip.MustParseAddr("10.27.0.45")
|
||||
|
||||
notExpectedIPSetBuilder := netipx.IPSetBuilder{}
|
||||
notExpectedIPSetBuilder.Add(expected0)
|
||||
notExpectedIPSetBuilder.Add(expected9)
|
||||
notExpectedIPSetBuilder.Add(expected300)
|
||||
notExpectedIPSet, err := notExpectedIPSetBuilder.IPSet()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
// We actually expect it to be a lot larger
|
||||
c.Assert(usedIps.Equal(notExpectedIPSet), check.Equals, false)
|
||||
|
||||
c.Assert(usedIps.Contains(expected0), check.Equals, true)
|
||||
c.Assert(usedIps.Contains(expected9), check.Equals, true)
|
||||
c.Assert(usedIps.Contains(expected300), check.Equals, true)
|
||||
|
||||
// Check that we can read back the IPs
|
||||
node1, err := db.GetNodeByID(1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(node1.IPAddresses), check.Equals, 1)
|
||||
c.Assert(
|
||||
node1.IPAddresses[0],
|
||||
check.Equals,
|
||||
netip.MustParseAddr("10.27.0.1"),
|
||||
)
|
||||
|
||||
node50, err := db.GetNodeByID(50)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(node50.IPAddresses), check.Equals, 1)
|
||||
c.Assert(
|
||||
node50.IPAddresses[0],
|
||||
check.Equals,
|
||||
netip.MustParseAddr("10.27.0.50"),
|
||||
)
|
||||
|
||||
expectedNextIP := netip.MustParseAddr("10.27.1.95")
|
||||
nextIP, err := db.getAvailableIPs()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(len(nextIP), check.Equals, 1)
|
||||
c.Assert(nextIP[0].String(), check.Equals, expectedNextIP.String())
|
||||
|
||||
// If we call get Available again, we should receive
|
||||
// the same IP, as it has not been reserved.
|
||||
nextIP2, err := db.getAvailableIPs()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(len(nextIP2), check.Equals, 1)
|
||||
c.Assert(nextIP2[0].String(), check.Equals, expectedNextIP.String())
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetAvailableIpNodeWithoutIP(c *check.C) {
|
||||
ips, err := db.getAvailableIPs()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
expected := netip.MustParseAddr("10.27.0.1")
|
||||
|
||||
c.Assert(len(ips), check.Equals, 1)
|
||||
c.Assert(ips[0].String(), check.Equals, expected.String())
|
||||
|
||||
user, err := db.CreateUser("test-ip")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.getNode("test", "testnode")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
node := types.Node{
|
||||
ID: 0,
|
||||
Hostname: "testnode",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
db.DB.Save(&node)
|
||||
|
||||
ips2, err := db.getAvailableIPs()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(len(ips2), check.Equals, 1)
|
||||
c.Assert(ips2[0].String(), check.Equals, expected.String())
|
||||
}
|
|
@ -5,7 +5,6 @@ import (
|
|||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
@ -18,7 +17,6 @@ import (
|
|||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
)
|
||||
|
@ -35,7 +33,6 @@ type KV struct {
|
|||
type HSDatabase struct {
|
||||
DB *gorm.DB
|
||||
|
||||
ipPrefixes []netip.Prefix
|
||||
baseDomain string
|
||||
}
|
||||
|
||||
|
@ -43,8 +40,6 @@ type HSDatabase struct {
|
|||
// rather than arguments.
|
||||
func NewHeadscaleDatabase(
|
||||
cfg types.DatabaseConfig,
|
||||
notifier *notifier.Notifier,
|
||||
ipPrefixes []netip.Prefix,
|
||||
baseDomain string,
|
||||
) (*HSDatabase, error) {
|
||||
dbConn, err := openDB(cfg)
|
||||
|
@ -327,7 +322,6 @@ func NewHeadscaleDatabase(
|
|||
db := HSDatabase{
|
||||
DB: dbConn,
|
||||
|
||||
ipPrefixes: ipPrefixes,
|
||||
baseDomain: baseDomain,
|
||||
}
|
||||
|
||||
|
@ -351,6 +345,11 @@ func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) {
|
|||
return nil, fmt.Errorf("creating directory for sqlite: %w", err)
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("database", types.DatabaseSqlite).
|
||||
Str("path", cfg.Sqlite.Path).
|
||||
Msg("Opening database")
|
||||
|
||||
db, err := gorm.Open(
|
||||
sqlite.Open(cfg.Sqlite.Path+"?_synchronous=1&_journal_mode=WAL"),
|
||||
&gorm.Config{
|
||||
|
@ -379,6 +378,11 @@ func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) {
|
|||
cfg.Postgres.User,
|
||||
)
|
||||
|
||||
log.Info().
|
||||
Str("database", types.DatabasePostgres).
|
||||
Str("path", dbString).
|
||||
Msg("Opening database")
|
||||
|
||||
if sslEnabled, err := strconv.ParseBool(cfg.Postgres.Ssl); err == nil {
|
||||
if !sslEnabled {
|
||||
dbString += " sslmode=disable"
|
||||
|
|
148
hscontrol/db/ip.go
Normal file
148
hscontrol/db/ip.go
Normal file
|
@ -0,0 +1,148 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"go4.org/netipx"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// IPAllocator is a singleton responsible for allocating
|
||||
// IP addresses for nodes and making sure the same
|
||||
// address is not handed out twice. There can only be one
|
||||
// and it needs to be created before any other database
|
||||
// writes occur.
|
||||
type IPAllocator struct {
|
||||
mu sync.Mutex
|
||||
|
||||
prefix4 netip.Prefix
|
||||
prefix6 netip.Prefix
|
||||
|
||||
// Previous IPs handed out
|
||||
prev4 netip.Addr
|
||||
prev6 netip.Addr
|
||||
|
||||
// Set of all IPs handed out.
|
||||
// This might not be in sync with the database,
|
||||
// but it is more conservative. If saves to the
|
||||
// database fails, the IP will be allocated here
|
||||
// until the next restart of Headscale.
|
||||
usedIPs netipx.IPSetBuilder
|
||||
}
|
||||
|
||||
// NewIPAllocator returns a new IPAllocator singleton which
|
||||
// can be used to hand out unique IP addresses within the
|
||||
// provided IPv4 and IPv6 prefix. It needs to be created
|
||||
// when headscale starts and needs to finish its read
|
||||
// transaction before any writes to the database occur.
|
||||
func NewIPAllocator(db *HSDatabase, prefix4, prefix6 netip.Prefix) (*IPAllocator, error) {
|
||||
var addressesSlices []string
|
||||
|
||||
if db != nil {
|
||||
db.Read(func(rx *gorm.DB) error {
|
||||
return rx.Model(&types.Node{}).Pluck("ip_addresses", &addressesSlices).Error
|
||||
})
|
||||
}
|
||||
|
||||
var ips netipx.IPSetBuilder
|
||||
|
||||
// Add network and broadcast addrs to used pool so they
|
||||
// are not handed out to nodes.
|
||||
network4, broadcast4 := util.GetIPPrefixEndpoints(prefix4)
|
||||
network6, broadcast6 := util.GetIPPrefixEndpoints(prefix6)
|
||||
ips.Add(network4)
|
||||
ips.Add(broadcast4)
|
||||
ips.Add(network6)
|
||||
ips.Add(broadcast6)
|
||||
|
||||
// Fetch all the IP Addresses currently handed out from the Database
|
||||
// and add them to the used IP set.
|
||||
for _, slice := range addressesSlices {
|
||||
var machineAddresses types.NodeAddresses
|
||||
err := machineAddresses.Scan(slice)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"parsing IPs from database %v: %w", machineAddresses,
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
for _, ip := range machineAddresses {
|
||||
ips.Add(ip)
|
||||
}
|
||||
}
|
||||
|
||||
// Build the initial IPSet to validate that we can use it.
|
||||
_, err := ips.IPSet()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"building initial IP Set: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
return &IPAllocator{
|
||||
usedIPs: ips,
|
||||
|
||||
prefix4: prefix4,
|
||||
prefix6: prefix6,
|
||||
|
||||
// Use network as starting point, it will be used to call .Next()
|
||||
// TODO(kradalby): Could potentially take all the IPs loaded from
|
||||
// the database into account to start at a more "educated" location.
|
||||
prev4: network4,
|
||||
prev6: network6,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (i *IPAllocator) Next() (types.NodeAddresses, error) {
|
||||
i.mu.Lock()
|
||||
defer i.mu.Unlock()
|
||||
|
||||
v4, err := i.next(i.prev4, i.prefix4)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("allocating IPv4 address: %w", err)
|
||||
}
|
||||
|
||||
v6, err := i.next(i.prev6, i.prefix6)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("allocating IPv6 address: %w", err)
|
||||
}
|
||||
|
||||
return types.NodeAddresses{*v4, *v6}, nil
|
||||
}
|
||||
|
||||
var ErrCouldNotAllocateIP = errors.New("failed to allocate IP")
|
||||
|
||||
func (i *IPAllocator) next(prev netip.Addr, prefix netip.Prefix) (*netip.Addr, error) {
|
||||
// Get the first IP in our prefix
|
||||
ip := prev.Next()
|
||||
|
||||
// TODO(kradalby): maybe this can be done less often.
|
||||
set, err := i.usedIPs.IPSet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for {
|
||||
if !prefix.Contains(ip) {
|
||||
return nil, ErrCouldNotAllocateIP
|
||||
}
|
||||
|
||||
// Check if the IP has already been allocated.
|
||||
if set.Contains(ip) {
|
||||
ip = ip.Next()
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
i.usedIPs.Add(ip)
|
||||
|
||||
return &ip, nil
|
||||
}
|
||||
}
|
151
hscontrol/db/ip_test.go
Normal file
151
hscontrol/db/ip_test.go
Normal file
|
@ -0,0 +1,151 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
)
|
||||
|
||||
func TestIPAllocator(t *testing.T) {
|
||||
mpp := func(pref string) netip.Prefix {
|
||||
return netip.MustParsePrefix(pref)
|
||||
}
|
||||
na := func(pref string) netip.Addr {
|
||||
return netip.MustParseAddr(pref)
|
||||
}
|
||||
newDb := func() *HSDatabase {
|
||||
tmpDir, err := os.MkdirTemp("", "headscale-db-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("creating temp dir: %s", err)
|
||||
}
|
||||
db, _ = NewHeadscaleDatabase(
|
||||
types.DatabaseConfig{
|
||||
Type: "sqlite3",
|
||||
Sqlite: types.SqliteConfig{
|
||||
Path: tmpDir + "/headscale_test.db",
|
||||
},
|
||||
},
|
||||
"",
|
||||
)
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
dbFunc func() *HSDatabase
|
||||
|
||||
prefix4 netip.Prefix
|
||||
prefix6 netip.Prefix
|
||||
getCount int
|
||||
want []types.NodeAddresses
|
||||
}{
|
||||
{
|
||||
name: "simple",
|
||||
dbFunc: func() *HSDatabase {
|
||||
return nil
|
||||
},
|
||||
|
||||
prefix4: mpp("100.64.0.0/10"),
|
||||
prefix6: mpp("fd7a:115c:a1e0::/48"),
|
||||
|
||||
getCount: 1,
|
||||
|
||||
want: []types.NodeAddresses{
|
||||
{
|
||||
na("100.64.0.1"),
|
||||
na("fd7a:115c:a1e0::1"),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "simple-with-db",
|
||||
dbFunc: func() *HSDatabase {
|
||||
db := newDb()
|
||||
|
||||
db.DB.Save(&types.Node{
|
||||
IPAddresses: types.NodeAddresses{
|
||||
na("100.64.0.1"),
|
||||
na("fd7a:115c:a1e0::1"),
|
||||
},
|
||||
})
|
||||
|
||||
return db
|
||||
},
|
||||
|
||||
prefix4: mpp("100.64.0.0/10"),
|
||||
prefix6: mpp("fd7a:115c:a1e0::/48"),
|
||||
|
||||
getCount: 1,
|
||||
|
||||
want: []types.NodeAddresses{
|
||||
{
|
||||
na("100.64.0.2"),
|
||||
na("fd7a:115c:a1e0::2"),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "before-after-free-middle-in-db",
|
||||
dbFunc: func() *HSDatabase {
|
||||
db := newDb()
|
||||
|
||||
db.DB.Save(&types.Node{
|
||||
IPAddresses: types.NodeAddresses{
|
||||
na("100.64.0.2"),
|
||||
na("fd7a:115c:a1e0::2"),
|
||||
},
|
||||
})
|
||||
|
||||
return db
|
||||
},
|
||||
|
||||
prefix4: mpp("100.64.0.0/10"),
|
||||
prefix6: mpp("fd7a:115c:a1e0::/48"),
|
||||
|
||||
getCount: 2,
|
||||
|
||||
want: []types.NodeAddresses{
|
||||
{
|
||||
na("100.64.0.1"),
|
||||
na("fd7a:115c:a1e0::1"),
|
||||
},
|
||||
{
|
||||
na("100.64.0.3"),
|
||||
na("fd7a:115c:a1e0::3"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db := tt.dbFunc()
|
||||
|
||||
alloc, _ := NewIPAllocator(db, tt.prefix4, tt.prefix6)
|
||||
|
||||
spew.Dump(alloc)
|
||||
|
||||
t.Logf("prefixes: %q, %q", tt.prefix4.String(), tt.prefix6.String())
|
||||
|
||||
var got []types.NodeAddresses
|
||||
|
||||
for range tt.getCount {
|
||||
gotSet, err := alloc.Next()
|
||||
if err != nil {
|
||||
t.Fatalf("allocating next IP: %s", err)
|
||||
}
|
||||
|
||||
got = append(got, gotSet)
|
||||
}
|
||||
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
|
||||
t.Errorf("IPAllocator unexpected result (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -307,7 +307,7 @@ func RegisterNodeFromAuthCallback(
|
|||
userName string,
|
||||
nodeExpiry *time.Time,
|
||||
registrationMethod string,
|
||||
ipPrefixes []netip.Prefix,
|
||||
addrs types.NodeAddresses,
|
||||
) (*types.Node, error) {
|
||||
log.Debug().
|
||||
Str("machine_key", mkey.ShortString()).
|
||||
|
@ -343,7 +343,7 @@ func RegisterNodeFromAuthCallback(
|
|||
node, err := RegisterNode(
|
||||
tx,
|
||||
registrationNode,
|
||||
ipPrefixes,
|
||||
addrs,
|
||||
)
|
||||
|
||||
if err == nil {
|
||||
|
@ -359,14 +359,14 @@ func RegisterNodeFromAuthCallback(
|
|||
return nil, ErrNodeNotFoundRegistrationCache
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) RegisterNode(node types.Node) (*types.Node, error) {
|
||||
func (hsdb *HSDatabase) RegisterNode(node types.Node, addrs types.NodeAddresses) (*types.Node, error) {
|
||||
return Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
return RegisterNode(tx, node, hsdb.ipPrefixes)
|
||||
return RegisterNode(tx, node, addrs)
|
||||
})
|
||||
}
|
||||
|
||||
// RegisterNode is executed from the CLI to register a new Node using its MachineKey.
|
||||
func RegisterNode(tx *gorm.DB, node types.Node, ipPrefixes []netip.Prefix) (*types.Node, error) {
|
||||
func RegisterNode(tx *gorm.DB, node types.Node, addrs types.NodeAddresses) (*types.Node, error) {
|
||||
log.Debug().
|
||||
Str("node", node.Hostname).
|
||||
Str("machine_key", node.MachineKey.ShortString()).
|
||||
|
@ -393,18 +393,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipPrefixes []netip.Prefix) (*typ
|
|||
return &node, nil
|
||||
}
|
||||
|
||||
ips, err := getAvailableIPs(tx, ipPrefixes)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Str("node", node.Hostname).
|
||||
Msg("Could not find IP for the new node")
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
node.IPAddresses = ips
|
||||
node.IPAddresses = addrs
|
||||
|
||||
if err := tx.Save(&node).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed register(save) node in the database: %w", err)
|
||||
|
@ -413,7 +402,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipPrefixes []netip.Prefix) (*typ
|
|||
log.Trace().
|
||||
Caller().
|
||||
Str("node", node.Hostname).
|
||||
Str("ip", strings.Join(ips.StringSlice(), ",")).
|
||||
Str("ip", strings.Join(addrs.StringSlice(), ",")).
|
||||
Msg("Node registered with the database")
|
||||
|
||||
return &node, nil
|
||||
|
|
|
@ -7,7 +7,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -661,10 +660,6 @@ func TestFailoverRoute(t *testing.T) {
|
|||
Path: tmpDir + "/headscale_test.db",
|
||||
},
|
||||
},
|
||||
notifier.NewNotifier(),
|
||||
[]netip.Prefix{
|
||||
netip.MustParsePrefix("10.27.0.0/23"),
|
||||
},
|
||||
"",
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
|
|
@ -2,11 +2,9 @@ package db
|
|||
|
||||
import (
|
||||
"log"
|
||||
"net/netip"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"gopkg.in/check.v1"
|
||||
)
|
||||
|
@ -52,10 +50,6 @@ func (s *Suite) ResetDB(c *check.C) {
|
|||
Path: tmpDir + "/headscale_test.db",
|
||||
},
|
||||
},
|
||||
notifier.NewNotifier(),
|
||||
[]netip.Prefix{
|
||||
netip.MustParsePrefix("10.27.0.0/23"),
|
||||
},
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue