replace ephemeral deletion logic (#2008)

* replace ephemeral deletion logic

this commit replaces the way we remove ephemeral nodes,
currently they are deleted in a loop and we look at last seen
time. This time is now only set when a node disconnects and
there was a bug (#2006) where nodes that had never disconnected
was deleted since they did not have a last seen.

The new logic will start an expiry timer when the node disconnects
and delete the node from the database when the timer is up.

If the node reconnects within the expiry, the timer is cancelled.

Fixes #2006

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* use uint64 as authekyid and ptr helper in tests

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* add test db helper

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* add list ephemeral node func

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* schedule ephemeral nodes for removal on startup

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* fix gorm query for postgres

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* add godoc

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

---------

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2024-07-18 10:01:59 +02:00 committed by GitHub
parent 58bd38a609
commit 7e62031444
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 417 additions and 206 deletions

View file

@ -1,17 +1,23 @@
package db
import (
"crypto/rand"
"fmt"
"math/big"
"net/netip"
"regexp"
"strconv"
"sync"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/puzpuzpuz/xsync/v3"
"github.com/stretchr/testify/assert"
"gopkg.in/check.v1"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/ptr"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
@ -30,7 +36,6 @@ func (s *Suite) TestGetNode(c *check.C) {
nodeKey := key.NewNode()
machineKey := key.NewMachine()
pakID := uint(pak.ID)
node := &types.Node{
ID: 0,
@ -39,7 +44,7 @@ func (s *Suite) TestGetNode(c *check.C) {
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: &pakID,
AuthKeyID: ptr.To(pak.ID),
}
trx := db.DB.Save(node)
c.Assert(trx.Error, check.IsNil)
@ -61,7 +66,6 @@ func (s *Suite) TestGetNodeByID(c *check.C) {
nodeKey := key.NewNode()
machineKey := key.NewMachine()
pakID := uint(pak.ID)
node := types.Node{
ID: 0,
MachineKey: machineKey.Public(),
@ -69,7 +73,7 @@ func (s *Suite) TestGetNodeByID(c *check.C) {
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: &pakID,
AuthKeyID: ptr.To(pak.ID),
}
trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil)
@ -93,7 +97,6 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) {
machineKey := key.NewMachine()
pakID := uint(pak.ID)
node := types.Node{
ID: 0,
MachineKey: machineKey.Public(),
@ -101,7 +104,7 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) {
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: &pakID,
AuthKeyID: ptr.To(pak.ID),
}
trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil)
@ -145,7 +148,6 @@ func (s *Suite) TestListPeers(c *check.C) {
_, err = db.GetNodeByID(0)
c.Assert(err, check.NotNil)
pakID := uint(pak.ID)
for index := 0; index <= 10; index++ {
nodeKey := key.NewNode()
machineKey := key.NewMachine()
@ -157,7 +159,7 @@ func (s *Suite) TestListPeers(c *check.C) {
Hostname: "testnode" + strconv.Itoa(index),
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: &pakID,
AuthKeyID: ptr.To(pak.ID),
}
trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil)
@ -197,7 +199,6 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
for index := 0; index <= 10; index++ {
nodeKey := key.NewNode()
machineKey := key.NewMachine()
pakID := uint(stor[index%2].key.ID)
v4 := netip.MustParseAddr(fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1)))
node := types.Node{
@ -208,7 +209,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
Hostname: "testnode" + strconv.Itoa(index),
UserID: stor[index%2].user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: &pakID,
AuthKeyID: ptr.To(stor[index%2].key.ID),
}
trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil)
@ -283,7 +284,6 @@ func (s *Suite) TestExpireNode(c *check.C) {
nodeKey := key.NewNode()
machineKey := key.NewMachine()
pakID := uint(pak.ID)
node := &types.Node{
ID: 0,
@ -292,7 +292,7 @@ func (s *Suite) TestExpireNode(c *check.C) {
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: &pakID,
AuthKeyID: ptr.To(pak.ID),
Expiry: &time.Time{},
}
db.DB.Save(node)
@ -328,7 +328,6 @@ func (s *Suite) TestGenerateGivenName(c *check.C) {
machineKey2 := key.NewMachine()
pakID := uint(pak.ID)
node := &types.Node{
ID: 0,
MachineKey: machineKey.Public(),
@ -337,7 +336,7 @@ func (s *Suite) TestGenerateGivenName(c *check.C) {
GivenName: "hostname-1",
UserID: user1.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: &pakID,
AuthKeyID: ptr.To(pak.ID),
}
trx := db.DB.Save(node)
@ -372,7 +371,6 @@ func (s *Suite) TestSetTags(c *check.C) {
nodeKey := key.NewNode()
machineKey := key.NewMachine()
pakID := uint(pak.ID)
node := &types.Node{
ID: 0,
MachineKey: machineKey.Public(),
@ -380,7 +378,7 @@ func (s *Suite) TestSetTags(c *check.C) {
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: &pakID,
AuthKeyID: ptr.To(pak.ID),
}
trx := db.DB.Save(node)
@ -566,7 +564,6 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
route2 := netip.MustParsePrefix("10.11.0.0/24")
v4 := netip.MustParseAddr("100.64.0.1")
pakID := uint(pak.ID)
node := types.Node{
ID: 0,
MachineKey: machineKey.Public(),
@ -574,7 +571,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
Hostname: "test",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: &pakID,
AuthKeyID: ptr.To(pak.ID),
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:exit"},
RoutableIPs: []netip.Prefix{defaultRouteV4, defaultRouteV6, route1, route2},
@ -600,3 +597,121 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
c.Assert(err, check.IsNil)
c.Assert(enabledRoutes, check.HasLen, 4)
}
func TestEphemeralGarbageCollectorOrder(t *testing.T) {
want := []types.NodeID{1, 3}
got := []types.NodeID{}
e := NewEphemeralGarbageCollector(func(ni types.NodeID) {
got = append(got, ni)
})
go e.Start()
e.Schedule(1, 1*time.Second)
e.Schedule(2, 2*time.Second)
e.Schedule(3, 3*time.Second)
e.Schedule(4, 4*time.Second)
e.Cancel(2)
e.Cancel(4)
time.Sleep(6 * time.Second)
e.Close()
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("wrong nodes deleted, unexpected result (-want +got):\n%s", diff)
}
}
func TestEphemeralGarbageCollectorLoads(t *testing.T) {
var got []types.NodeID
var mu sync.Mutex
want := 1000
e := NewEphemeralGarbageCollector(func(ni types.NodeID) {
defer mu.Unlock()
mu.Lock()
time.Sleep(time.Duration(generateRandomNumber(t, 3)) * time.Millisecond)
got = append(got, ni)
})
go e.Start()
for i := 0; i < want; i++ {
go e.Schedule(types.NodeID(i), 1*time.Second)
}
time.Sleep(10 * time.Second)
e.Close()
if len(got) != want {
t.Errorf("expected %d, got %d", want, len(got))
}
}
func generateRandomNumber(t *testing.T, max int64) int64 {
t.Helper()
maxB := big.NewInt(max)
n, err := rand.Int(rand.Reader, maxB)
if err != nil {
t.Fatalf("getting random number: %s", err)
}
return n.Int64() + 1
}
func TestListEphemeralNodes(t *testing.T) {
db, err := newTestDB()
if err != nil {
t.Fatalf("creating db: %s", err)
}
user, err := db.CreateUser("test")
assert.NoError(t, err)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
assert.NoError(t, err)
pakEph, err := db.CreatePreAuthKey(user.Name, false, true, nil, nil)
assert.NoError(t, err)
node := types.Node{
ID: 0,
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "test",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID),
}
nodeEph := types.Node{
ID: 0,
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "ephemeral",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pakEph.ID),
}
err = db.DB.Save(&node).Error
assert.NoError(t, err)
err = db.DB.Save(&nodeEph).Error
assert.NoError(t, err)
nodes, err := db.ListNodes()
assert.NoError(t, err)
ephemeralNodes, err := db.ListEphemeralNodes()
assert.NoError(t, err)
assert.Len(t, nodes, 2)
assert.Len(t, ephemeralNodes, 1)
assert.Equal(t, nodeEph.ID, ephemeralNodes[0].ID)
assert.Equal(t, nodeEph.AuthKeyID, ephemeralNodes[0].AuthKeyID)
assert.Equal(t, nodeEph.UserID, ephemeralNodes[0].UserID)
assert.Equal(t, nodeEph.Hostname, ephemeralNodes[0].Hostname)
}