new IP allocator and add postgres to integration tests. (#1756)

This commit is contained in:
Kristoffer Dalby 2024-02-18 19:31:29 +01:00 committed by GitHub
parent f581d4d9c0
commit 384ca03208
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
119 changed files with 3686 additions and 443 deletions

View file

@ -9,6 +9,7 @@ import (
"net"
"net/http"
_ "net/http/pprof" //nolint
"net/netip"
"os"
"os/signal"
"path/filepath"
@ -80,6 +81,7 @@ const (
type Headscale struct {
cfg *types.Config
db *db.HSDatabase
ipAlloc *db.IPAllocator
noisePrivateKey *key.MachinePrivate
DERPMap *tailcfg.DERPMap
@ -106,6 +108,7 @@ var (
)
func NewHeadscale(cfg *types.Config) (*Headscale, error) {
var err error
if profilingEnabled {
runtime.SetBlockProfileRate(1)
}
@ -128,16 +131,17 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
nodeNotifier: notifier.NewNotifier(),
}
database, err := db.NewHeadscaleDatabase(
app.db, err = db.NewHeadscaleDatabase(
cfg.Database,
app.nodeNotifier,
cfg.IPPrefixes,
cfg.BaseDomain)
if err != nil {
return nil, err
}
app.db = database
app.ipAlloc, err = db.NewIPAllocator(app.db, *cfg.PrefixV4, *cfg.PrefixV6)
if err != nil {
return nil, err
}
if cfg.OIDC.Issuer != "" {
err = app.initOIDC()
@ -151,7 +155,8 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
}
if app.cfg.DNSConfig != nil && app.cfg.DNSConfig.Proxied { // if MagicDNS
magicDNSDomains := util.GenerateMagicDNSRootDomains(app.cfg.IPPrefixes)
// TODO(kradalby): revisit why this takes a list.
magicDNSDomains := util.GenerateMagicDNSRootDomains([]netip.Prefix{*cfg.PrefixV4, *cfg.PrefixV6})
// we might have routes already from Split DNS
if app.cfg.DNSConfig.Routes == nil {
app.cfg.DNSConfig.Routes = make(map[string][]*dnstype.Resolver)

View file

@ -388,8 +388,21 @@ func (h *Headscale) handleAuthKey(
ForcedTags: pak.Proto().GetAclTags(),
}
addrs, err := h.ipAlloc.Next()
if err != nil {
log.Error().
Caller().
Str("func", "RegistrationHandler").
Str("hostinfo.name", registerRequest.Hostinfo.Hostname).
Err(err).
Msg("failed to allocate IP ")
return
}
node, err = h.db.RegisterNode(
nodeToRegister,
addrs,
)
if err != nil {
log.Error().

View file

@ -1,106 +0,0 @@
// Codehere is mostly taken from github.com/tailscale/tailscale
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package db
import (
"errors"
"fmt"
"net/netip"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"go4.org/netipx"
"gorm.io/gorm"
)
var ErrCouldNotAllocateIP = errors.New("could not find any suitable IP")
func (hsdb *HSDatabase) getAvailableIPs() (types.NodeAddresses, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (types.NodeAddresses, error) {
return getAvailableIPs(rx, hsdb.ipPrefixes)
})
}
func getAvailableIPs(rx *gorm.DB, ipPrefixes []netip.Prefix) (types.NodeAddresses, error) {
var ips types.NodeAddresses
var err error
for _, ipPrefix := range ipPrefixes {
var ip *netip.Addr
ip, err = getAvailableIP(rx, ipPrefix)
if err != nil {
return ips, err
}
ips = append(ips, *ip)
}
return ips, err
}
func getAvailableIP(rx *gorm.DB, ipPrefix netip.Prefix) (*netip.Addr, error) {
usedIps, err := getUsedIPs(rx)
if err != nil {
return nil, err
}
ipPrefixNetworkAddress, ipPrefixBroadcastAddress := util.GetIPPrefixEndpoints(ipPrefix)
// Get the first IP in our prefix
ip := ipPrefixNetworkAddress.Next()
for {
if !ipPrefix.Contains(ip) {
return nil, ErrCouldNotAllocateIP
}
switch {
case ip.Compare(ipPrefixBroadcastAddress) == 0:
fallthrough
case usedIps.Contains(ip):
fallthrough
case ip == netip.Addr{} || ip.IsLoopback():
ip = ip.Next()
continue
default:
return &ip, nil
}
}
}
func getUsedIPs(rx *gorm.DB) (*netipx.IPSet, error) {
// FIXME: This really deserves a better data model,
// but this was quick to get running and it should be enough
// to begin experimenting with a dual stack tailnet.
var addressesSlices []string
rx.Model(&types.Node{}).Pluck("ip_addresses", &addressesSlices)
var ips netipx.IPSetBuilder
for _, slice := range addressesSlices {
var machineAddresses types.NodeAddresses
err := machineAddresses.Scan(slice)
if err != nil {
return &netipx.IPSet{}, fmt.Errorf(
"failed to read ip from database: %w",
err,
)
}
for _, ip := range machineAddresses {
ips.Add(ip)
}
}
ipSet, err := ips.IPSet()
if err != nil {
return &netipx.IPSet{}, fmt.Errorf(
"failed to build IP Set: %w",
err,
)
}
return ipSet, nil
}

View file

@ -1,196 +0,0 @@
package db
import (
"net/netip"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"go4.org/netipx"
"gopkg.in/check.v1"
"gorm.io/gorm"
)
func (s *Suite) TestGetAvailableIp(c *check.C) {
tx := db.DB.Begin()
defer tx.Rollback()
ips, err := getAvailableIPs(tx, []netip.Prefix{
netip.MustParsePrefix("10.27.0.0/23"),
})
c.Assert(err, check.IsNil)
expected := netip.MustParseAddr("10.27.0.1")
c.Assert(len(ips), check.Equals, 1)
c.Assert(ips[0].String(), check.Equals, expected.String())
}
func (s *Suite) TestGetUsedIps(c *check.C) {
ips, err := db.getAvailableIPs()
c.Assert(err, check.IsNil)
user, err := db.CreateUser("test-ip")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.getNode("test", "testnode")
c.Assert(err, check.NotNil)
node := types.Node{
ID: 0,
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
IPAddresses: ips,
}
db.Write(func(tx *gorm.DB) error {
return tx.Save(&node).Error
})
usedIps, err := Read(db.DB, func(rx *gorm.DB) (*netipx.IPSet, error) {
return getUsedIPs(rx)
})
c.Assert(err, check.IsNil)
expected := netip.MustParseAddr("10.27.0.1")
expectedIPSetBuilder := netipx.IPSetBuilder{}
expectedIPSetBuilder.Add(expected)
expectedIPSet, _ := expectedIPSetBuilder.IPSet()
c.Assert(usedIps.Equal(expectedIPSet), check.Equals, true)
c.Assert(usedIps.Contains(expected), check.Equals, true)
node1, err := db.GetNodeByID(0)
c.Assert(err, check.IsNil)
c.Assert(len(node1.IPAddresses), check.Equals, 1)
c.Assert(node1.IPAddresses[0], check.Equals, expected)
}
func (s *Suite) TestGetMultiIp(c *check.C) {
user, err := db.CreateUser("test-ip")
c.Assert(err, check.IsNil)
ipPrefixes := []netip.Prefix{
netip.MustParsePrefix("10.27.0.0/23"),
}
for index := 1; index <= 350; index++ {
tx := db.DB.Begin()
ips, err := getAvailableIPs(tx, ipPrefixes)
c.Assert(err, check.IsNil)
pak, err := CreatePreAuthKey(tx, user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = getNode(tx, "test", "testnode")
c.Assert(err, check.NotNil)
node := types.Node{
ID: uint64(index),
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
IPAddresses: ips,
}
tx.Save(&node)
c.Assert(tx.Commit().Error, check.IsNil)
}
usedIps, err := Read(db.DB, func(rx *gorm.DB) (*netipx.IPSet, error) {
return getUsedIPs(rx)
})
c.Assert(err, check.IsNil)
expected0 := netip.MustParseAddr("10.27.0.1")
expected9 := netip.MustParseAddr("10.27.0.10")
expected300 := netip.MustParseAddr("10.27.0.45")
notExpectedIPSetBuilder := netipx.IPSetBuilder{}
notExpectedIPSetBuilder.Add(expected0)
notExpectedIPSetBuilder.Add(expected9)
notExpectedIPSetBuilder.Add(expected300)
notExpectedIPSet, err := notExpectedIPSetBuilder.IPSet()
c.Assert(err, check.IsNil)
// We actually expect it to be a lot larger
c.Assert(usedIps.Equal(notExpectedIPSet), check.Equals, false)
c.Assert(usedIps.Contains(expected0), check.Equals, true)
c.Assert(usedIps.Contains(expected9), check.Equals, true)
c.Assert(usedIps.Contains(expected300), check.Equals, true)
// Check that we can read back the IPs
node1, err := db.GetNodeByID(1)
c.Assert(err, check.IsNil)
c.Assert(len(node1.IPAddresses), check.Equals, 1)
c.Assert(
node1.IPAddresses[0],
check.Equals,
netip.MustParseAddr("10.27.0.1"),
)
node50, err := db.GetNodeByID(50)
c.Assert(err, check.IsNil)
c.Assert(len(node50.IPAddresses), check.Equals, 1)
c.Assert(
node50.IPAddresses[0],
check.Equals,
netip.MustParseAddr("10.27.0.50"),
)
expectedNextIP := netip.MustParseAddr("10.27.1.95")
nextIP, err := db.getAvailableIPs()
c.Assert(err, check.IsNil)
c.Assert(len(nextIP), check.Equals, 1)
c.Assert(nextIP[0].String(), check.Equals, expectedNextIP.String())
// If we call get Available again, we should receive
// the same IP, as it has not been reserved.
nextIP2, err := db.getAvailableIPs()
c.Assert(err, check.IsNil)
c.Assert(len(nextIP2), check.Equals, 1)
c.Assert(nextIP2[0].String(), check.Equals, expectedNextIP.String())
}
func (s *Suite) TestGetAvailableIpNodeWithoutIP(c *check.C) {
ips, err := db.getAvailableIPs()
c.Assert(err, check.IsNil)
expected := netip.MustParseAddr("10.27.0.1")
c.Assert(len(ips), check.Equals, 1)
c.Assert(ips[0].String(), check.Equals, expected.String())
user, err := db.CreateUser("test-ip")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.getNode("test", "testnode")
c.Assert(err, check.NotNil)
node := types.Node{
ID: 0,
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
db.DB.Save(&node)
ips2, err := db.getAvailableIPs()
c.Assert(err, check.IsNil)
c.Assert(len(ips2), check.Equals, 1)
c.Assert(ips2[0].String(), check.Equals, expected.String())
}

View file

@ -5,7 +5,6 @@ import (
"database/sql"
"errors"
"fmt"
"net/netip"
"path/filepath"
"strconv"
"strings"
@ -18,7 +17,6 @@ import (
"gorm.io/gorm"
"gorm.io/gorm/logger"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
)
@ -35,7 +33,6 @@ type KV struct {
type HSDatabase struct {
DB *gorm.DB
ipPrefixes []netip.Prefix
baseDomain string
}
@ -43,8 +40,6 @@ type HSDatabase struct {
// rather than arguments.
func NewHeadscaleDatabase(
cfg types.DatabaseConfig,
notifier *notifier.Notifier,
ipPrefixes []netip.Prefix,
baseDomain string,
) (*HSDatabase, error) {
dbConn, err := openDB(cfg)
@ -327,7 +322,6 @@ func NewHeadscaleDatabase(
db := HSDatabase{
DB: dbConn,
ipPrefixes: ipPrefixes,
baseDomain: baseDomain,
}
@ -351,6 +345,11 @@ func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) {
return nil, fmt.Errorf("creating directory for sqlite: %w", err)
}
log.Info().
Str("database", types.DatabaseSqlite).
Str("path", cfg.Sqlite.Path).
Msg("Opening database")
db, err := gorm.Open(
sqlite.Open(cfg.Sqlite.Path+"?_synchronous=1&_journal_mode=WAL"),
&gorm.Config{
@ -379,6 +378,11 @@ func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) {
cfg.Postgres.User,
)
log.Info().
Str("database", types.DatabasePostgres).
Str("path", dbString).
Msg("Opening database")
if sslEnabled, err := strconv.ParseBool(cfg.Postgres.Ssl); err == nil {
if !sslEnabled {
dbString += " sslmode=disable"

148
hscontrol/db/ip.go Normal file
View file

@ -0,0 +1,148 @@
package db
import (
"errors"
"fmt"
"net/netip"
"sync"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"go4.org/netipx"
"gorm.io/gorm"
)
// IPAllocator is a singleton responsible for allocating
// IP addresses for nodes and making sure the same
// address is not handed out twice. There can only be one
// and it needs to be created before any other database
// writes occur.
type IPAllocator struct {
mu sync.Mutex
prefix4 netip.Prefix
prefix6 netip.Prefix
// Previous IPs handed out
prev4 netip.Addr
prev6 netip.Addr
// Set of all IPs handed out.
// This might not be in sync with the database,
// but it is more conservative. If saves to the
// database fails, the IP will be allocated here
// until the next restart of Headscale.
usedIPs netipx.IPSetBuilder
}
// NewIPAllocator returns a new IPAllocator singleton which
// can be used to hand out unique IP addresses within the
// provided IPv4 and IPv6 prefix. It needs to be created
// when headscale starts and needs to finish its read
// transaction before any writes to the database occur.
func NewIPAllocator(db *HSDatabase, prefix4, prefix6 netip.Prefix) (*IPAllocator, error) {
var addressesSlices []string
if db != nil {
db.Read(func(rx *gorm.DB) error {
return rx.Model(&types.Node{}).Pluck("ip_addresses", &addressesSlices).Error
})
}
var ips netipx.IPSetBuilder
// Add network and broadcast addrs to used pool so they
// are not handed out to nodes.
network4, broadcast4 := util.GetIPPrefixEndpoints(prefix4)
network6, broadcast6 := util.GetIPPrefixEndpoints(prefix6)
ips.Add(network4)
ips.Add(broadcast4)
ips.Add(network6)
ips.Add(broadcast6)
// Fetch all the IP Addresses currently handed out from the Database
// and add them to the used IP set.
for _, slice := range addressesSlices {
var machineAddresses types.NodeAddresses
err := machineAddresses.Scan(slice)
if err != nil {
return nil, fmt.Errorf(
"parsing IPs from database %v: %w", machineAddresses,
err,
)
}
for _, ip := range machineAddresses {
ips.Add(ip)
}
}
// Build the initial IPSet to validate that we can use it.
_, err := ips.IPSet()
if err != nil {
return nil, fmt.Errorf(
"building initial IP Set: %w",
err,
)
}
return &IPAllocator{
usedIPs: ips,
prefix4: prefix4,
prefix6: prefix6,
// Use network as starting point, it will be used to call .Next()
// TODO(kradalby): Could potentially take all the IPs loaded from
// the database into account to start at a more "educated" location.
prev4: network4,
prev6: network6,
}, nil
}
func (i *IPAllocator) Next() (types.NodeAddresses, error) {
i.mu.Lock()
defer i.mu.Unlock()
v4, err := i.next(i.prev4, i.prefix4)
if err != nil {
return nil, fmt.Errorf("allocating IPv4 address: %w", err)
}
v6, err := i.next(i.prev6, i.prefix6)
if err != nil {
return nil, fmt.Errorf("allocating IPv6 address: %w", err)
}
return types.NodeAddresses{*v4, *v6}, nil
}
var ErrCouldNotAllocateIP = errors.New("failed to allocate IP")
func (i *IPAllocator) next(prev netip.Addr, prefix netip.Prefix) (*netip.Addr, error) {
// Get the first IP in our prefix
ip := prev.Next()
// TODO(kradalby): maybe this can be done less often.
set, err := i.usedIPs.IPSet()
if err != nil {
return nil, err
}
for {
if !prefix.Contains(ip) {
return nil, ErrCouldNotAllocateIP
}
// Check if the IP has already been allocated.
if set.Contains(ip) {
ip = ip.Next()
continue
}
i.usedIPs.Add(ip)
return &ip, nil
}
}

151
hscontrol/db/ip_test.go Normal file
View file

@ -0,0 +1,151 @@
package db
import (
"net/netip"
"os"
"testing"
"github.com/davecgh/go-spew/spew"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
)
func TestIPAllocator(t *testing.T) {
mpp := func(pref string) netip.Prefix {
return netip.MustParsePrefix(pref)
}
na := func(pref string) netip.Addr {
return netip.MustParseAddr(pref)
}
newDb := func() *HSDatabase {
tmpDir, err := os.MkdirTemp("", "headscale-db-test-*")
if err != nil {
t.Fatalf("creating temp dir: %s", err)
}
db, _ = NewHeadscaleDatabase(
types.DatabaseConfig{
Type: "sqlite3",
Sqlite: types.SqliteConfig{
Path: tmpDir + "/headscale_test.db",
},
},
"",
)
return db
}
tests := []struct {
name string
dbFunc func() *HSDatabase
prefix4 netip.Prefix
prefix6 netip.Prefix
getCount int
want []types.NodeAddresses
}{
{
name: "simple",
dbFunc: func() *HSDatabase {
return nil
},
prefix4: mpp("100.64.0.0/10"),
prefix6: mpp("fd7a:115c:a1e0::/48"),
getCount: 1,
want: []types.NodeAddresses{
{
na("100.64.0.1"),
na("fd7a:115c:a1e0::1"),
},
},
},
{
name: "simple-with-db",
dbFunc: func() *HSDatabase {
db := newDb()
db.DB.Save(&types.Node{
IPAddresses: types.NodeAddresses{
na("100.64.0.1"),
na("fd7a:115c:a1e0::1"),
},
})
return db
},
prefix4: mpp("100.64.0.0/10"),
prefix6: mpp("fd7a:115c:a1e0::/48"),
getCount: 1,
want: []types.NodeAddresses{
{
na("100.64.0.2"),
na("fd7a:115c:a1e0::2"),
},
},
},
{
name: "before-after-free-middle-in-db",
dbFunc: func() *HSDatabase {
db := newDb()
db.DB.Save(&types.Node{
IPAddresses: types.NodeAddresses{
na("100.64.0.2"),
na("fd7a:115c:a1e0::2"),
},
})
return db
},
prefix4: mpp("100.64.0.0/10"),
prefix6: mpp("fd7a:115c:a1e0::/48"),
getCount: 2,
want: []types.NodeAddresses{
{
na("100.64.0.1"),
na("fd7a:115c:a1e0::1"),
},
{
na("100.64.0.3"),
na("fd7a:115c:a1e0::3"),
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := tt.dbFunc()
alloc, _ := NewIPAllocator(db, tt.prefix4, tt.prefix6)
spew.Dump(alloc)
t.Logf("prefixes: %q, %q", tt.prefix4.String(), tt.prefix6.String())
var got []types.NodeAddresses
for range tt.getCount {
gotSet, err := alloc.Next()
if err != nil {
t.Fatalf("allocating next IP: %s", err)
}
got = append(got, gotSet)
}
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
t.Errorf("IPAllocator unexpected result (-want +got):\n%s", diff)
}
})
}
}

View file

@ -307,7 +307,7 @@ func RegisterNodeFromAuthCallback(
userName string,
nodeExpiry *time.Time,
registrationMethod string,
ipPrefixes []netip.Prefix,
addrs types.NodeAddresses,
) (*types.Node, error) {
log.Debug().
Str("machine_key", mkey.ShortString()).
@ -343,7 +343,7 @@ func RegisterNodeFromAuthCallback(
node, err := RegisterNode(
tx,
registrationNode,
ipPrefixes,
addrs,
)
if err == nil {
@ -359,14 +359,14 @@ func RegisterNodeFromAuthCallback(
return nil, ErrNodeNotFoundRegistrationCache
}
func (hsdb *HSDatabase) RegisterNode(node types.Node) (*types.Node, error) {
func (hsdb *HSDatabase) RegisterNode(node types.Node, addrs types.NodeAddresses) (*types.Node, error) {
return Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
return RegisterNode(tx, node, hsdb.ipPrefixes)
return RegisterNode(tx, node, addrs)
})
}
// RegisterNode is executed from the CLI to register a new Node using its MachineKey.
func RegisterNode(tx *gorm.DB, node types.Node, ipPrefixes []netip.Prefix) (*types.Node, error) {
func RegisterNode(tx *gorm.DB, node types.Node, addrs types.NodeAddresses) (*types.Node, error) {
log.Debug().
Str("node", node.Hostname).
Str("machine_key", node.MachineKey.ShortString()).
@ -393,18 +393,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipPrefixes []netip.Prefix) (*typ
return &node, nil
}
ips, err := getAvailableIPs(tx, ipPrefixes)
if err != nil {
log.Error().
Caller().
Err(err).
Str("node", node.Hostname).
Msg("Could not find IP for the new node")
return nil, err
}
node.IPAddresses = ips
node.IPAddresses = addrs
if err := tx.Save(&node).Error; err != nil {
return nil, fmt.Errorf("failed register(save) node in the database: %w", err)
@ -413,7 +402,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipPrefixes []netip.Prefix) (*typ
log.Trace().
Caller().
Str("node", node.Hostname).
Str("ip", strings.Join(ips.StringSlice(), ",")).
Str("ip", strings.Join(addrs.StringSlice(), ",")).
Msg("Node registered with the database")
return &node, nil

View file

@ -7,7 +7,6 @@ import (
"time"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/stretchr/testify/assert"
@ -661,10 +660,6 @@ func TestFailoverRoute(t *testing.T) {
Path: tmpDir + "/headscale_test.db",
},
},
notifier.NewNotifier(),
[]netip.Prefix{
netip.MustParsePrefix("10.27.0.0/23"),
},
"",
)
assert.NoError(t, err)

View file

@ -2,11 +2,9 @@ package db
import (
"log"
"net/netip"
"os"
"testing"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/types"
"gopkg.in/check.v1"
)
@ -52,10 +50,6 @@ func (s *Suite) ResetDB(c *check.C) {
Path: tmpDir + "/headscale_test.db",
},
},
notifier.NewNotifier(),
[]netip.Prefix{
netip.MustParsePrefix("10.27.0.0/23"),
},
"",
)
if err != nil {

View file

@ -4,6 +4,7 @@ package hscontrol
import (
"context"
"fmt"
"sort"
"strings"
"time"
@ -98,6 +99,10 @@ func (api headscaleV1APIServer) ListUsers(
response[index] = user.Proto()
}
sort.Slice(response, func(i, j int) bool {
return response[i].Id < response[j].Id
})
log.Trace().Caller().Interface("users", response).Msg("")
return &v1.ListUsersResponse{Users: response}, nil
@ -168,6 +173,10 @@ func (api headscaleV1APIServer) ListPreAuthKeys(
response[index] = key.Proto()
}
sort.Slice(response, func(i, j int) bool {
return response[i].Id < response[j].Id
})
return &v1.ListPreAuthKeysResponse{PreAuthKeys: response}, nil
}
@ -186,6 +195,11 @@ func (api headscaleV1APIServer) RegisterNode(
return nil, err
}
addrs, err := api.h.ipAlloc.Next()
if err != nil {
return nil, err
}
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
return db.RegisterNodeFromAuthCallback(
tx,
@ -194,7 +208,7 @@ func (api headscaleV1APIServer) RegisterNode(
request.GetUser(),
nil,
util.RegisterMethodCLI,
api.h.cfg.IPPrefixes,
addrs,
)
})
if err != nil {
@ -422,6 +436,10 @@ func (api headscaleV1APIServer) ListNodes(
return nil, err
}
sort.Slice(nodes, func(i, j int) bool {
return nodes[i].ID < nodes[j].ID
})
response := make([]*v1.Node, len(nodes))
for index, node := range nodes {
resp := node.Proto()
@ -606,6 +624,10 @@ func (api headscaleV1APIServer) ListApiKeys(
response[index] = key.Proto()
}
sort.Slice(response, func(i, j int) bool {
return response[i].Id < response[j].Id
})
return &v1.ListApiKeysResponse{ApiKeys: response}, nil
}

View file

@ -620,6 +620,11 @@ func (h *Headscale) registerNodeForOIDCCallback(
machineKey *key.MachinePublic,
expiry time.Time,
) error {
addrs, err := h.ipAlloc.Next()
if err != nil {
return err
}
if err := h.db.DB.Transaction(func(tx *gorm.DB) error {
if _, err := db.RegisterNodeFromAuthCallback(
// TODO(kradalby): find a better way to use the cache across modules
@ -629,7 +634,7 @@ func (h *Headscale) registerNodeForOIDCCallback(
user.Name,
&expiry,
util.RegisterMethodOIDC,
h.cfg.IPPrefixes,
addrs,
); err != nil {
return err
}

View file

@ -1,7 +1,6 @@
package hscontrol
import (
"net/netip"
"os"
"testing"
@ -47,9 +46,6 @@ func (s *Suite) ResetDB(c *check.C) {
Path: tmpDir + "/headscale_test.db",
},
},
IPPrefixes: []netip.Prefix{
netip.MustParsePrefix("10.27.0.0/23"),
},
OIDC: types.OIDCConfig{
StripEmaildomain: false,
},

View file

@ -11,6 +11,7 @@ import (
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/prometheus/common/model"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
@ -19,8 +20,6 @@ import (
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/dnstype"
"github.com/juanfont/headscale/hscontrol/util"
)
const (
@ -41,7 +40,8 @@ type Config struct {
GRPCAllowInsecure bool
EphemeralNodeInactivityTimeout time.Duration
NodeUpdateCheckInterval time.Duration
IPPrefixes []netip.Prefix
PrefixV4 *netip.Prefix
PrefixV6 *netip.Prefix
NoisePrivateKeyPath string
BaseDomain string
Log LogConfig
@ -569,6 +569,39 @@ func GetDNSConfig() (*tailcfg.DNSConfig, string) {
return nil, ""
}
func Prefixes() (*netip.Prefix, *netip.Prefix, error) {
prefixV4Str := viper.GetString("prefixes.v4")
prefixV6Str := viper.GetString("prefixes.v6")
prefixV4, err := netip.ParsePrefix(prefixV4Str)
if err != nil {
return nil, nil, err
}
prefixV6, err := netip.ParsePrefix(prefixV6Str)
if err != nil {
return nil, nil, err
}
builder := netipx.IPSetBuilder{}
builder.AddPrefix(tsaddr.CGNATRange())
builder.AddPrefix(tsaddr.TailscaleULARange())
ipSet, _ := builder.IPSet()
if !ipSet.ContainsPrefix(prefixV4) {
log.Warn().
Msgf("Prefix %s is not in the %s range. This is an unsupported configuration.",
prefixV4Str, tsaddr.CGNATRange())
}
if !ipSet.ContainsPrefix(prefixV6) {
log.Warn().
Msgf("Prefix %s is not in the %s range. This is an unsupported configuration.",
prefixV6Str, tsaddr.TailscaleULARange())
}
return &prefixV4, &prefixV6, nil
}
func GetHeadscaleConfig() (*Config, error) {
if IsCLIConfigured() {
return &Config{
@ -581,66 +614,16 @@ func GetHeadscaleConfig() (*Config, error) {
}, nil
}
prefix4, prefix6, err := Prefixes()
if err != nil {
return nil, err
}
dnsConfig, baseDomain := GetDNSConfig()
derpConfig := GetDERPConfig()
logConfig := GetLogTailConfig()
randomizeClientPort := viper.GetBool("randomize_client_port")
configuredPrefixes := viper.GetStringSlice("ip_prefixes")
parsedPrefixes := make([]netip.Prefix, 0, len(configuredPrefixes)+1)
for i, prefixInConfig := range configuredPrefixes {
prefix, err := netip.ParsePrefix(prefixInConfig)
if err != nil {
panic(fmt.Errorf("failed to parse ip_prefixes[%d]: %w", i, err))
}
if prefix.Addr().Is4() {
builder := netipx.IPSetBuilder{}
builder.AddPrefix(tsaddr.CGNATRange())
ipSet, _ := builder.IPSet()
if !ipSet.ContainsPrefix(prefix) {
log.Warn().
Msgf("Prefix %s is not in the %s range. This is an unsupported configuration.",
prefixInConfig, tsaddr.CGNATRange())
}
}
if prefix.Addr().Is6() {
builder := netipx.IPSetBuilder{}
builder.AddPrefix(tsaddr.TailscaleULARange())
ipSet, _ := builder.IPSet()
if !ipSet.ContainsPrefix(prefix) {
log.Warn().
Msgf("Prefix %s is not in the %s range. This is an unsupported configuration.",
prefixInConfig, tsaddr.TailscaleULARange())
}
}
parsedPrefixes = append(parsedPrefixes, prefix)
}
prefixes := make([]netip.Prefix, 0, len(parsedPrefixes))
{
// dedup
normalizedPrefixes := make(map[string]int, len(parsedPrefixes))
for i, p := range parsedPrefixes {
normalized, _ := netipx.RangeOfPrefix(p).Prefix()
normalizedPrefixes[normalized.String()] = i
}
// convert back to list
for _, i := range normalizedPrefixes {
prefixes = append(prefixes, parsedPrefixes[i])
}
}
if len(prefixes) < 1 {
prefixes = append(prefixes, netip.MustParsePrefix("100.64.0.0/10"))
log.Warn().
Msgf("'ip_prefixes' not configured, falling back to default: %v", prefixes)
}
oidcClientSecret := viper.GetString("oidc.client_secret")
oidcClientSecretPath := viper.GetString("oidc.client_secret_path")
if oidcClientSecretPath != "" && oidcClientSecret != "" {
@ -662,7 +645,9 @@ func GetHeadscaleConfig() (*Config, error) {
GRPCAllowInsecure: viper.GetBool("grpc_allow_insecure"),
DisableUpdateCheck: viper.GetBool("disable_check_updates"),
IPPrefixes: prefixes,
PrefixV4: prefix4,
PrefixV6: prefix6,
NoisePrivateKeyPath: util.AbsolutePathFromConfigPath(
viper.GetString("noise.private_key_path"),
),

View file

@ -208,7 +208,6 @@ func (node *Node) IsEphemeral() bool {
}
func (node *Node) CanAccess(filter []tailcfg.FilterRule, node2 *Node) bool {
allowedIPs := append([]netip.Addr{}, node2.IPAddresses...)
for _, route := range node2.Routes {