Migrate IP fields in database to dedicated columns (#1869)

This commit is contained in:
Kristoffer Dalby 2024-04-17 07:03:06 +02:00 committed by GitHub
parent 85cef84e17
commit 2ce23df45a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
39 changed files with 1885 additions and 1055 deletions

View file

@ -5,6 +5,7 @@ import (
"database/sql"
"errors"
"fmt"
"net/netip"
"path/filepath"
"strconv"
"strings"
@ -330,6 +331,66 @@ func NewHeadscaleDatabase(
return nil
},
},
{
// Replace column with IP address list with dedicated
// IP v4 and v6 column.
// Note that previously, the list _could_ contain more
// than two addresses, which should not really happen.
// In that case, the first occurence of each type will
// be kept.
ID: "2024041121742",
Migrate: func(tx *gorm.DB) error {
_ = tx.Migrator().AddColumn(&types.Node{}, "ipv4")
_ = tx.Migrator().AddColumn(&types.Node{}, "ipv6")
type node struct {
ID uint64 `gorm:"column:id"`
Addresses string `gorm:"column:ip_addresses"`
}
var nodes []node
_ = tx.Raw("SELECT id, ip_addresses FROM nodes").Scan(&nodes).Error
for _, node := range nodes {
addrs := strings.Split(node.Addresses, ",")
if len(addrs) == 0 {
fmt.Errorf("no addresses found for node(%d)", node.ID)
}
var v4 *netip.Addr
var v6 *netip.Addr
for _, addrStr := range addrs {
addr, err := netip.ParseAddr(addrStr)
if err != nil {
return fmt.Errorf("parsing IP for node(%d) from database: %w", node.ID, err)
}
if addr.Is4() && v4 == nil {
v4 = &addr
}
if addr.Is6() && v6 == nil {
v6 = &addr
}
}
err = tx.Save(&types.Node{ID: types.NodeID(node.ID), IPv4: v4, IPv6: v6}).Error
if err != nil {
return fmt.Errorf("saving ip addresses to new columns: %w", err)
}
}
_ = tx.Migrator().DropColumn(&types.Node{}, "ip_addresses")
return nil
},
Rollback: func(tx *gorm.DB) error {
return nil
},
},
},
)

View file

@ -1,13 +1,17 @@
package db
import (
"crypto/rand"
"database/sql"
"errors"
"fmt"
"math/big"
"net/netip"
"sync"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"go4.org/netipx"
"gorm.io/gorm"
)
@ -20,13 +24,16 @@ import (
type IPAllocator struct {
mu sync.Mutex
prefix4 netip.Prefix
prefix6 netip.Prefix
prefix4 *netip.Prefix
prefix6 *netip.Prefix
// Previous IPs handed out
prev4 netip.Addr
prev6 netip.Addr
// strategy used for handing out IP addresses.
strategy types.IPAllocationStrategy
// Set of all IPs handed out.
// This might not be in sync with the database,
// but it is more conservative. If saves to the
@ -40,40 +47,71 @@ type IPAllocator struct {
// provided IPv4 and IPv6 prefix. It needs to be created
// when headscale starts and needs to finish its read
// transaction before any writes to the database occur.
func NewIPAllocator(db *HSDatabase, prefix4, prefix6 netip.Prefix) (*IPAllocator, error) {
var addressesSlices []string
func NewIPAllocator(
db *HSDatabase,
prefix4, prefix6 *netip.Prefix,
strategy types.IPAllocationStrategy,
) (*IPAllocator, error) {
ret := IPAllocator{
prefix4: prefix4,
prefix6: prefix6,
strategy: strategy,
}
var v4s []sql.NullString
var v6s []sql.NullString
if db != nil {
db.Read(func(rx *gorm.DB) error {
return rx.Model(&types.Node{}).Pluck("ip_addresses", &addressesSlices).Error
err := db.Read(func(rx *gorm.DB) error {
return rx.Model(&types.Node{}).Pluck("ipv4", &v4s).Error
})
if err != nil {
return nil, fmt.Errorf("reading IPv4 addresses from database: %w", err)
}
err = db.Read(func(rx *gorm.DB) error {
return rx.Model(&types.Node{}).Pluck("ipv6", &v6s).Error
})
if err != nil {
return nil, fmt.Errorf("reading IPv6 addresses from database: %w", err)
}
}
var ips netipx.IPSetBuilder
// Add network and broadcast addrs to used pool so they
// are not handed out to nodes.
network4, broadcast4 := util.GetIPPrefixEndpoints(prefix4)
network6, broadcast6 := util.GetIPPrefixEndpoints(prefix6)
ips.Add(network4)
ips.Add(broadcast4)
ips.Add(network6)
ips.Add(broadcast6)
if prefix4 != nil {
network4, broadcast4 := util.GetIPPrefixEndpoints(*prefix4)
ips.Add(network4)
ips.Add(broadcast4)
// Use network as starting point, it will be used to call .Next()
// TODO(kradalby): Could potentially take all the IPs loaded from
// the database into account to start at a more "educated" location.
ret.prev4 = network4
}
if prefix6 != nil {
network6, broadcast6 := util.GetIPPrefixEndpoints(*prefix6)
ips.Add(network6)
ips.Add(broadcast6)
ret.prev6 = network6
}
// Fetch all the IP Addresses currently handed out from the Database
// and add them to the used IP set.
for _, slice := range addressesSlices {
var machineAddresses types.NodeAddresses
err := machineAddresses.Scan(slice)
if err != nil {
return nil, fmt.Errorf(
"parsing IPs from database %v: %w", machineAddresses,
err,
)
}
for _, addrStr := range append(v4s, v6s...) {
if addrStr.Valid {
addr, err := netip.ParseAddr(addrStr.String)
if err != nil {
return nil, fmt.Errorf("parsing IP address from database: %w", err)
}
for _, ip := range machineAddresses {
ips.Add(ip)
ips.Add(addr)
}
}
@ -86,42 +124,61 @@ func NewIPAllocator(db *HSDatabase, prefix4, prefix6 netip.Prefix) (*IPAllocator
)
}
return &IPAllocator{
usedIPs: ips,
ret.usedIPs = ips
prefix4: prefix4,
prefix6: prefix6,
// Use network as starting point, it will be used to call .Next()
// TODO(kradalby): Could potentially take all the IPs loaded from
// the database into account to start at a more "educated" location.
prev4: network4,
prev6: network6,
}, nil
return &ret, nil
}
func (i *IPAllocator) Next() (types.NodeAddresses, error) {
func (i *IPAllocator) Next() (*netip.Addr, *netip.Addr, error) {
i.mu.Lock()
defer i.mu.Unlock()
v4, err := i.next(i.prev4, i.prefix4)
if err != nil {
return nil, fmt.Errorf("allocating IPv4 address: %w", err)
var err error
var ret4 *netip.Addr
var ret6 *netip.Addr
if i.prefix4 != nil {
ret4, err = i.next(i.prev4, i.prefix4)
if err != nil {
return nil, nil, fmt.Errorf("allocating IPv4 address: %w", err)
}
i.prev4 = *ret4
}
v6, err := i.next(i.prev6, i.prefix6)
if err != nil {
return nil, fmt.Errorf("allocating IPv6 address: %w", err)
if i.prefix6 != nil {
ret6, err = i.next(i.prev6, i.prefix6)
if err != nil {
return nil, nil, fmt.Errorf("allocating IPv6 address: %w", err)
}
i.prev6 = *ret6
}
return types.NodeAddresses{*v4, *v6}, nil
return ret4, ret6, nil
}
var ErrCouldNotAllocateIP = errors.New("failed to allocate IP")
func (i *IPAllocator) next(prev netip.Addr, prefix netip.Prefix) (*netip.Addr, error) {
// Get the first IP in our prefix
ip := prev.Next()
func (i *IPAllocator) nextLocked(prev netip.Addr, prefix *netip.Prefix) (*netip.Addr, error) {
i.mu.Lock()
defer i.mu.Unlock()
return i.next(prev, prefix)
}
func (i *IPAllocator) next(prev netip.Addr, prefix *netip.Prefix) (*netip.Addr, error) {
var err error
var ip netip.Addr
switch i.strategy {
case types.IPAllocationStrategySequential:
// Get the first IP in our prefix
ip = prev.Next()
case types.IPAllocationStrategyRandom:
ip, err = randomNext(*prefix)
if err != nil {
return nil, fmt.Errorf("getting random IP: %w", err)
}
}
// TODO(kradalby): maybe this can be done less often.
set, err := i.usedIPs.IPSet()
@ -136,7 +193,15 @@ func (i *IPAllocator) next(prev netip.Addr, prefix netip.Prefix) (*netip.Addr, e
// Check if the IP has already been allocated.
if set.Contains(ip) {
ip = ip.Next()
switch i.strategy {
case types.IPAllocationStrategySequential:
ip = ip.Next()
case types.IPAllocationStrategyRandom:
ip, err = randomNext(*prefix)
if err != nil {
return nil, fmt.Errorf("getting random IP: %w", err)
}
}
continue
}
@ -146,3 +211,120 @@ func (i *IPAllocator) next(prev netip.Addr, prefix netip.Prefix) (*netip.Addr, e
return &ip, nil
}
}
func randomNext(pfx netip.Prefix) (netip.Addr, error) {
rang := netipx.RangeOfPrefix(pfx)
fromIP, toIP := rang.From(), rang.To()
var from, to big.Int
from.SetBytes(fromIP.AsSlice())
to.SetBytes(toIP.AsSlice())
// Find the max, this is how we can do "random range",
// get the "max" as 0 -> to - from and then add back from
// after.
tempMax := big.NewInt(0).Sub(&to, &from)
out, err := rand.Int(rand.Reader, tempMax)
if err != nil {
return netip.Addr{}, fmt.Errorf("generating random IP: %w", err)
}
valInRange := big.NewInt(0).Add(&from, out)
ip, ok := netip.AddrFromSlice(valInRange.Bytes())
if !ok {
return netip.Addr{}, fmt.Errorf("generated ip bytes are invalid ip")
}
if !pfx.Contains(ip) {
return netip.Addr{}, fmt.Errorf(
"generated ip(%s) not in prefix(%s)",
ip.String(),
pfx.String(),
)
}
return ip, nil
}
// BackfillNodeIPs will take a database transaction, and
// iterate through all of the current nodes in headscale
// and ensure it has IP addresses according to the current
// configuration.
// This means that if both IPv4 and IPv6 is set in the
// config, and some nodes are missing that type of IP,
// it will be added.
// If a prefix type has been removed (IPv4 or IPv6), it
// will remove the IPs in that family from the node.
func (db *HSDatabase) BackfillNodeIPs(i *IPAllocator) ([]string, error) {
var err error
var ret []string
err = db.Write(func(tx *gorm.DB) error {
if i == nil {
return errors.New("backfilling IPs: ip allocator was nil")
}
log.Trace().Msgf("starting to backfill IPs")
nodes, err := ListNodes(tx)
if err != nil {
return fmt.Errorf("listing nodes to backfill IPs: %w", err)
}
for _, node := range nodes {
log.Trace().Uint64("node.id", node.ID.Uint64()).Msg("checking if need backfill")
changed := false
// IPv4 prefix is set, but node ip is missing, alloc
if i.prefix4 != nil && node.IPv4 == nil {
ret4, err := i.nextLocked(i.prev4, i.prefix4)
if err != nil {
return fmt.Errorf("failed to allocate ipv4 for node(%d): %w", node.ID, err)
}
node.IPv4 = ret4
changed = true
ret = append(ret, fmt.Sprintf("assigned IPv4 %q to Node(%d) %q", ret4.String(), node.ID, node.Hostname))
}
// IPv6 prefix is set, but node ip is missing, alloc
if i.prefix6 != nil && node.IPv6 == nil {
ret6, err := i.nextLocked(i.prev6, i.prefix6)
if err != nil {
return fmt.Errorf("failed to allocate ipv6 for node(%d): %w", node.ID, err)
}
node.IPv6 = ret6
changed = true
ret = append(ret, fmt.Sprintf("assigned IPv6 %q to Node(%d) %q", ret6.String(), node.ID, node.Hostname))
}
// IPv4 prefix is not set, but node has IP, remove
if i.prefix4 == nil && node.IPv4 != nil {
ret = append(ret, fmt.Sprintf("removing IPv4 %q from Node(%d) %q", node.IPv4.String(), node.ID, node.Hostname))
node.IPv4 = nil
changed = true
}
// IPv6 prefix is not set, but node has IP, remove
if i.prefix6 == nil && node.IPv6 != nil {
ret = append(ret, fmt.Sprintf("removing IPv6 %q from Node(%d) %q", node.IPv6.String(), node.ID, node.Hostname))
node.IPv6 = nil
changed = true
}
if changed {
err := tx.Save(node).Error
if err != nil {
return fmt.Errorf("saving node(%d) after adding IPs: %w", node.ID, err)
}
}
}
return nil
})
return ret, err
}

View file

@ -1,49 +1,41 @@
package db
import (
"database/sql"
"fmt"
"net/netip"
"os"
"strings"
"testing"
"github.com/davecgh/go-spew/spew"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
)
func TestIPAllocator(t *testing.T) {
mpp := func(pref string) netip.Prefix {
return netip.MustParsePrefix(pref)
}
na := func(pref string) netip.Addr {
return netip.MustParseAddr(pref)
}
newDb := func() *HSDatabase {
tmpDir, err := os.MkdirTemp("", "headscale-db-test-*")
if err != nil {
t.Fatalf("creating temp dir: %s", err)
}
db, _ = NewHeadscaleDatabase(
types.DatabaseConfig{
Type: "sqlite3",
Sqlite: types.SqliteConfig{
Path: tmpDir + "/headscale_test.db",
},
},
"",
)
return db
}
var mpp = func(pref string) *netip.Prefix {
p := netip.MustParsePrefix(pref)
return &p
}
var na = func(pref string) netip.Addr {
return netip.MustParseAddr(pref)
}
var nap = func(pref string) *netip.Addr {
n := na(pref)
return &n
}
func TestIPAllocatorSequential(t *testing.T) {
tests := []struct {
name string
dbFunc func() *HSDatabase
prefix4 netip.Prefix
prefix6 netip.Prefix
prefix4 *netip.Prefix
prefix6 *netip.Prefix
getCount int
want []types.NodeAddresses
want4 []netip.Addr
want6 []netip.Addr
}{
{
name: "simple",
@ -56,23 +48,49 @@ func TestIPAllocator(t *testing.T) {
getCount: 1,
want: []types.NodeAddresses{
{
na("100.64.0.1"),
na("fd7a:115c:a1e0::1"),
},
want4: []netip.Addr{
na("100.64.0.1"),
},
want6: []netip.Addr{
na("fd7a:115c:a1e0::1"),
},
},
{
name: "simple-v4",
dbFunc: func() *HSDatabase {
return nil
},
prefix4: mpp("100.64.0.0/10"),
getCount: 1,
want4: []netip.Addr{
na("100.64.0.1"),
},
},
{
name: "simple-v6",
dbFunc: func() *HSDatabase {
return nil
},
prefix6: mpp("fd7a:115c:a1e0::/48"),
getCount: 1,
want6: []netip.Addr{
na("fd7a:115c:a1e0::1"),
},
},
{
name: "simple-with-db",
dbFunc: func() *HSDatabase {
db := newDb()
db := dbForTest(t, "simple-with-db")
db.DB.Save(&types.Node{
IPAddresses: types.NodeAddresses{
na("100.64.0.1"),
na("fd7a:115c:a1e0::1"),
},
IPv4: nap("100.64.0.1"),
IPv6: nap("fd7a:115c:a1e0::1"),
})
return db
@ -83,23 +101,21 @@ func TestIPAllocator(t *testing.T) {
getCount: 1,
want: []types.NodeAddresses{
{
na("100.64.0.2"),
na("fd7a:115c:a1e0::2"),
},
want4: []netip.Addr{
na("100.64.0.2"),
},
want6: []netip.Addr{
na("fd7a:115c:a1e0::2"),
},
},
{
name: "before-after-free-middle-in-db",
dbFunc: func() *HSDatabase {
db := newDb()
db := dbForTest(t, "before-after-free-middle-in-db")
db.DB.Save(&types.Node{
IPAddresses: types.NodeAddresses{
na("100.64.0.2"),
na("fd7a:115c:a1e0::2"),
},
IPv4: nap("100.64.0.2"),
IPv6: nap("fd7a:115c:a1e0::2"),
})
return db
@ -110,15 +126,13 @@ func TestIPAllocator(t *testing.T) {
getCount: 2,
want: []types.NodeAddresses{
{
na("100.64.0.1"),
na("fd7a:115c:a1e0::1"),
},
{
na("100.64.0.3"),
na("fd7a:115c:a1e0::3"),
},
want4: []netip.Addr{
na("100.64.0.1"),
na("100.64.0.3"),
},
want6: []netip.Addr{
na("fd7a:115c:a1e0::1"),
na("fd7a:115c:a1e0::3"),
},
},
}
@ -127,24 +141,347 @@ func TestIPAllocator(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
db := tt.dbFunc()
alloc, _ := NewIPAllocator(db, tt.prefix4, tt.prefix6)
alloc, _ := NewIPAllocator(
db,
tt.prefix4,
tt.prefix6,
types.IPAllocationStrategySequential,
)
spew.Dump(alloc)
t.Logf("prefixes: %q, %q", tt.prefix4.String(), tt.prefix6.String())
var got []types.NodeAddresses
var got4s []netip.Addr
var got6s []netip.Addr
for range tt.getCount {
gotSet, err := alloc.Next()
got4, got6, err := alloc.Next()
if err != nil {
t.Fatalf("allocating next IP: %s", err)
}
got = append(got, gotSet)
if got4 != nil {
got4s = append(got4s, *got4)
}
if got6 != nil {
got6s = append(got6s, *got6)
}
}
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
t.Errorf("IPAllocator unexpected result (-want +got):\n%s", diff)
if diff := cmp.Diff(tt.want4, got4s, util.Comparers...); diff != "" {
t.Errorf("IPAllocator 4s unexpected result (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.want6, got6s, util.Comparers...); diff != "" {
t.Errorf("IPAllocator 6s unexpected result (-want +got):\n%s", diff)
}
})
}
}
func TestIPAllocatorRandom(t *testing.T) {
tests := []struct {
name string
dbFunc func() *HSDatabase
getCount int
prefix4 *netip.Prefix
prefix6 *netip.Prefix
want4 bool
want6 bool
}{
{
name: "simple",
dbFunc: func() *HSDatabase {
return nil
},
prefix4: mpp("100.64.0.0/10"),
prefix6: mpp("fd7a:115c:a1e0::/48"),
getCount: 1,
want4: true,
want6: true,
},
{
name: "simple-v4",
dbFunc: func() *HSDatabase {
return nil
},
prefix4: mpp("100.64.0.0/10"),
getCount: 1,
want4: true,
want6: false,
},
{
name: "simple-v6",
dbFunc: func() *HSDatabase {
return nil
},
prefix6: mpp("fd7a:115c:a1e0::/48"),
getCount: 1,
want4: false,
want6: true,
},
{
name: "generate-lots-of-random",
dbFunc: func() *HSDatabase {
return nil
},
prefix4: mpp("100.64.0.0/10"),
prefix6: mpp("fd7a:115c:a1e0::/48"),
getCount: 1000,
want4: true,
want6: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := tt.dbFunc()
alloc, _ := NewIPAllocator(db, tt.prefix4, tt.prefix6, types.IPAllocationStrategyRandom)
spew.Dump(alloc)
for range tt.getCount {
got4, got6, err := alloc.Next()
if err != nil {
t.Fatalf("allocating next IP: %s", err)
}
t.Logf("addrs ipv4: %v, ipv6: %v", got4, got6)
if tt.want4 {
if got4 == nil {
t.Fatalf("expected ipv4 addr, got nil")
}
}
if tt.want6 {
if got6 == nil {
t.Fatalf("expected ipv4 addr, got nil")
}
}
}
})
}
}
func TestBackfillIPAddresses(t *testing.T) {
fullNodeP := func(i int) *types.Node {
v4 := fmt.Sprintf("100.64.0.%d", i)
v6 := fmt.Sprintf("fd7a:115c:a1e0::%d", i)
return &types.Node{
IPv4DatabaseField: sql.NullString{
Valid: true,
String: v4,
},
IPv4: nap(v4),
IPv6DatabaseField: sql.NullString{
Valid: true,
String: v6,
},
IPv6: nap(v6),
}
}
tests := []struct {
name string
dbFunc func() *HSDatabase
prefix4 *netip.Prefix
prefix6 *netip.Prefix
want types.Nodes
}{
{
name: "simple-backfill-ipv6",
dbFunc: func() *HSDatabase {
db := dbForTest(t, "simple-backfill-ipv6")
db.DB.Save(&types.Node{
IPv4: nap("100.64.0.1"),
})
return db
},
prefix4: mpp("100.64.0.0/10"),
prefix6: mpp("fd7a:115c:a1e0::/48"),
want: types.Nodes{
&types.Node{
IPv4DatabaseField: sql.NullString{
Valid: true,
String: "100.64.0.1",
},
IPv4: nap("100.64.0.1"),
IPv6DatabaseField: sql.NullString{
Valid: true,
String: "fd7a:115c:a1e0::1",
},
IPv6: nap("fd7a:115c:a1e0::1"),
},
},
},
{
name: "simple-backfill-ipv4",
dbFunc: func() *HSDatabase {
db := dbForTest(t, "simple-backfill-ipv4")
db.DB.Save(&types.Node{
IPv6: nap("fd7a:115c:a1e0::1"),
})
return db
},
prefix4: mpp("100.64.0.0/10"),
prefix6: mpp("fd7a:115c:a1e0::/48"),
want: types.Nodes{
&types.Node{
IPv4DatabaseField: sql.NullString{
Valid: true,
String: "100.64.0.1",
},
IPv4: nap("100.64.0.1"),
IPv6DatabaseField: sql.NullString{
Valid: true,
String: "fd7a:115c:a1e0::1",
},
IPv6: nap("fd7a:115c:a1e0::1"),
},
},
},
{
name: "simple-backfill-remove-ipv6",
dbFunc: func() *HSDatabase {
db := dbForTest(t, "simple-backfill-remove-ipv6")
db.DB.Save(&types.Node{
IPv4: nap("100.64.0.1"),
IPv6: nap("fd7a:115c:a1e0::1"),
})
return db
},
prefix4: mpp("100.64.0.0/10"),
want: types.Nodes{
&types.Node{
IPv4DatabaseField: sql.NullString{
Valid: true,
String: "100.64.0.1",
},
IPv4: nap("100.64.0.1"),
},
},
},
{
name: "simple-backfill-remove-ipv4",
dbFunc: func() *HSDatabase {
db := dbForTest(t, "simple-backfill-remove-ipv4")
db.DB.Save(&types.Node{
IPv4: nap("100.64.0.1"),
IPv6: nap("fd7a:115c:a1e0::1"),
})
return db
},
prefix6: mpp("fd7a:115c:a1e0::/48"),
want: types.Nodes{
&types.Node{
IPv6DatabaseField: sql.NullString{
Valid: true,
String: "fd7a:115c:a1e0::1",
},
IPv6: nap("fd7a:115c:a1e0::1"),
},
},
},
{
name: "multi-backfill-ipv6",
dbFunc: func() *HSDatabase {
db := dbForTest(t, "simple-backfill-ipv6")
db.DB.Save(&types.Node{
IPv4: nap("100.64.0.1"),
})
db.DB.Save(&types.Node{
IPv4: nap("100.64.0.2"),
})
db.DB.Save(&types.Node{
IPv4: nap("100.64.0.3"),
})
db.DB.Save(&types.Node{
IPv4: nap("100.64.0.4"),
})
return db
},
prefix4: mpp("100.64.0.0/10"),
prefix6: mpp("fd7a:115c:a1e0::/48"),
want: types.Nodes{
fullNodeP(1),
fullNodeP(2),
fullNodeP(3),
fullNodeP(4),
},
},
}
comps := append(util.Comparers, cmpopts.IgnoreFields(types.Node{},
"ID",
"MachineKeyDatabaseField",
"NodeKeyDatabaseField",
"DiscoKeyDatabaseField",
"Endpoints",
"HostinfoDatabaseField",
"Hostinfo",
"Routes",
"CreatedAt",
"UpdatedAt",
))
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := tt.dbFunc()
alloc, err := NewIPAllocator(db, tt.prefix4, tt.prefix6, types.IPAllocationStrategySequential)
if err != nil {
t.Fatalf("failed to set up ip alloc: %s", err)
}
logs, err := db.BackfillNodeIPs(alloc)
if err != nil {
t.Fatalf("failed to backfill: %s", err)
}
t.Logf("backfill log: \n%s", strings.Join(logs, "\n"))
got, err := db.ListNodes()
if err != nil {
t.Fatalf("failed to get nodes: %s", err)
}
if diff := cmp.Diff(tt.want, got, comps...); diff != "" {
t.Errorf("Backfill unexpected result (-want +got):\n%s", diff)
}
})
}

View file

@ -5,7 +5,6 @@ import (
"fmt"
"net/netip"
"sort"
"strings"
"time"
"github.com/juanfont/headscale/hscontrol/types"
@ -294,7 +293,8 @@ func RegisterNodeFromAuthCallback(
userName string,
nodeExpiry *time.Time,
registrationMethod string,
addrs types.NodeAddresses,
ipv4 *netip.Addr,
ipv6 *netip.Addr,
) (*types.Node, error) {
log.Debug().
Str("machine_key", mkey.ShortString()).
@ -330,7 +330,7 @@ func RegisterNodeFromAuthCallback(
node, err := RegisterNode(
tx,
registrationNode,
addrs,
ipv4, ipv6,
)
if err == nil {
@ -346,14 +346,14 @@ func RegisterNodeFromAuthCallback(
return nil, ErrNodeNotFoundRegistrationCache
}
func (hsdb *HSDatabase) RegisterNode(node types.Node, addrs types.NodeAddresses) (*types.Node, error) {
func (hsdb *HSDatabase) RegisterNode(node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) {
return Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
return RegisterNode(tx, node, addrs)
return RegisterNode(tx, node, ipv4, ipv6)
})
}
// RegisterNode is executed from the CLI to register a new Node using its MachineKey.
func RegisterNode(tx *gorm.DB, node types.Node, addrs types.NodeAddresses) (*types.Node, error) {
func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) {
log.Debug().
Str("node", node.Hostname).
Str("machine_key", node.MachineKey.ShortString()).
@ -361,10 +361,10 @@ func RegisterNode(tx *gorm.DB, node types.Node, addrs types.NodeAddresses) (*typ
Str("user", node.User.Name).
Msg("Registering node")
// If the node exists and we had already IPs for it, we just save it
// If the node exists and it already has IP(s), we just save it
// so we store the node.Expire and node.Nodekey that has been set when
// adding it to the registrationCache
if len(node.IPAddresses) > 0 {
if node.IPv4 != nil || node.IPv6 != nil {
if err := tx.Save(&node).Error; err != nil {
return nil, fmt.Errorf("failed register existing node in the database: %w", err)
}
@ -380,7 +380,8 @@ func RegisterNode(tx *gorm.DB, node types.Node, addrs types.NodeAddresses) (*typ
return &node, nil
}
node.IPAddresses = addrs
node.IPv4 = ipv4
node.IPv6 = ipv6
if err := tx.Save(&node).Error; err != nil {
return nil, fmt.Errorf("failed register(save) node in the database: %w", err)
@ -389,7 +390,6 @@ func RegisterNode(tx *gorm.DB, node types.Node, addrs types.NodeAddresses) (*typ
log.Trace().
Caller().
Str("node", node.Hostname).
Str("ip", strings.Join(addrs.StringSlice(), ",")).
Msg("Node registered with the database")
return &node, nil

View file

@ -188,13 +188,12 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
nodeKey := key.NewNode()
machineKey := key.NewMachine()
v4 := netip.MustParseAddr(fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1)))
node := types.Node{
ID: types.NodeID(index),
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
IPAddresses: types.NodeAddresses{
netip.MustParseAddr(fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1))),
},
ID: types.NodeID(index),
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
IPv4: &v4,
Hostname: "testnode" + strconv.Itoa(index),
UserID: stor[index%2].user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
@ -301,27 +300,6 @@ func (s *Suite) TestExpireNode(c *check.C) {
c.Assert(nodeFromDB.IsExpired(), check.Equals, true)
}
func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) {
input := types.NodeAddresses([]netip.Addr{
netip.MustParseAddr("192.0.2.1"),
netip.MustParseAddr("2001:db8::1"),
})
serialized, err := input.Value()
c.Assert(err, check.IsNil)
if serial, ok := serialized.(string); ok {
c.Assert(serial, check.Equals, "192.0.2.1,2001:db8::1")
}
var deserialized types.NodeAddresses
err = deserialized.Scan(serialized)
c.Assert(err, check.IsNil)
c.Assert(len(deserialized), check.Equals, len(input))
for i := range deserialized {
c.Assert(deserialized[i], check.Equals, input[i])
}
}
func (s *Suite) TestGenerateGivenName(c *check.C) {
user1, err := db.CreateUser("user-1")
c.Assert(err, check.IsNil)
@ -561,6 +539,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
// Check if a subprefix of an autoapproved route is approved
route2 := netip.MustParsePrefix("10.11.0.0/24")
v4 := netip.MustParseAddr("100.64.0.1")
node := types.Node{
ID: 0,
MachineKey: machineKey.Public(),
@ -573,7 +552,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
RequestTags: []string{"tag:exit"},
RoutableIPs: []netip.Prefix{defaultRouteV4, defaultRouteV6, route1, route2},
},
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
IPv4: &v4,
}
db.DB.Save(&node)

View file

@ -609,7 +609,7 @@ func EnableAutoApprovedRoutes(
aclPolicy *policy.ACLPolicy,
node *types.Node,
) error {
if len(node.IPAddresses) == 0 {
if node.IPv4 == nil && node.IPv6 == nil {
return nil // This node has no IPAddresses, so can't possibly match any autoApprovers ACLs
}
@ -652,7 +652,7 @@ func EnableAutoApprovedRoutes(
}
// approvedIPs should contain all of node's IPs if it matches the rule, so check for first
if approvedIps.Contains(node.IPAddresses[0]) {
if approvedIps.Contains(*node.IPv4) {
approvedRoutes = append(approvedRoutes, advertisedRoute)
}
}