new IP allocator and add postgres to integration tests. (#1756)

This commit is contained in:
Kristoffer Dalby 2024-02-18 19:31:29 +01:00 committed by GitHub
parent f581d4d9c0
commit 384ca03208
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
119 changed files with 3686 additions and 443 deletions

View file

@ -1,106 +0,0 @@
// Codehere is mostly taken from github.com/tailscale/tailscale
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package db
import (
"errors"
"fmt"
"net/netip"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"go4.org/netipx"
"gorm.io/gorm"
)
var ErrCouldNotAllocateIP = errors.New("could not find any suitable IP")
func (hsdb *HSDatabase) getAvailableIPs() (types.NodeAddresses, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (types.NodeAddresses, error) {
return getAvailableIPs(rx, hsdb.ipPrefixes)
})
}
func getAvailableIPs(rx *gorm.DB, ipPrefixes []netip.Prefix) (types.NodeAddresses, error) {
var ips types.NodeAddresses
var err error
for _, ipPrefix := range ipPrefixes {
var ip *netip.Addr
ip, err = getAvailableIP(rx, ipPrefix)
if err != nil {
return ips, err
}
ips = append(ips, *ip)
}
return ips, err
}
func getAvailableIP(rx *gorm.DB, ipPrefix netip.Prefix) (*netip.Addr, error) {
usedIps, err := getUsedIPs(rx)
if err != nil {
return nil, err
}
ipPrefixNetworkAddress, ipPrefixBroadcastAddress := util.GetIPPrefixEndpoints(ipPrefix)
// Get the first IP in our prefix
ip := ipPrefixNetworkAddress.Next()
for {
if !ipPrefix.Contains(ip) {
return nil, ErrCouldNotAllocateIP
}
switch {
case ip.Compare(ipPrefixBroadcastAddress) == 0:
fallthrough
case usedIps.Contains(ip):
fallthrough
case ip == netip.Addr{} || ip.IsLoopback():
ip = ip.Next()
continue
default:
return &ip, nil
}
}
}
func getUsedIPs(rx *gorm.DB) (*netipx.IPSet, error) {
// FIXME: This really deserves a better data model,
// but this was quick to get running and it should be enough
// to begin experimenting with a dual stack tailnet.
var addressesSlices []string
rx.Model(&types.Node{}).Pluck("ip_addresses", &addressesSlices)
var ips netipx.IPSetBuilder
for _, slice := range addressesSlices {
var machineAddresses types.NodeAddresses
err := machineAddresses.Scan(slice)
if err != nil {
return &netipx.IPSet{}, fmt.Errorf(
"failed to read ip from database: %w",
err,
)
}
for _, ip := range machineAddresses {
ips.Add(ip)
}
}
ipSet, err := ips.IPSet()
if err != nil {
return &netipx.IPSet{}, fmt.Errorf(
"failed to build IP Set: %w",
err,
)
}
return ipSet, nil
}

View file

@ -1,196 +0,0 @@
package db
import (
"net/netip"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"go4.org/netipx"
"gopkg.in/check.v1"
"gorm.io/gorm"
)
func (s *Suite) TestGetAvailableIp(c *check.C) {
tx := db.DB.Begin()
defer tx.Rollback()
ips, err := getAvailableIPs(tx, []netip.Prefix{
netip.MustParsePrefix("10.27.0.0/23"),
})
c.Assert(err, check.IsNil)
expected := netip.MustParseAddr("10.27.0.1")
c.Assert(len(ips), check.Equals, 1)
c.Assert(ips[0].String(), check.Equals, expected.String())
}
func (s *Suite) TestGetUsedIps(c *check.C) {
ips, err := db.getAvailableIPs()
c.Assert(err, check.IsNil)
user, err := db.CreateUser("test-ip")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.getNode("test", "testnode")
c.Assert(err, check.NotNil)
node := types.Node{
ID: 0,
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
IPAddresses: ips,
}
db.Write(func(tx *gorm.DB) error {
return tx.Save(&node).Error
})
usedIps, err := Read(db.DB, func(rx *gorm.DB) (*netipx.IPSet, error) {
return getUsedIPs(rx)
})
c.Assert(err, check.IsNil)
expected := netip.MustParseAddr("10.27.0.1")
expectedIPSetBuilder := netipx.IPSetBuilder{}
expectedIPSetBuilder.Add(expected)
expectedIPSet, _ := expectedIPSetBuilder.IPSet()
c.Assert(usedIps.Equal(expectedIPSet), check.Equals, true)
c.Assert(usedIps.Contains(expected), check.Equals, true)
node1, err := db.GetNodeByID(0)
c.Assert(err, check.IsNil)
c.Assert(len(node1.IPAddresses), check.Equals, 1)
c.Assert(node1.IPAddresses[0], check.Equals, expected)
}
func (s *Suite) TestGetMultiIp(c *check.C) {
user, err := db.CreateUser("test-ip")
c.Assert(err, check.IsNil)
ipPrefixes := []netip.Prefix{
netip.MustParsePrefix("10.27.0.0/23"),
}
for index := 1; index <= 350; index++ {
tx := db.DB.Begin()
ips, err := getAvailableIPs(tx, ipPrefixes)
c.Assert(err, check.IsNil)
pak, err := CreatePreAuthKey(tx, user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = getNode(tx, "test", "testnode")
c.Assert(err, check.NotNil)
node := types.Node{
ID: uint64(index),
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
IPAddresses: ips,
}
tx.Save(&node)
c.Assert(tx.Commit().Error, check.IsNil)
}
usedIps, err := Read(db.DB, func(rx *gorm.DB) (*netipx.IPSet, error) {
return getUsedIPs(rx)
})
c.Assert(err, check.IsNil)
expected0 := netip.MustParseAddr("10.27.0.1")
expected9 := netip.MustParseAddr("10.27.0.10")
expected300 := netip.MustParseAddr("10.27.0.45")
notExpectedIPSetBuilder := netipx.IPSetBuilder{}
notExpectedIPSetBuilder.Add(expected0)
notExpectedIPSetBuilder.Add(expected9)
notExpectedIPSetBuilder.Add(expected300)
notExpectedIPSet, err := notExpectedIPSetBuilder.IPSet()
c.Assert(err, check.IsNil)
// We actually expect it to be a lot larger
c.Assert(usedIps.Equal(notExpectedIPSet), check.Equals, false)
c.Assert(usedIps.Contains(expected0), check.Equals, true)
c.Assert(usedIps.Contains(expected9), check.Equals, true)
c.Assert(usedIps.Contains(expected300), check.Equals, true)
// Check that we can read back the IPs
node1, err := db.GetNodeByID(1)
c.Assert(err, check.IsNil)
c.Assert(len(node1.IPAddresses), check.Equals, 1)
c.Assert(
node1.IPAddresses[0],
check.Equals,
netip.MustParseAddr("10.27.0.1"),
)
node50, err := db.GetNodeByID(50)
c.Assert(err, check.IsNil)
c.Assert(len(node50.IPAddresses), check.Equals, 1)
c.Assert(
node50.IPAddresses[0],
check.Equals,
netip.MustParseAddr("10.27.0.50"),
)
expectedNextIP := netip.MustParseAddr("10.27.1.95")
nextIP, err := db.getAvailableIPs()
c.Assert(err, check.IsNil)
c.Assert(len(nextIP), check.Equals, 1)
c.Assert(nextIP[0].String(), check.Equals, expectedNextIP.String())
// If we call get Available again, we should receive
// the same IP, as it has not been reserved.
nextIP2, err := db.getAvailableIPs()
c.Assert(err, check.IsNil)
c.Assert(len(nextIP2), check.Equals, 1)
c.Assert(nextIP2[0].String(), check.Equals, expectedNextIP.String())
}
func (s *Suite) TestGetAvailableIpNodeWithoutIP(c *check.C) {
ips, err := db.getAvailableIPs()
c.Assert(err, check.IsNil)
expected := netip.MustParseAddr("10.27.0.1")
c.Assert(len(ips), check.Equals, 1)
c.Assert(ips[0].String(), check.Equals, expected.String())
user, err := db.CreateUser("test-ip")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.getNode("test", "testnode")
c.Assert(err, check.NotNil)
node := types.Node{
ID: 0,
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
db.DB.Save(&node)
ips2, err := db.getAvailableIPs()
c.Assert(err, check.IsNil)
c.Assert(len(ips2), check.Equals, 1)
c.Assert(ips2[0].String(), check.Equals, expected.String())
}

View file

@ -5,7 +5,6 @@ import (
"database/sql"
"errors"
"fmt"
"net/netip"
"path/filepath"
"strconv"
"strings"
@ -18,7 +17,6 @@ import (
"gorm.io/gorm"
"gorm.io/gorm/logger"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
)
@ -35,7 +33,6 @@ type KV struct {
type HSDatabase struct {
DB *gorm.DB
ipPrefixes []netip.Prefix
baseDomain string
}
@ -43,8 +40,6 @@ type HSDatabase struct {
// rather than arguments.
func NewHeadscaleDatabase(
cfg types.DatabaseConfig,
notifier *notifier.Notifier,
ipPrefixes []netip.Prefix,
baseDomain string,
) (*HSDatabase, error) {
dbConn, err := openDB(cfg)
@ -327,7 +322,6 @@ func NewHeadscaleDatabase(
db := HSDatabase{
DB: dbConn,
ipPrefixes: ipPrefixes,
baseDomain: baseDomain,
}
@ -351,6 +345,11 @@ func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) {
return nil, fmt.Errorf("creating directory for sqlite: %w", err)
}
log.Info().
Str("database", types.DatabaseSqlite).
Str("path", cfg.Sqlite.Path).
Msg("Opening database")
db, err := gorm.Open(
sqlite.Open(cfg.Sqlite.Path+"?_synchronous=1&_journal_mode=WAL"),
&gorm.Config{
@ -379,6 +378,11 @@ func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) {
cfg.Postgres.User,
)
log.Info().
Str("database", types.DatabasePostgres).
Str("path", dbString).
Msg("Opening database")
if sslEnabled, err := strconv.ParseBool(cfg.Postgres.Ssl); err == nil {
if !sslEnabled {
dbString += " sslmode=disable"

148
hscontrol/db/ip.go Normal file
View file

@ -0,0 +1,148 @@
package db
import (
"errors"
"fmt"
"net/netip"
"sync"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"go4.org/netipx"
"gorm.io/gorm"
)
// IPAllocator is a singleton responsible for allocating
// IP addresses for nodes and making sure the same
// address is not handed out twice. There can only be one
// and it needs to be created before any other database
// writes occur.
type IPAllocator struct {
mu sync.Mutex
prefix4 netip.Prefix
prefix6 netip.Prefix
// Previous IPs handed out
prev4 netip.Addr
prev6 netip.Addr
// Set of all IPs handed out.
// This might not be in sync with the database,
// but it is more conservative. If saves to the
// database fails, the IP will be allocated here
// until the next restart of Headscale.
usedIPs netipx.IPSetBuilder
}
// NewIPAllocator returns a new IPAllocator singleton which
// can be used to hand out unique IP addresses within the
// provided IPv4 and IPv6 prefix. It needs to be created
// when headscale starts and needs to finish its read
// transaction before any writes to the database occur.
func NewIPAllocator(db *HSDatabase, prefix4, prefix6 netip.Prefix) (*IPAllocator, error) {
var addressesSlices []string
if db != nil {
db.Read(func(rx *gorm.DB) error {
return rx.Model(&types.Node{}).Pluck("ip_addresses", &addressesSlices).Error
})
}
var ips netipx.IPSetBuilder
// Add network and broadcast addrs to used pool so they
// are not handed out to nodes.
network4, broadcast4 := util.GetIPPrefixEndpoints(prefix4)
network6, broadcast6 := util.GetIPPrefixEndpoints(prefix6)
ips.Add(network4)
ips.Add(broadcast4)
ips.Add(network6)
ips.Add(broadcast6)
// Fetch all the IP Addresses currently handed out from the Database
// and add them to the used IP set.
for _, slice := range addressesSlices {
var machineAddresses types.NodeAddresses
err := machineAddresses.Scan(slice)
if err != nil {
return nil, fmt.Errorf(
"parsing IPs from database %v: %w", machineAddresses,
err,
)
}
for _, ip := range machineAddresses {
ips.Add(ip)
}
}
// Build the initial IPSet to validate that we can use it.
_, err := ips.IPSet()
if err != nil {
return nil, fmt.Errorf(
"building initial IP Set: %w",
err,
)
}
return &IPAllocator{
usedIPs: ips,
prefix4: prefix4,
prefix6: prefix6,
// Use network as starting point, it will be used to call .Next()
// TODO(kradalby): Could potentially take all the IPs loaded from
// the database into account to start at a more "educated" location.
prev4: network4,
prev6: network6,
}, nil
}
func (i *IPAllocator) Next() (types.NodeAddresses, error) {
i.mu.Lock()
defer i.mu.Unlock()
v4, err := i.next(i.prev4, i.prefix4)
if err != nil {
return nil, fmt.Errorf("allocating IPv4 address: %w", err)
}
v6, err := i.next(i.prev6, i.prefix6)
if err != nil {
return nil, fmt.Errorf("allocating IPv6 address: %w", err)
}
return types.NodeAddresses{*v4, *v6}, nil
}
var ErrCouldNotAllocateIP = errors.New("failed to allocate IP")
func (i *IPAllocator) next(prev netip.Addr, prefix netip.Prefix) (*netip.Addr, error) {
// Get the first IP in our prefix
ip := prev.Next()
// TODO(kradalby): maybe this can be done less often.
set, err := i.usedIPs.IPSet()
if err != nil {
return nil, err
}
for {
if !prefix.Contains(ip) {
return nil, ErrCouldNotAllocateIP
}
// Check if the IP has already been allocated.
if set.Contains(ip) {
ip = ip.Next()
continue
}
i.usedIPs.Add(ip)
return &ip, nil
}
}

151
hscontrol/db/ip_test.go Normal file
View file

@ -0,0 +1,151 @@
package db
import (
"net/netip"
"os"
"testing"
"github.com/davecgh/go-spew/spew"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
)
func TestIPAllocator(t *testing.T) {
mpp := func(pref string) netip.Prefix {
return netip.MustParsePrefix(pref)
}
na := func(pref string) netip.Addr {
return netip.MustParseAddr(pref)
}
newDb := func() *HSDatabase {
tmpDir, err := os.MkdirTemp("", "headscale-db-test-*")
if err != nil {
t.Fatalf("creating temp dir: %s", err)
}
db, _ = NewHeadscaleDatabase(
types.DatabaseConfig{
Type: "sqlite3",
Sqlite: types.SqliteConfig{
Path: tmpDir + "/headscale_test.db",
},
},
"",
)
return db
}
tests := []struct {
name string
dbFunc func() *HSDatabase
prefix4 netip.Prefix
prefix6 netip.Prefix
getCount int
want []types.NodeAddresses
}{
{
name: "simple",
dbFunc: func() *HSDatabase {
return nil
},
prefix4: mpp("100.64.0.0/10"),
prefix6: mpp("fd7a:115c:a1e0::/48"),
getCount: 1,
want: []types.NodeAddresses{
{
na("100.64.0.1"),
na("fd7a:115c:a1e0::1"),
},
},
},
{
name: "simple-with-db",
dbFunc: func() *HSDatabase {
db := newDb()
db.DB.Save(&types.Node{
IPAddresses: types.NodeAddresses{
na("100.64.0.1"),
na("fd7a:115c:a1e0::1"),
},
})
return db
},
prefix4: mpp("100.64.0.0/10"),
prefix6: mpp("fd7a:115c:a1e0::/48"),
getCount: 1,
want: []types.NodeAddresses{
{
na("100.64.0.2"),
na("fd7a:115c:a1e0::2"),
},
},
},
{
name: "before-after-free-middle-in-db",
dbFunc: func() *HSDatabase {
db := newDb()
db.DB.Save(&types.Node{
IPAddresses: types.NodeAddresses{
na("100.64.0.2"),
na("fd7a:115c:a1e0::2"),
},
})
return db
},
prefix4: mpp("100.64.0.0/10"),
prefix6: mpp("fd7a:115c:a1e0::/48"),
getCount: 2,
want: []types.NodeAddresses{
{
na("100.64.0.1"),
na("fd7a:115c:a1e0::1"),
},
{
na("100.64.0.3"),
na("fd7a:115c:a1e0::3"),
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := tt.dbFunc()
alloc, _ := NewIPAllocator(db, tt.prefix4, tt.prefix6)
spew.Dump(alloc)
t.Logf("prefixes: %q, %q", tt.prefix4.String(), tt.prefix6.String())
var got []types.NodeAddresses
for range tt.getCount {
gotSet, err := alloc.Next()
if err != nil {
t.Fatalf("allocating next IP: %s", err)
}
got = append(got, gotSet)
}
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
t.Errorf("IPAllocator unexpected result (-want +got):\n%s", diff)
}
})
}
}

View file

@ -307,7 +307,7 @@ func RegisterNodeFromAuthCallback(
userName string,
nodeExpiry *time.Time,
registrationMethod string,
ipPrefixes []netip.Prefix,
addrs types.NodeAddresses,
) (*types.Node, error) {
log.Debug().
Str("machine_key", mkey.ShortString()).
@ -343,7 +343,7 @@ func RegisterNodeFromAuthCallback(
node, err := RegisterNode(
tx,
registrationNode,
ipPrefixes,
addrs,
)
if err == nil {
@ -359,14 +359,14 @@ func RegisterNodeFromAuthCallback(
return nil, ErrNodeNotFoundRegistrationCache
}
func (hsdb *HSDatabase) RegisterNode(node types.Node) (*types.Node, error) {
func (hsdb *HSDatabase) RegisterNode(node types.Node, addrs types.NodeAddresses) (*types.Node, error) {
return Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
return RegisterNode(tx, node, hsdb.ipPrefixes)
return RegisterNode(tx, node, addrs)
})
}
// RegisterNode is executed from the CLI to register a new Node using its MachineKey.
func RegisterNode(tx *gorm.DB, node types.Node, ipPrefixes []netip.Prefix) (*types.Node, error) {
func RegisterNode(tx *gorm.DB, node types.Node, addrs types.NodeAddresses) (*types.Node, error) {
log.Debug().
Str("node", node.Hostname).
Str("machine_key", node.MachineKey.ShortString()).
@ -393,18 +393,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipPrefixes []netip.Prefix) (*typ
return &node, nil
}
ips, err := getAvailableIPs(tx, ipPrefixes)
if err != nil {
log.Error().
Caller().
Err(err).
Str("node", node.Hostname).
Msg("Could not find IP for the new node")
return nil, err
}
node.IPAddresses = ips
node.IPAddresses = addrs
if err := tx.Save(&node).Error; err != nil {
return nil, fmt.Errorf("failed register(save) node in the database: %w", err)
@ -413,7 +402,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipPrefixes []netip.Prefix) (*typ
log.Trace().
Caller().
Str("node", node.Hostname).
Str("ip", strings.Join(ips.StringSlice(), ",")).
Str("ip", strings.Join(addrs.StringSlice(), ",")).
Msg("Node registered with the database")
return &node, nil

View file

@ -7,7 +7,6 @@ import (
"time"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/stretchr/testify/assert"
@ -661,10 +660,6 @@ func TestFailoverRoute(t *testing.T) {
Path: tmpDir + "/headscale_test.db",
},
},
notifier.NewNotifier(),
[]netip.Prefix{
netip.MustParsePrefix("10.27.0.0/23"),
},
"",
)
assert.NoError(t, err)

View file

@ -2,11 +2,9 @@ package db
import (
"log"
"net/netip"
"os"
"testing"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/types"
"gopkg.in/check.v1"
)
@ -52,10 +50,6 @@ func (s *Suite) ResetDB(c *check.C) {
Path: tmpDir + "/headscale_test.db",
},
},
notifier.NewNotifier(),
[]netip.Prefix{
netip.MustParsePrefix("10.27.0.0/23"),
},
"",
)
if err != nil {