Replace the timestamp based state system

This commit replaces the timestamp based state system with a new
one that has update channels directly to the connected nodes. It
will send an update to all listening clients via the polling
mechanism.

It introduces a new package notifier, which has a concurrency safe
manager for all our channels to the connected nodes.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2023-06-21 11:29:52 +02:00 committed by Kristoffer Dalby
parent 056d3a81c5
commit 66ff1fcd40
13 changed files with 216 additions and 731 deletions

View file

@ -63,8 +63,6 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
c.Assert(len(machine1.IPAddresses), check.Equals, 1)
c.Assert(machine1.IPAddresses[0], check.Equals, expected)
c.Assert(channelUpdates, check.Equals, int32(0))
}
func (s *Suite) TestGetMultiIp(c *check.C) {
@ -153,8 +151,6 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
c.Assert(len(nextIP2), check.Equals, 1)
c.Assert(nextIP2[0].String(), check.Equals, expectedNextIP.String())
c.Assert(channelUpdates, check.Equals, int32(0))
}
func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
@ -192,6 +188,4 @@ func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
c.Assert(len(ips2), check.Equals, 1)
c.Assert(ips2[0].String(), check.Equals, expected.String())
c.Assert(channelUpdates, check.Equals, int32(0))
}

View file

@ -22,8 +22,6 @@ func (*Suite) TestCreateAPIKey(c *check.C) {
keys, err := db.ListAPIKeys()
c.Assert(err, check.IsNil)
c.Assert(len(keys), check.Equals, 1)
c.Assert(channelUpdates, check.Equals, int32(0))
}
func (*Suite) TestAPIKeyDoesNotExist(c *check.C) {
@ -41,8 +39,6 @@ func (*Suite) TestValidateAPIKeyOk(c *check.C) {
valid, err := db.ValidateAPIKey(apiKeyStr)
c.Assert(err, check.IsNil)
c.Assert(valid, check.Equals, true)
c.Assert(channelUpdates, check.Equals, int32(0))
}
func (*Suite) TestValidateAPIKeyNotOk(c *check.C) {
@ -71,8 +67,6 @@ func (*Suite) TestValidateAPIKeyNotOk(c *check.C) {
validWithErr, err := db.ValidateAPIKey("produceerrorkey")
c.Assert(err, check.NotNil)
c.Assert(validWithErr, check.Equals, false)
c.Assert(channelUpdates, check.Equals, int32(0))
}
func (*Suite) TestExpireAPIKey(c *check.C) {
@ -92,6 +86,4 @@ func (*Suite) TestExpireAPIKey(c *check.C) {
notValid, err := db.ValidateAPIKey(apiKeyStr)
c.Assert(err, check.IsNil)
c.Assert(notValid, check.Equals, false)
c.Assert(channelUpdates, check.Equals, int32(0))
}

View file

@ -9,6 +9,7 @@ import (
"time"
"github.com/glebarez/sqlite"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
@ -36,8 +37,8 @@ type KV struct {
}
type HSDatabase struct {
db *gorm.DB
notifyStateChan chan<- struct{}
db *gorm.DB
notifier *notifier.Notifier
ipAllocationMutex sync.Mutex
@ -50,7 +51,7 @@ type HSDatabase struct {
func NewHeadscaleDatabase(
dbType, connectionAddr string,
debug bool,
notifyStateChan chan<- struct{},
notifier *notifier.Notifier,
ipPrefixes []netip.Prefix,
baseDomain string,
) (*HSDatabase, error) {
@ -60,8 +61,8 @@ func NewHeadscaleDatabase(
}
db := HSDatabase{
db: dbConn,
notifyStateChan: notifyStateChan,
db: dbConn,
notifier: notifier,
ipPrefixes: ipPrefixes,
baseDomain: baseDomain,
@ -297,10 +298,6 @@ func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) {
)
}
func (hsdb *HSDatabase) notifyStateChange() {
hsdb.notifyStateChan <- struct{}{}
}
// getValue returns the value for the given key in KV.
func (hsdb *HSDatabase) getValue(key string) (string, error) {
var row KV

View file

@ -218,7 +218,7 @@ func (hsdb *HSDatabase) SetTags(
}
machine.ForcedTags = newTags
hsdb.notifyStateChange()
hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
if err := hsdb.db.Save(machine).Error; err != nil {
return fmt.Errorf("failed to update tags for machine in the database: %w", err)
@ -232,7 +232,7 @@ func (hsdb *HSDatabase) ExpireMachine(machine *types.Machine) error {
now := time.Now()
machine.Expiry = &now
hsdb.notifyStateChange()
hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
if err := hsdb.db.Save(machine).Error; err != nil {
return fmt.Errorf("failed to expire machine in the database: %w", err)
@ -259,7 +259,7 @@ func (hsdb *HSDatabase) RenameMachine(machine *types.Machine, newName string) er
}
machine.GivenName = newName
hsdb.notifyStateChange()
hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
if err := hsdb.db.Save(machine).Error; err != nil {
return fmt.Errorf("failed to rename machine in the database: %w", err)
@ -275,7 +275,7 @@ func (hsdb *HSDatabase) RefreshMachine(machine *types.Machine, expiry time.Time)
machine.LastSuccessfulUpdate = &now
machine.Expiry = &expiry
hsdb.notifyStateChange()
hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
if err := hsdb.db.Save(machine).Error; err != nil {
return fmt.Errorf(
@ -323,32 +323,6 @@ func (hsdb *HSDatabase) HardDeleteMachine(machine *types.Machine) error {
return nil
}
func (hsdb *HSDatabase) IsOutdated(machine *types.Machine, lastChange time.Time) bool {
if err := hsdb.UpdateMachineFromDatabase(machine); err != nil {
// It does not seem meaningful to propagate this error as the end result
// will have to be that the machine has to be considered outdated.
return true
}
// Get the last update from all headscale users to compare with our nodes
// last update.
// TODO(kradalby): Only request updates from users where we can talk to nodes
// This would mostly be for a bit of performance, and can be calculated based on
// ACLs.
lastUpdate := machine.CreatedAt
if machine.LastSuccessfulUpdate != nil {
lastUpdate = *machine.LastSuccessfulUpdate
}
log.Trace().
Caller().
Str("machine", machine.Hostname).
Time("last_successful_update", lastChange).
Time("last_state_change", lastUpdate).
Msgf("Checking if %s is missing updates", machine.Hostname)
return lastUpdate.Before(lastChange)
}
func (hsdb *HSDatabase) RegisterMachineFromAuthCallback(
cache *cache.Cache,
nodeKeyStr string,
@ -626,7 +600,7 @@ func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string
}
}
hsdb.notifyStateChange()
hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
return nil
}
@ -723,17 +697,22 @@ func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Durati
}
if expiredFound {
hsdb.notifyStateChange()
hsdb.notifier.NotifyAll()
}
}
}
func (hsdb *HSDatabase) ExpireExpiredMachines(lastChange time.Time) {
func (hsdb *HSDatabase) ExpireExpiredMachines(lastCheck time.Time) time.Time {
// use the time of the start of the function to ensure we
// dont miss some machines by returning it _after_ we have
// checked everything.
started := time.Now()
users, err := hsdb.ListUsers()
if err != nil {
log.Error().Err(err).Msg("Error listing users")
return
return time.Unix(0, 0)
}
for _, user := range users {
@ -744,13 +723,13 @@ func (hsdb *HSDatabase) ExpireExpiredMachines(lastChange time.Time) {
Str("user", user.Name).
Msg("Error listing machines in user")
return
return time.Unix(0, 0)
}
expiredFound := false
for index, machine := range machines {
if machine.IsExpired() &&
machine.Expiry.After(lastChange) {
machine.Expiry.After(lastCheck) {
expiredFound = true
err := hsdb.ExpireMachine(&machines[index])
@ -770,7 +749,9 @@ func (hsdb *HSDatabase) ExpireExpiredMachines(lastChange time.Time) {
}
if expiredFound {
hsdb.notifyStateChange()
hsdb.notifier.NotifyAll()
}
}
return started
}

View file

@ -39,8 +39,6 @@ func (s *Suite) TestGetMachine(c *check.C) {
_, err = db.GetMachine("test", "testmachine")
c.Assert(err, check.IsNil)
c.Assert(channelUpdates, check.Equals, int32(0))
}
func (s *Suite) TestGetMachineByID(c *check.C) {
@ -67,8 +65,6 @@ func (s *Suite) TestGetMachineByID(c *check.C) {
_, err = db.GetMachineByID(0)
c.Assert(err, check.IsNil)
c.Assert(channelUpdates, check.Equals, int32(0))
}
func (s *Suite) TestGetMachineByNodeKey(c *check.C) {
@ -98,8 +94,6 @@ func (s *Suite) TestGetMachineByNodeKey(c *check.C) {
_, err = db.GetMachineByNodeKey(nodeKey.Public())
c.Assert(err, check.IsNil)
c.Assert(channelUpdates, check.Equals, int32(0))
}
func (s *Suite) TestGetMachineByAnyNodeKey(c *check.C) {
@ -131,8 +125,6 @@ func (s *Suite) TestGetMachineByAnyNodeKey(c *check.C) {
_, err = db.GetMachineByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public())
c.Assert(err, check.IsNil)
c.Assert(channelUpdates, check.Equals, int32(0))
}
func (s *Suite) TestDeleteMachine(c *check.C) {
@ -155,8 +147,6 @@ func (s *Suite) TestDeleteMachine(c *check.C) {
_, err = db.GetMachine(user.Name, "testmachine")
c.Assert(err, check.NotNil)
c.Assert(channelUpdates, check.Equals, int32(0))
}
func (s *Suite) TestHardDeleteMachine(c *check.C) {
@ -179,8 +169,6 @@ func (s *Suite) TestHardDeleteMachine(c *check.C) {
_, err = db.GetMachine(user.Name, "testmachine3")
c.Assert(err, check.NotNil)
c.Assert(channelUpdates, check.Equals, int32(0))
}
func (s *Suite) TestListPeers(c *check.C) {
@ -217,8 +205,6 @@ func (s *Suite) TestListPeers(c *check.C) {
c.Assert(peersOfMachine0[0].Hostname, check.Equals, "testmachine2")
c.Assert(peersOfMachine0[5].Hostname, check.Equals, "testmachine7")
c.Assert(peersOfMachine0[8].Hostname, check.Equals, "testmachine10")
c.Assert(channelUpdates, check.Equals, int32(0))
}
func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
@ -312,8 +298,6 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
c.Assert(peersOfAdminMachine[0].Hostname, check.Equals, "testmachine2")
c.Assert(peersOfAdminMachine[2].Hostname, check.Equals, "testmachine4")
c.Assert(peersOfAdminMachine[5].Hostname, check.Equals, "testmachine7")
c.Assert(channelUpdates, check.Equals, int32(0))
}
func (s *Suite) TestExpireMachine(c *check.C) {
@ -349,8 +333,6 @@ func (s *Suite) TestExpireMachine(c *check.C) {
c.Assert(err, check.IsNil)
c.Assert(machineFromDB.IsExpired(), check.Equals, true)
c.Assert(channelUpdates, check.Equals, int32(1))
}
func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) {
@ -372,8 +354,6 @@ func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) {
for i := range deserialized {
c.Assert(deserialized[i], check.Equals, input[i])
}
c.Assert(channelUpdates, check.Equals, int32(0))
}
func (s *Suite) TestGenerateGivenName(c *check.C) {
@ -418,8 +398,6 @@ func (s *Suite) TestGenerateGivenName(c *check.C) {
comment = check.Commentf("Unique users, unique machines, same hostname, conflict")
c.Assert(err, check.IsNil, comment)
c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", MachineGivenNameHashLength), comment)
c.Assert(channelUpdates, check.Equals, int32(0))
}
func (s *Suite) TestSetTags(c *check.C) {
@ -463,8 +441,6 @@ func (s *Suite) TestSetTags(c *check.C) {
check.DeepEquals,
types.StringList([]string{"tag:bar", "tag:test", "tag:unknown"}),
)
c.Assert(channelUpdates, check.Equals, int32(2))
}
func TestHeadscale_generateGivenName(t *testing.T) {
@ -655,6 +631,4 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
enabledRoutes, err := db.GetEnabledRoutes(machine0ByID)
c.Assert(err, check.IsNil)
c.Assert(enabledRoutes, check.HasLen, 4)
c.Assert(channelUpdates, check.Equals, int32(4))
}

View file

@ -161,8 +161,6 @@ func (*Suite) TestEphemeralKey(c *check.C) {
// The machine record should have been deleted
_, err = db.GetMachine("test7", "testest")
c.Assert(err, check.NotNil)
c.Assert(channelUpdates, check.Equals, int32(1))
}
func (*Suite) TestExpirePreauthKey(c *check.C) {

View file

@ -374,7 +374,7 @@ func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error {
}
if routesChanged {
hsdb.notifyStateChange()
hsdb.notifier.NotifyAll()
}
return nil

View file

@ -52,8 +52,6 @@ func (s *Suite) TestGetRoutes(c *check.C) {
err = db.enableRoutes(&machine, "10.0.0.0/24")
c.Assert(err, check.IsNil)
c.Assert(channelUpdates, check.Equals, int32(0))
}
func (s *Suite) TestGetEnableRoutes(c *check.C) {
@ -129,8 +127,6 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
enabledRoutesWithAdditionalRoute, err := db.GetEnabledRoutes(&machine)
c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutesWithAdditionalRoute), check.Equals, 2)
c.Assert(channelUpdates, check.Equals, int32(3))
}
func (s *Suite) TestIsUniquePrefix(c *check.C) {
@ -215,8 +211,6 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
routes, err = db.GetMachinePrimaryRoutes(&machine2)
c.Assert(err, check.IsNil)
c.Assert(len(routes), check.Equals, 0)
c.Assert(channelUpdates, check.Equals, int32(3))
}
func (s *Suite) TestSubnetFailover(c *check.C) {
@ -359,8 +353,6 @@ func (s *Suite) TestSubnetFailover(c *check.C) {
routes, err = db.GetMachinePrimaryRoutes(&machine2)
c.Assert(err, check.IsNil)
c.Assert(len(routes), check.Equals, 2)
c.Assert(channelUpdates, check.Equals, int32(6))
}
func (s *Suite) TestDeleteRoutes(c *check.C) {
@ -420,6 +412,4 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
enabledRoutes1, err := db.GetEnabledRoutes(&machine1)
c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes1), check.Equals, 1)
c.Assert(channelUpdates, check.Equals, int32(2))
}

View file

@ -3,9 +3,9 @@ package db
import (
"net/netip"
"os"
"sync/atomic"
"testing"
"github.com/juanfont/headscale/hscontrol/notifier"
"gopkg.in/check.v1"
)
@ -20,14 +20,9 @@ type Suite struct{}
var (
tmpDir string
db *HSDatabase
// channelUpdates counts the number of times
// either of the channels was notified.
channelUpdates int32
)
func (s *Suite) SetUpTest(c *check.C) {
atomic.StoreInt32(&channelUpdates, 0)
s.ResetDB(c)
}
@ -35,13 +30,6 @@ func (s *Suite) TearDownTest(c *check.C) {
os.RemoveAll(tmpDir)
}
func notificationSink(c <-chan struct{}) {
for {
<-c
atomic.AddInt32(&channelUpdates, 1)
}
}
func (s *Suite) ResetDB(c *check.C) {
if len(tmpDir) != 0 {
os.RemoveAll(tmpDir)
@ -52,15 +40,11 @@ func (s *Suite) ResetDB(c *check.C) {
c.Fatal(err)
}
sink := make(chan struct{})
go notificationSink(sink)
db, err = NewHeadscaleDatabase(
"sqlite3",
tmpDir+"/headscale_test.db",
false,
sink,
notifier.NewNotifier(),
[]netip.Prefix{
netip.MustParsePrefix("10.27.0.0/23"),
},