new IP allocator and add postgres to integration tests. (#1756)
This commit is contained in:
parent
f581d4d9c0
commit
384ca03208
119 changed files with 3686 additions and 443 deletions
|
@ -9,6 +9,7 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
_ "net/http/pprof" //nolint
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
|
@ -80,6 +81,7 @@ const (
|
|||
type Headscale struct {
|
||||
cfg *types.Config
|
||||
db *db.HSDatabase
|
||||
ipAlloc *db.IPAllocator
|
||||
noisePrivateKey *key.MachinePrivate
|
||||
|
||||
DERPMap *tailcfg.DERPMap
|
||||
|
@ -106,6 +108,7 @@ var (
|
|||
)
|
||||
|
||||
func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||
var err error
|
||||
if profilingEnabled {
|
||||
runtime.SetBlockProfileRate(1)
|
||||
}
|
||||
|
@ -128,16 +131,17 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
|||
nodeNotifier: notifier.NewNotifier(),
|
||||
}
|
||||
|
||||
database, err := db.NewHeadscaleDatabase(
|
||||
app.db, err = db.NewHeadscaleDatabase(
|
||||
cfg.Database,
|
||||
app.nodeNotifier,
|
||||
cfg.IPPrefixes,
|
||||
cfg.BaseDomain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
app.db = database
|
||||
app.ipAlloc, err = db.NewIPAllocator(app.db, *cfg.PrefixV4, *cfg.PrefixV6)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cfg.OIDC.Issuer != "" {
|
||||
err = app.initOIDC()
|
||||
|
@ -151,7 +155,8 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
|||
}
|
||||
|
||||
if app.cfg.DNSConfig != nil && app.cfg.DNSConfig.Proxied { // if MagicDNS
|
||||
magicDNSDomains := util.GenerateMagicDNSRootDomains(app.cfg.IPPrefixes)
|
||||
// TODO(kradalby): revisit why this takes a list.
|
||||
magicDNSDomains := util.GenerateMagicDNSRootDomains([]netip.Prefix{*cfg.PrefixV4, *cfg.PrefixV6})
|
||||
// we might have routes already from Split DNS
|
||||
if app.cfg.DNSConfig.Routes == nil {
|
||||
app.cfg.DNSConfig.Routes = make(map[string][]*dnstype.Resolver)
|
||||
|
|
|
@ -388,8 +388,21 @@ func (h *Headscale) handleAuthKey(
|
|||
ForcedTags: pak.Proto().GetAclTags(),
|
||||
}
|
||||
|
||||
addrs, err := h.ipAlloc.Next()
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("func", "RegistrationHandler").
|
||||
Str("hostinfo.name", registerRequest.Hostinfo.Hostname).
|
||||
Err(err).
|
||||
Msg("failed to allocate IP ")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
node, err = h.db.RegisterNode(
|
||||
nodeToRegister,
|
||||
addrs,
|
||||
)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
|
|
|
@ -1,106 +0,0 @@
|
|||
// Codehere is mostly taken from github.com/tailscale/tailscale
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"go4.org/netipx"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var ErrCouldNotAllocateIP = errors.New("could not find any suitable IP")
|
||||
|
||||
func (hsdb *HSDatabase) getAvailableIPs() (types.NodeAddresses, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (types.NodeAddresses, error) {
|
||||
return getAvailableIPs(rx, hsdb.ipPrefixes)
|
||||
})
|
||||
}
|
||||
|
||||
func getAvailableIPs(rx *gorm.DB, ipPrefixes []netip.Prefix) (types.NodeAddresses, error) {
|
||||
var ips types.NodeAddresses
|
||||
var err error
|
||||
for _, ipPrefix := range ipPrefixes {
|
||||
var ip *netip.Addr
|
||||
ip, err = getAvailableIP(rx, ipPrefix)
|
||||
if err != nil {
|
||||
return ips, err
|
||||
}
|
||||
ips = append(ips, *ip)
|
||||
}
|
||||
|
||||
return ips, err
|
||||
}
|
||||
|
||||
func getAvailableIP(rx *gorm.DB, ipPrefix netip.Prefix) (*netip.Addr, error) {
|
||||
usedIps, err := getUsedIPs(rx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ipPrefixNetworkAddress, ipPrefixBroadcastAddress := util.GetIPPrefixEndpoints(ipPrefix)
|
||||
|
||||
// Get the first IP in our prefix
|
||||
ip := ipPrefixNetworkAddress.Next()
|
||||
|
||||
for {
|
||||
if !ipPrefix.Contains(ip) {
|
||||
return nil, ErrCouldNotAllocateIP
|
||||
}
|
||||
|
||||
switch {
|
||||
case ip.Compare(ipPrefixBroadcastAddress) == 0:
|
||||
fallthrough
|
||||
case usedIps.Contains(ip):
|
||||
fallthrough
|
||||
case ip == netip.Addr{} || ip.IsLoopback():
|
||||
ip = ip.Next()
|
||||
|
||||
continue
|
||||
|
||||
default:
|
||||
return &ip, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getUsedIPs(rx *gorm.DB) (*netipx.IPSet, error) {
|
||||
// FIXME: This really deserves a better data model,
|
||||
// but this was quick to get running and it should be enough
|
||||
// to begin experimenting with a dual stack tailnet.
|
||||
var addressesSlices []string
|
||||
rx.Model(&types.Node{}).Pluck("ip_addresses", &addressesSlices)
|
||||
|
||||
var ips netipx.IPSetBuilder
|
||||
for _, slice := range addressesSlices {
|
||||
var machineAddresses types.NodeAddresses
|
||||
err := machineAddresses.Scan(slice)
|
||||
if err != nil {
|
||||
return &netipx.IPSet{}, fmt.Errorf(
|
||||
"failed to read ip from database: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
for _, ip := range machineAddresses {
|
||||
ips.Add(ip)
|
||||
}
|
||||
}
|
||||
|
||||
ipSet, err := ips.IPSet()
|
||||
if err != nil {
|
||||
return &netipx.IPSet{}, fmt.Errorf(
|
||||
"failed to build IP Set: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
return ipSet, nil
|
||||
}
|
|
@ -1,196 +0,0 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"go4.org/netipx"
|
||||
"gopkg.in/check.v1"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func (s *Suite) TestGetAvailableIp(c *check.C) {
|
||||
tx := db.DB.Begin()
|
||||
defer tx.Rollback()
|
||||
|
||||
ips, err := getAvailableIPs(tx, []netip.Prefix{
|
||||
netip.MustParsePrefix("10.27.0.0/23"),
|
||||
})
|
||||
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
expected := netip.MustParseAddr("10.27.0.1")
|
||||
|
||||
c.Assert(len(ips), check.Equals, 1)
|
||||
c.Assert(ips[0].String(), check.Equals, expected.String())
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetUsedIps(c *check.C) {
|
||||
ips, err := db.getAvailableIPs()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
user, err := db.CreateUser("test-ip")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.getNode("test", "testnode")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
node := types.Node{
|
||||
ID: 0,
|
||||
Hostname: "testnode",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
IPAddresses: ips,
|
||||
}
|
||||
db.Write(func(tx *gorm.DB) error {
|
||||
return tx.Save(&node).Error
|
||||
})
|
||||
|
||||
usedIps, err := Read(db.DB, func(rx *gorm.DB) (*netipx.IPSet, error) {
|
||||
return getUsedIPs(rx)
|
||||
})
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
expected := netip.MustParseAddr("10.27.0.1")
|
||||
expectedIPSetBuilder := netipx.IPSetBuilder{}
|
||||
expectedIPSetBuilder.Add(expected)
|
||||
expectedIPSet, _ := expectedIPSetBuilder.IPSet()
|
||||
|
||||
c.Assert(usedIps.Equal(expectedIPSet), check.Equals, true)
|
||||
c.Assert(usedIps.Contains(expected), check.Equals, true)
|
||||
|
||||
node1, err := db.GetNodeByID(0)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(len(node1.IPAddresses), check.Equals, 1)
|
||||
c.Assert(node1.IPAddresses[0], check.Equals, expected)
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetMultiIp(c *check.C) {
|
||||
user, err := db.CreateUser("test-ip")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
ipPrefixes := []netip.Prefix{
|
||||
netip.MustParsePrefix("10.27.0.0/23"),
|
||||
}
|
||||
|
||||
for index := 1; index <= 350; index++ {
|
||||
tx := db.DB.Begin()
|
||||
|
||||
ips, err := getAvailableIPs(tx, ipPrefixes)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := CreatePreAuthKey(tx, user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = getNode(tx, "test", "testnode")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
node := types.Node{
|
||||
ID: uint64(index),
|
||||
Hostname: "testnode",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
IPAddresses: ips,
|
||||
}
|
||||
tx.Save(&node)
|
||||
c.Assert(tx.Commit().Error, check.IsNil)
|
||||
}
|
||||
|
||||
usedIps, err := Read(db.DB, func(rx *gorm.DB) (*netipx.IPSet, error) {
|
||||
return getUsedIPs(rx)
|
||||
})
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
expected0 := netip.MustParseAddr("10.27.0.1")
|
||||
expected9 := netip.MustParseAddr("10.27.0.10")
|
||||
expected300 := netip.MustParseAddr("10.27.0.45")
|
||||
|
||||
notExpectedIPSetBuilder := netipx.IPSetBuilder{}
|
||||
notExpectedIPSetBuilder.Add(expected0)
|
||||
notExpectedIPSetBuilder.Add(expected9)
|
||||
notExpectedIPSetBuilder.Add(expected300)
|
||||
notExpectedIPSet, err := notExpectedIPSetBuilder.IPSet()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
// We actually expect it to be a lot larger
|
||||
c.Assert(usedIps.Equal(notExpectedIPSet), check.Equals, false)
|
||||
|
||||
c.Assert(usedIps.Contains(expected0), check.Equals, true)
|
||||
c.Assert(usedIps.Contains(expected9), check.Equals, true)
|
||||
c.Assert(usedIps.Contains(expected300), check.Equals, true)
|
||||
|
||||
// Check that we can read back the IPs
|
||||
node1, err := db.GetNodeByID(1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(node1.IPAddresses), check.Equals, 1)
|
||||
c.Assert(
|
||||
node1.IPAddresses[0],
|
||||
check.Equals,
|
||||
netip.MustParseAddr("10.27.0.1"),
|
||||
)
|
||||
|
||||
node50, err := db.GetNodeByID(50)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(node50.IPAddresses), check.Equals, 1)
|
||||
c.Assert(
|
||||
node50.IPAddresses[0],
|
||||
check.Equals,
|
||||
netip.MustParseAddr("10.27.0.50"),
|
||||
)
|
||||
|
||||
expectedNextIP := netip.MustParseAddr("10.27.1.95")
|
||||
nextIP, err := db.getAvailableIPs()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(len(nextIP), check.Equals, 1)
|
||||
c.Assert(nextIP[0].String(), check.Equals, expectedNextIP.String())
|
||||
|
||||
// If we call get Available again, we should receive
|
||||
// the same IP, as it has not been reserved.
|
||||
nextIP2, err := db.getAvailableIPs()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(len(nextIP2), check.Equals, 1)
|
||||
c.Assert(nextIP2[0].String(), check.Equals, expectedNextIP.String())
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetAvailableIpNodeWithoutIP(c *check.C) {
|
||||
ips, err := db.getAvailableIPs()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
expected := netip.MustParseAddr("10.27.0.1")
|
||||
|
||||
c.Assert(len(ips), check.Equals, 1)
|
||||
c.Assert(ips[0].String(), check.Equals, expected.String())
|
||||
|
||||
user, err := db.CreateUser("test-ip")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.getNode("test", "testnode")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
node := types.Node{
|
||||
ID: 0,
|
||||
Hostname: "testnode",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
db.DB.Save(&node)
|
||||
|
||||
ips2, err := db.getAvailableIPs()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(len(ips2), check.Equals, 1)
|
||||
c.Assert(ips2[0].String(), check.Equals, expected.String())
|
||||
}
|
|
@ -5,7 +5,6 @@ import (
|
|||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
@ -18,7 +17,6 @@ import (
|
|||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
)
|
||||
|
@ -35,7 +33,6 @@ type KV struct {
|
|||
type HSDatabase struct {
|
||||
DB *gorm.DB
|
||||
|
||||
ipPrefixes []netip.Prefix
|
||||
baseDomain string
|
||||
}
|
||||
|
||||
|
@ -43,8 +40,6 @@ type HSDatabase struct {
|
|||
// rather than arguments.
|
||||
func NewHeadscaleDatabase(
|
||||
cfg types.DatabaseConfig,
|
||||
notifier *notifier.Notifier,
|
||||
ipPrefixes []netip.Prefix,
|
||||
baseDomain string,
|
||||
) (*HSDatabase, error) {
|
||||
dbConn, err := openDB(cfg)
|
||||
|
@ -327,7 +322,6 @@ func NewHeadscaleDatabase(
|
|||
db := HSDatabase{
|
||||
DB: dbConn,
|
||||
|
||||
ipPrefixes: ipPrefixes,
|
||||
baseDomain: baseDomain,
|
||||
}
|
||||
|
||||
|
@ -351,6 +345,11 @@ func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) {
|
|||
return nil, fmt.Errorf("creating directory for sqlite: %w", err)
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("database", types.DatabaseSqlite).
|
||||
Str("path", cfg.Sqlite.Path).
|
||||
Msg("Opening database")
|
||||
|
||||
db, err := gorm.Open(
|
||||
sqlite.Open(cfg.Sqlite.Path+"?_synchronous=1&_journal_mode=WAL"),
|
||||
&gorm.Config{
|
||||
|
@ -379,6 +378,11 @@ func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) {
|
|||
cfg.Postgres.User,
|
||||
)
|
||||
|
||||
log.Info().
|
||||
Str("database", types.DatabasePostgres).
|
||||
Str("path", dbString).
|
||||
Msg("Opening database")
|
||||
|
||||
if sslEnabled, err := strconv.ParseBool(cfg.Postgres.Ssl); err == nil {
|
||||
if !sslEnabled {
|
||||
dbString += " sslmode=disable"
|
||||
|
|
148
hscontrol/db/ip.go
Normal file
148
hscontrol/db/ip.go
Normal file
|
@ -0,0 +1,148 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"go4.org/netipx"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// IPAllocator is a singleton responsible for allocating
|
||||
// IP addresses for nodes and making sure the same
|
||||
// address is not handed out twice. There can only be one
|
||||
// and it needs to be created before any other database
|
||||
// writes occur.
|
||||
type IPAllocator struct {
|
||||
mu sync.Mutex
|
||||
|
||||
prefix4 netip.Prefix
|
||||
prefix6 netip.Prefix
|
||||
|
||||
// Previous IPs handed out
|
||||
prev4 netip.Addr
|
||||
prev6 netip.Addr
|
||||
|
||||
// Set of all IPs handed out.
|
||||
// This might not be in sync with the database,
|
||||
// but it is more conservative. If saves to the
|
||||
// database fails, the IP will be allocated here
|
||||
// until the next restart of Headscale.
|
||||
usedIPs netipx.IPSetBuilder
|
||||
}
|
||||
|
||||
// NewIPAllocator returns a new IPAllocator singleton which
|
||||
// can be used to hand out unique IP addresses within the
|
||||
// provided IPv4 and IPv6 prefix. It needs to be created
|
||||
// when headscale starts and needs to finish its read
|
||||
// transaction before any writes to the database occur.
|
||||
func NewIPAllocator(db *HSDatabase, prefix4, prefix6 netip.Prefix) (*IPAllocator, error) {
|
||||
var addressesSlices []string
|
||||
|
||||
if db != nil {
|
||||
db.Read(func(rx *gorm.DB) error {
|
||||
return rx.Model(&types.Node{}).Pluck("ip_addresses", &addressesSlices).Error
|
||||
})
|
||||
}
|
||||
|
||||
var ips netipx.IPSetBuilder
|
||||
|
||||
// Add network and broadcast addrs to used pool so they
|
||||
// are not handed out to nodes.
|
||||
network4, broadcast4 := util.GetIPPrefixEndpoints(prefix4)
|
||||
network6, broadcast6 := util.GetIPPrefixEndpoints(prefix6)
|
||||
ips.Add(network4)
|
||||
ips.Add(broadcast4)
|
||||
ips.Add(network6)
|
||||
ips.Add(broadcast6)
|
||||
|
||||
// Fetch all the IP Addresses currently handed out from the Database
|
||||
// and add them to the used IP set.
|
||||
for _, slice := range addressesSlices {
|
||||
var machineAddresses types.NodeAddresses
|
||||
err := machineAddresses.Scan(slice)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"parsing IPs from database %v: %w", machineAddresses,
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
for _, ip := range machineAddresses {
|
||||
ips.Add(ip)
|
||||
}
|
||||
}
|
||||
|
||||
// Build the initial IPSet to validate that we can use it.
|
||||
_, err := ips.IPSet()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"building initial IP Set: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
return &IPAllocator{
|
||||
usedIPs: ips,
|
||||
|
||||
prefix4: prefix4,
|
||||
prefix6: prefix6,
|
||||
|
||||
// Use network as starting point, it will be used to call .Next()
|
||||
// TODO(kradalby): Could potentially take all the IPs loaded from
|
||||
// the database into account to start at a more "educated" location.
|
||||
prev4: network4,
|
||||
prev6: network6,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (i *IPAllocator) Next() (types.NodeAddresses, error) {
|
||||
i.mu.Lock()
|
||||
defer i.mu.Unlock()
|
||||
|
||||
v4, err := i.next(i.prev4, i.prefix4)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("allocating IPv4 address: %w", err)
|
||||
}
|
||||
|
||||
v6, err := i.next(i.prev6, i.prefix6)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("allocating IPv6 address: %w", err)
|
||||
}
|
||||
|
||||
return types.NodeAddresses{*v4, *v6}, nil
|
||||
}
|
||||
|
||||
var ErrCouldNotAllocateIP = errors.New("failed to allocate IP")
|
||||
|
||||
func (i *IPAllocator) next(prev netip.Addr, prefix netip.Prefix) (*netip.Addr, error) {
|
||||
// Get the first IP in our prefix
|
||||
ip := prev.Next()
|
||||
|
||||
// TODO(kradalby): maybe this can be done less often.
|
||||
set, err := i.usedIPs.IPSet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for {
|
||||
if !prefix.Contains(ip) {
|
||||
return nil, ErrCouldNotAllocateIP
|
||||
}
|
||||
|
||||
// Check if the IP has already been allocated.
|
||||
if set.Contains(ip) {
|
||||
ip = ip.Next()
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
i.usedIPs.Add(ip)
|
||||
|
||||
return &ip, nil
|
||||
}
|
||||
}
|
151
hscontrol/db/ip_test.go
Normal file
151
hscontrol/db/ip_test.go
Normal file
|
@ -0,0 +1,151 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
)
|
||||
|
||||
func TestIPAllocator(t *testing.T) {
|
||||
mpp := func(pref string) netip.Prefix {
|
||||
return netip.MustParsePrefix(pref)
|
||||
}
|
||||
na := func(pref string) netip.Addr {
|
||||
return netip.MustParseAddr(pref)
|
||||
}
|
||||
newDb := func() *HSDatabase {
|
||||
tmpDir, err := os.MkdirTemp("", "headscale-db-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("creating temp dir: %s", err)
|
||||
}
|
||||
db, _ = NewHeadscaleDatabase(
|
||||
types.DatabaseConfig{
|
||||
Type: "sqlite3",
|
||||
Sqlite: types.SqliteConfig{
|
||||
Path: tmpDir + "/headscale_test.db",
|
||||
},
|
||||
},
|
||||
"",
|
||||
)
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
dbFunc func() *HSDatabase
|
||||
|
||||
prefix4 netip.Prefix
|
||||
prefix6 netip.Prefix
|
||||
getCount int
|
||||
want []types.NodeAddresses
|
||||
}{
|
||||
{
|
||||
name: "simple",
|
||||
dbFunc: func() *HSDatabase {
|
||||
return nil
|
||||
},
|
||||
|
||||
prefix4: mpp("100.64.0.0/10"),
|
||||
prefix6: mpp("fd7a:115c:a1e0::/48"),
|
||||
|
||||
getCount: 1,
|
||||
|
||||
want: []types.NodeAddresses{
|
||||
{
|
||||
na("100.64.0.1"),
|
||||
na("fd7a:115c:a1e0::1"),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "simple-with-db",
|
||||
dbFunc: func() *HSDatabase {
|
||||
db := newDb()
|
||||
|
||||
db.DB.Save(&types.Node{
|
||||
IPAddresses: types.NodeAddresses{
|
||||
na("100.64.0.1"),
|
||||
na("fd7a:115c:a1e0::1"),
|
||||
},
|
||||
})
|
||||
|
||||
return db
|
||||
},
|
||||
|
||||
prefix4: mpp("100.64.0.0/10"),
|
||||
prefix6: mpp("fd7a:115c:a1e0::/48"),
|
||||
|
||||
getCount: 1,
|
||||
|
||||
want: []types.NodeAddresses{
|
||||
{
|
||||
na("100.64.0.2"),
|
||||
na("fd7a:115c:a1e0::2"),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "before-after-free-middle-in-db",
|
||||
dbFunc: func() *HSDatabase {
|
||||
db := newDb()
|
||||
|
||||
db.DB.Save(&types.Node{
|
||||
IPAddresses: types.NodeAddresses{
|
||||
na("100.64.0.2"),
|
||||
na("fd7a:115c:a1e0::2"),
|
||||
},
|
||||
})
|
||||
|
||||
return db
|
||||
},
|
||||
|
||||
prefix4: mpp("100.64.0.0/10"),
|
||||
prefix6: mpp("fd7a:115c:a1e0::/48"),
|
||||
|
||||
getCount: 2,
|
||||
|
||||
want: []types.NodeAddresses{
|
||||
{
|
||||
na("100.64.0.1"),
|
||||
na("fd7a:115c:a1e0::1"),
|
||||
},
|
||||
{
|
||||
na("100.64.0.3"),
|
||||
na("fd7a:115c:a1e0::3"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db := tt.dbFunc()
|
||||
|
||||
alloc, _ := NewIPAllocator(db, tt.prefix4, tt.prefix6)
|
||||
|
||||
spew.Dump(alloc)
|
||||
|
||||
t.Logf("prefixes: %q, %q", tt.prefix4.String(), tt.prefix6.String())
|
||||
|
||||
var got []types.NodeAddresses
|
||||
|
||||
for range tt.getCount {
|
||||
gotSet, err := alloc.Next()
|
||||
if err != nil {
|
||||
t.Fatalf("allocating next IP: %s", err)
|
||||
}
|
||||
|
||||
got = append(got, gotSet)
|
||||
}
|
||||
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
|
||||
t.Errorf("IPAllocator unexpected result (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -307,7 +307,7 @@ func RegisterNodeFromAuthCallback(
|
|||
userName string,
|
||||
nodeExpiry *time.Time,
|
||||
registrationMethod string,
|
||||
ipPrefixes []netip.Prefix,
|
||||
addrs types.NodeAddresses,
|
||||
) (*types.Node, error) {
|
||||
log.Debug().
|
||||
Str("machine_key", mkey.ShortString()).
|
||||
|
@ -343,7 +343,7 @@ func RegisterNodeFromAuthCallback(
|
|||
node, err := RegisterNode(
|
||||
tx,
|
||||
registrationNode,
|
||||
ipPrefixes,
|
||||
addrs,
|
||||
)
|
||||
|
||||
if err == nil {
|
||||
|
@ -359,14 +359,14 @@ func RegisterNodeFromAuthCallback(
|
|||
return nil, ErrNodeNotFoundRegistrationCache
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) RegisterNode(node types.Node) (*types.Node, error) {
|
||||
func (hsdb *HSDatabase) RegisterNode(node types.Node, addrs types.NodeAddresses) (*types.Node, error) {
|
||||
return Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
return RegisterNode(tx, node, hsdb.ipPrefixes)
|
||||
return RegisterNode(tx, node, addrs)
|
||||
})
|
||||
}
|
||||
|
||||
// RegisterNode is executed from the CLI to register a new Node using its MachineKey.
|
||||
func RegisterNode(tx *gorm.DB, node types.Node, ipPrefixes []netip.Prefix) (*types.Node, error) {
|
||||
func RegisterNode(tx *gorm.DB, node types.Node, addrs types.NodeAddresses) (*types.Node, error) {
|
||||
log.Debug().
|
||||
Str("node", node.Hostname).
|
||||
Str("machine_key", node.MachineKey.ShortString()).
|
||||
|
@ -393,18 +393,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipPrefixes []netip.Prefix) (*typ
|
|||
return &node, nil
|
||||
}
|
||||
|
||||
ips, err := getAvailableIPs(tx, ipPrefixes)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Str("node", node.Hostname).
|
||||
Msg("Could not find IP for the new node")
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
node.IPAddresses = ips
|
||||
node.IPAddresses = addrs
|
||||
|
||||
if err := tx.Save(&node).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed register(save) node in the database: %w", err)
|
||||
|
@ -413,7 +402,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipPrefixes []netip.Prefix) (*typ
|
|||
log.Trace().
|
||||
Caller().
|
||||
Str("node", node.Hostname).
|
||||
Str("ip", strings.Join(ips.StringSlice(), ",")).
|
||||
Str("ip", strings.Join(addrs.StringSlice(), ",")).
|
||||
Msg("Node registered with the database")
|
||||
|
||||
return &node, nil
|
||||
|
|
|
@ -7,7 +7,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -661,10 +660,6 @@ func TestFailoverRoute(t *testing.T) {
|
|||
Path: tmpDir + "/headscale_test.db",
|
||||
},
|
||||
},
|
||||
notifier.NewNotifier(),
|
||||
[]netip.Prefix{
|
||||
netip.MustParsePrefix("10.27.0.0/23"),
|
||||
},
|
||||
"",
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
|
|
@ -2,11 +2,9 @@ package db
|
|||
|
||||
import (
|
||||
"log"
|
||||
"net/netip"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"gopkg.in/check.v1"
|
||||
)
|
||||
|
@ -52,10 +50,6 @@ func (s *Suite) ResetDB(c *check.C) {
|
|||
Path: tmpDir + "/headscale_test.db",
|
||||
},
|
||||
},
|
||||
notifier.NewNotifier(),
|
||||
[]netip.Prefix{
|
||||
netip.MustParsePrefix("10.27.0.0/23"),
|
||||
},
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
|
|
|
@ -4,6 +4,7 @@ package hscontrol
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -98,6 +99,10 @@ func (api headscaleV1APIServer) ListUsers(
|
|||
response[index] = user.Proto()
|
||||
}
|
||||
|
||||
sort.Slice(response, func(i, j int) bool {
|
||||
return response[i].Id < response[j].Id
|
||||
})
|
||||
|
||||
log.Trace().Caller().Interface("users", response).Msg("")
|
||||
|
||||
return &v1.ListUsersResponse{Users: response}, nil
|
||||
|
@ -168,6 +173,10 @@ func (api headscaleV1APIServer) ListPreAuthKeys(
|
|||
response[index] = key.Proto()
|
||||
}
|
||||
|
||||
sort.Slice(response, func(i, j int) bool {
|
||||
return response[i].Id < response[j].Id
|
||||
})
|
||||
|
||||
return &v1.ListPreAuthKeysResponse{PreAuthKeys: response}, nil
|
||||
}
|
||||
|
||||
|
@ -186,6 +195,11 @@ func (api headscaleV1APIServer) RegisterNode(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
addrs, err := api.h.ipAlloc.Next()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
return db.RegisterNodeFromAuthCallback(
|
||||
tx,
|
||||
|
@ -194,7 +208,7 @@ func (api headscaleV1APIServer) RegisterNode(
|
|||
request.GetUser(),
|
||||
nil,
|
||||
util.RegisterMethodCLI,
|
||||
api.h.cfg.IPPrefixes,
|
||||
addrs,
|
||||
)
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -422,6 +436,10 @@ func (api headscaleV1APIServer) ListNodes(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
sort.Slice(nodes, func(i, j int) bool {
|
||||
return nodes[i].ID < nodes[j].ID
|
||||
})
|
||||
|
||||
response := make([]*v1.Node, len(nodes))
|
||||
for index, node := range nodes {
|
||||
resp := node.Proto()
|
||||
|
@ -606,6 +624,10 @@ func (api headscaleV1APIServer) ListApiKeys(
|
|||
response[index] = key.Proto()
|
||||
}
|
||||
|
||||
sort.Slice(response, func(i, j int) bool {
|
||||
return response[i].Id < response[j].Id
|
||||
})
|
||||
|
||||
return &v1.ListApiKeysResponse{ApiKeys: response}, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -620,6 +620,11 @@ func (h *Headscale) registerNodeForOIDCCallback(
|
|||
machineKey *key.MachinePublic,
|
||||
expiry time.Time,
|
||||
) error {
|
||||
addrs, err := h.ipAlloc.Next()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := h.db.DB.Transaction(func(tx *gorm.DB) error {
|
||||
if _, err := db.RegisterNodeFromAuthCallback(
|
||||
// TODO(kradalby): find a better way to use the cache across modules
|
||||
|
@ -629,7 +634,7 @@ func (h *Headscale) registerNodeForOIDCCallback(
|
|||
user.Name,
|
||||
&expiry,
|
||||
util.RegisterMethodOIDC,
|
||||
h.cfg.IPPrefixes,
|
||||
addrs,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package hscontrol
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
|
@ -47,9 +46,6 @@ func (s *Suite) ResetDB(c *check.C) {
|
|||
Path: tmpDir + "/headscale_test.db",
|
||||
},
|
||||
},
|
||||
IPPrefixes: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.27.0.0/23"),
|
||||
},
|
||||
OIDC: types.OIDCConfig{
|
||||
StripEmaildomain: false,
|
||||
},
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/prometheus/common/model"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
@ -19,8 +20,6 @@ import (
|
|||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/dnstype"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -41,7 +40,8 @@ type Config struct {
|
|||
GRPCAllowInsecure bool
|
||||
EphemeralNodeInactivityTimeout time.Duration
|
||||
NodeUpdateCheckInterval time.Duration
|
||||
IPPrefixes []netip.Prefix
|
||||
PrefixV4 *netip.Prefix
|
||||
PrefixV6 *netip.Prefix
|
||||
NoisePrivateKeyPath string
|
||||
BaseDomain string
|
||||
Log LogConfig
|
||||
|
@ -569,6 +569,39 @@ func GetDNSConfig() (*tailcfg.DNSConfig, string) {
|
|||
return nil, ""
|
||||
}
|
||||
|
||||
func Prefixes() (*netip.Prefix, *netip.Prefix, error) {
|
||||
prefixV4Str := viper.GetString("prefixes.v4")
|
||||
prefixV6Str := viper.GetString("prefixes.v6")
|
||||
|
||||
prefixV4, err := netip.ParsePrefix(prefixV4Str)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
prefixV6, err := netip.ParsePrefix(prefixV6Str)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
builder := netipx.IPSetBuilder{}
|
||||
builder.AddPrefix(tsaddr.CGNATRange())
|
||||
builder.AddPrefix(tsaddr.TailscaleULARange())
|
||||
ipSet, _ := builder.IPSet()
|
||||
if !ipSet.ContainsPrefix(prefixV4) {
|
||||
log.Warn().
|
||||
Msgf("Prefix %s is not in the %s range. This is an unsupported configuration.",
|
||||
prefixV4Str, tsaddr.CGNATRange())
|
||||
}
|
||||
|
||||
if !ipSet.ContainsPrefix(prefixV6) {
|
||||
log.Warn().
|
||||
Msgf("Prefix %s is not in the %s range. This is an unsupported configuration.",
|
||||
prefixV6Str, tsaddr.TailscaleULARange())
|
||||
}
|
||||
|
||||
return &prefixV4, &prefixV6, nil
|
||||
}
|
||||
|
||||
func GetHeadscaleConfig() (*Config, error) {
|
||||
if IsCLIConfigured() {
|
||||
return &Config{
|
||||
|
@ -581,66 +614,16 @@ func GetHeadscaleConfig() (*Config, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
prefix4, prefix6, err := Prefixes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dnsConfig, baseDomain := GetDNSConfig()
|
||||
derpConfig := GetDERPConfig()
|
||||
logConfig := GetLogTailConfig()
|
||||
randomizeClientPort := viper.GetBool("randomize_client_port")
|
||||
|
||||
configuredPrefixes := viper.GetStringSlice("ip_prefixes")
|
||||
parsedPrefixes := make([]netip.Prefix, 0, len(configuredPrefixes)+1)
|
||||
|
||||
for i, prefixInConfig := range configuredPrefixes {
|
||||
prefix, err := netip.ParsePrefix(prefixInConfig)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to parse ip_prefixes[%d]: %w", i, err))
|
||||
}
|
||||
|
||||
if prefix.Addr().Is4() {
|
||||
builder := netipx.IPSetBuilder{}
|
||||
builder.AddPrefix(tsaddr.CGNATRange())
|
||||
ipSet, _ := builder.IPSet()
|
||||
if !ipSet.ContainsPrefix(prefix) {
|
||||
log.Warn().
|
||||
Msgf("Prefix %s is not in the %s range. This is an unsupported configuration.",
|
||||
prefixInConfig, tsaddr.CGNATRange())
|
||||
}
|
||||
}
|
||||
|
||||
if prefix.Addr().Is6() {
|
||||
builder := netipx.IPSetBuilder{}
|
||||
builder.AddPrefix(tsaddr.TailscaleULARange())
|
||||
ipSet, _ := builder.IPSet()
|
||||
if !ipSet.ContainsPrefix(prefix) {
|
||||
log.Warn().
|
||||
Msgf("Prefix %s is not in the %s range. This is an unsupported configuration.",
|
||||
prefixInConfig, tsaddr.TailscaleULARange())
|
||||
}
|
||||
}
|
||||
|
||||
parsedPrefixes = append(parsedPrefixes, prefix)
|
||||
}
|
||||
|
||||
prefixes := make([]netip.Prefix, 0, len(parsedPrefixes))
|
||||
{
|
||||
// dedup
|
||||
normalizedPrefixes := make(map[string]int, len(parsedPrefixes))
|
||||
for i, p := range parsedPrefixes {
|
||||
normalized, _ := netipx.RangeOfPrefix(p).Prefix()
|
||||
normalizedPrefixes[normalized.String()] = i
|
||||
}
|
||||
|
||||
// convert back to list
|
||||
for _, i := range normalizedPrefixes {
|
||||
prefixes = append(prefixes, parsedPrefixes[i])
|
||||
}
|
||||
}
|
||||
|
||||
if len(prefixes) < 1 {
|
||||
prefixes = append(prefixes, netip.MustParsePrefix("100.64.0.0/10"))
|
||||
log.Warn().
|
||||
Msgf("'ip_prefixes' not configured, falling back to default: %v", prefixes)
|
||||
}
|
||||
|
||||
oidcClientSecret := viper.GetString("oidc.client_secret")
|
||||
oidcClientSecretPath := viper.GetString("oidc.client_secret_path")
|
||||
if oidcClientSecretPath != "" && oidcClientSecret != "" {
|
||||
|
@ -662,7 +645,9 @@ func GetHeadscaleConfig() (*Config, error) {
|
|||
GRPCAllowInsecure: viper.GetBool("grpc_allow_insecure"),
|
||||
DisableUpdateCheck: viper.GetBool("disable_check_updates"),
|
||||
|
||||
IPPrefixes: prefixes,
|
||||
PrefixV4: prefix4,
|
||||
PrefixV6: prefix6,
|
||||
|
||||
NoisePrivateKeyPath: util.AbsolutePathFromConfigPath(
|
||||
viper.GetString("noise.private_key_path"),
|
||||
),
|
||||
|
|
|
@ -208,7 +208,6 @@ func (node *Node) IsEphemeral() bool {
|
|||
}
|
||||
|
||||
func (node *Node) CanAccess(filter []tailcfg.FilterRule, node2 *Node) bool {
|
||||
|
||||
allowedIPs := append([]netip.Addr{}, node2.IPAddresses...)
|
||||
|
||||
for _, route := range node2.Routes {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue