Replace the timestamp based state system
This commit replaces the timestamp based state system with a new one that has update channels directly to the connected nodes. It will send an update to all listening clients via the polling mechanism. It introduces a new package notifier, which has a concurrency safe manager for all our channels to the connected nodes. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
056d3a81c5
commit
66ff1fcd40
13 changed files with 216 additions and 731 deletions
|
@ -63,8 +63,6 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
|
|||
|
||||
c.Assert(len(machine1.IPAddresses), check.Equals, 1)
|
||||
c.Assert(machine1.IPAddresses[0], check.Equals, expected)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetMultiIp(c *check.C) {
|
||||
|
@ -153,8 +151,6 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
|
|||
|
||||
c.Assert(len(nextIP2), check.Equals, 1)
|
||||
c.Assert(nextIP2[0].String(), check.Equals, expectedNextIP.String())
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
|
||||
|
@ -192,6 +188,4 @@ func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
|
|||
|
||||
c.Assert(len(ips2), check.Equals, 1)
|
||||
c.Assert(ips2[0].String(), check.Equals, expected.String())
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
|
|
@ -22,8 +22,6 @@ func (*Suite) TestCreateAPIKey(c *check.C) {
|
|||
keys, err := db.ListAPIKeys()
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(keys), check.Equals, 1)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (*Suite) TestAPIKeyDoesNotExist(c *check.C) {
|
||||
|
@ -41,8 +39,6 @@ func (*Suite) TestValidateAPIKeyOk(c *check.C) {
|
|||
valid, err := db.ValidateAPIKey(apiKeyStr)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(valid, check.Equals, true)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (*Suite) TestValidateAPIKeyNotOk(c *check.C) {
|
||||
|
@ -71,8 +67,6 @@ func (*Suite) TestValidateAPIKeyNotOk(c *check.C) {
|
|||
validWithErr, err := db.ValidateAPIKey("produceerrorkey")
|
||||
c.Assert(err, check.NotNil)
|
||||
c.Assert(validWithErr, check.Equals, false)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (*Suite) TestExpireAPIKey(c *check.C) {
|
||||
|
@ -92,6 +86,4 @@ func (*Suite) TestExpireAPIKey(c *check.C) {
|
|||
notValid, err := db.ValidateAPIKey(apiKeyStr)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(notValid, check.Equals, false)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
@ -36,8 +37,8 @@ type KV struct {
|
|||
}
|
||||
|
||||
type HSDatabase struct {
|
||||
db *gorm.DB
|
||||
notifyStateChan chan<- struct{}
|
||||
db *gorm.DB
|
||||
notifier *notifier.Notifier
|
||||
|
||||
ipAllocationMutex sync.Mutex
|
||||
|
||||
|
@ -50,7 +51,7 @@ type HSDatabase struct {
|
|||
func NewHeadscaleDatabase(
|
||||
dbType, connectionAddr string,
|
||||
debug bool,
|
||||
notifyStateChan chan<- struct{},
|
||||
notifier *notifier.Notifier,
|
||||
ipPrefixes []netip.Prefix,
|
||||
baseDomain string,
|
||||
) (*HSDatabase, error) {
|
||||
|
@ -60,8 +61,8 @@ func NewHeadscaleDatabase(
|
|||
}
|
||||
|
||||
db := HSDatabase{
|
||||
db: dbConn,
|
||||
notifyStateChan: notifyStateChan,
|
||||
db: dbConn,
|
||||
notifier: notifier,
|
||||
|
||||
ipPrefixes: ipPrefixes,
|
||||
baseDomain: baseDomain,
|
||||
|
@ -297,10 +298,6 @@ func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) {
|
|||
)
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) notifyStateChange() {
|
||||
hsdb.notifyStateChan <- struct{}{}
|
||||
}
|
||||
|
||||
// getValue returns the value for the given key in KV.
|
||||
func (hsdb *HSDatabase) getValue(key string) (string, error) {
|
||||
var row KV
|
||||
|
|
|
@ -218,7 +218,7 @@ func (hsdb *HSDatabase) SetTags(
|
|||
}
|
||||
machine.ForcedTags = newTags
|
||||
|
||||
hsdb.notifyStateChange()
|
||||
hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
|
||||
|
||||
if err := hsdb.db.Save(machine).Error; err != nil {
|
||||
return fmt.Errorf("failed to update tags for machine in the database: %w", err)
|
||||
|
@ -232,7 +232,7 @@ func (hsdb *HSDatabase) ExpireMachine(machine *types.Machine) error {
|
|||
now := time.Now()
|
||||
machine.Expiry = &now
|
||||
|
||||
hsdb.notifyStateChange()
|
||||
hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
|
||||
|
||||
if err := hsdb.db.Save(machine).Error; err != nil {
|
||||
return fmt.Errorf("failed to expire machine in the database: %w", err)
|
||||
|
@ -259,7 +259,7 @@ func (hsdb *HSDatabase) RenameMachine(machine *types.Machine, newName string) er
|
|||
}
|
||||
machine.GivenName = newName
|
||||
|
||||
hsdb.notifyStateChange()
|
||||
hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
|
||||
|
||||
if err := hsdb.db.Save(machine).Error; err != nil {
|
||||
return fmt.Errorf("failed to rename machine in the database: %w", err)
|
||||
|
@ -275,7 +275,7 @@ func (hsdb *HSDatabase) RefreshMachine(machine *types.Machine, expiry time.Time)
|
|||
machine.LastSuccessfulUpdate = &now
|
||||
machine.Expiry = &expiry
|
||||
|
||||
hsdb.notifyStateChange()
|
||||
hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
|
||||
|
||||
if err := hsdb.db.Save(machine).Error; err != nil {
|
||||
return fmt.Errorf(
|
||||
|
@ -323,32 +323,6 @@ func (hsdb *HSDatabase) HardDeleteMachine(machine *types.Machine) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) IsOutdated(machine *types.Machine, lastChange time.Time) bool {
|
||||
if err := hsdb.UpdateMachineFromDatabase(machine); err != nil {
|
||||
// It does not seem meaningful to propagate this error as the end result
|
||||
// will have to be that the machine has to be considered outdated.
|
||||
return true
|
||||
}
|
||||
|
||||
// Get the last update from all headscale users to compare with our nodes
|
||||
// last update.
|
||||
// TODO(kradalby): Only request updates from users where we can talk to nodes
|
||||
// This would mostly be for a bit of performance, and can be calculated based on
|
||||
// ACLs.
|
||||
lastUpdate := machine.CreatedAt
|
||||
if machine.LastSuccessfulUpdate != nil {
|
||||
lastUpdate = *machine.LastSuccessfulUpdate
|
||||
}
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", machine.Hostname).
|
||||
Time("last_successful_update", lastChange).
|
||||
Time("last_state_change", lastUpdate).
|
||||
Msgf("Checking if %s is missing updates", machine.Hostname)
|
||||
|
||||
return lastUpdate.Before(lastChange)
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) RegisterMachineFromAuthCallback(
|
||||
cache *cache.Cache,
|
||||
nodeKeyStr string,
|
||||
|
@ -626,7 +600,7 @@ func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string
|
|||
}
|
||||
}
|
||||
|
||||
hsdb.notifyStateChange()
|
||||
hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -723,17 +697,22 @@ func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Durati
|
|||
}
|
||||
|
||||
if expiredFound {
|
||||
hsdb.notifyStateChange()
|
||||
hsdb.notifier.NotifyAll()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) ExpireExpiredMachines(lastChange time.Time) {
|
||||
func (hsdb *HSDatabase) ExpireExpiredMachines(lastCheck time.Time) time.Time {
|
||||
// use the time of the start of the function to ensure we
|
||||
// dont miss some machines by returning it _after_ we have
|
||||
// checked everything.
|
||||
started := time.Now()
|
||||
|
||||
users, err := hsdb.ListUsers()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Error listing users")
|
||||
|
||||
return
|
||||
return time.Unix(0, 0)
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
|
@ -744,13 +723,13 @@ func (hsdb *HSDatabase) ExpireExpiredMachines(lastChange time.Time) {
|
|||
Str("user", user.Name).
|
||||
Msg("Error listing machines in user")
|
||||
|
||||
return
|
||||
return time.Unix(0, 0)
|
||||
}
|
||||
|
||||
expiredFound := false
|
||||
for index, machine := range machines {
|
||||
if machine.IsExpired() &&
|
||||
machine.Expiry.After(lastChange) {
|
||||
machine.Expiry.After(lastCheck) {
|
||||
expiredFound = true
|
||||
|
||||
err := hsdb.ExpireMachine(&machines[index])
|
||||
|
@ -770,7 +749,9 @@ func (hsdb *HSDatabase) ExpireExpiredMachines(lastChange time.Time) {
|
|||
}
|
||||
|
||||
if expiredFound {
|
||||
hsdb.notifyStateChange()
|
||||
hsdb.notifier.NotifyAll()
|
||||
}
|
||||
}
|
||||
|
||||
return started
|
||||
}
|
||||
|
|
|
@ -39,8 +39,6 @@ func (s *Suite) TestGetMachine(c *check.C) {
|
|||
|
||||
_, err = db.GetMachine("test", "testmachine")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetMachineByID(c *check.C) {
|
||||
|
@ -67,8 +65,6 @@ func (s *Suite) TestGetMachineByID(c *check.C) {
|
|||
|
||||
_, err = db.GetMachineByID(0)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetMachineByNodeKey(c *check.C) {
|
||||
|
@ -98,8 +94,6 @@ func (s *Suite) TestGetMachineByNodeKey(c *check.C) {
|
|||
|
||||
_, err = db.GetMachineByNodeKey(nodeKey.Public())
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetMachineByAnyNodeKey(c *check.C) {
|
||||
|
@ -131,8 +125,6 @@ func (s *Suite) TestGetMachineByAnyNodeKey(c *check.C) {
|
|||
|
||||
_, err = db.GetMachineByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public())
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestDeleteMachine(c *check.C) {
|
||||
|
@ -155,8 +147,6 @@ func (s *Suite) TestDeleteMachine(c *check.C) {
|
|||
|
||||
_, err = db.GetMachine(user.Name, "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestHardDeleteMachine(c *check.C) {
|
||||
|
@ -179,8 +169,6 @@ func (s *Suite) TestHardDeleteMachine(c *check.C) {
|
|||
|
||||
_, err = db.GetMachine(user.Name, "testmachine3")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestListPeers(c *check.C) {
|
||||
|
@ -217,8 +205,6 @@ func (s *Suite) TestListPeers(c *check.C) {
|
|||
c.Assert(peersOfMachine0[0].Hostname, check.Equals, "testmachine2")
|
||||
c.Assert(peersOfMachine0[5].Hostname, check.Equals, "testmachine7")
|
||||
c.Assert(peersOfMachine0[8].Hostname, check.Equals, "testmachine10")
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
||||
|
@ -312,8 +298,6 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
|||
c.Assert(peersOfAdminMachine[0].Hostname, check.Equals, "testmachine2")
|
||||
c.Assert(peersOfAdminMachine[2].Hostname, check.Equals, "testmachine4")
|
||||
c.Assert(peersOfAdminMachine[5].Hostname, check.Equals, "testmachine7")
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestExpireMachine(c *check.C) {
|
||||
|
@ -349,8 +333,6 @@ func (s *Suite) TestExpireMachine(c *check.C) {
|
|||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(machineFromDB.IsExpired(), check.Equals, true)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(1))
|
||||
}
|
||||
|
||||
func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) {
|
||||
|
@ -372,8 +354,6 @@ func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) {
|
|||
for i := range deserialized {
|
||||
c.Assert(deserialized[i], check.Equals, input[i])
|
||||
}
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestGenerateGivenName(c *check.C) {
|
||||
|
@ -418,8 +398,6 @@ func (s *Suite) TestGenerateGivenName(c *check.C) {
|
|||
comment = check.Commentf("Unique users, unique machines, same hostname, conflict")
|
||||
c.Assert(err, check.IsNil, comment)
|
||||
c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", MachineGivenNameHashLength), comment)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestSetTags(c *check.C) {
|
||||
|
@ -463,8 +441,6 @@ func (s *Suite) TestSetTags(c *check.C) {
|
|||
check.DeepEquals,
|
||||
types.StringList([]string{"tag:bar", "tag:test", "tag:unknown"}),
|
||||
)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(2))
|
||||
}
|
||||
|
||||
func TestHeadscale_generateGivenName(t *testing.T) {
|
||||
|
@ -655,6 +631,4 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
|
|||
enabledRoutes, err := db.GetEnabledRoutes(machine0ByID)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(enabledRoutes, check.HasLen, 4)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(4))
|
||||
}
|
||||
|
|
|
@ -161,8 +161,6 @@ func (*Suite) TestEphemeralKey(c *check.C) {
|
|||
// The machine record should have been deleted
|
||||
_, err = db.GetMachine("test7", "testest")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(1))
|
||||
}
|
||||
|
||||
func (*Suite) TestExpirePreauthKey(c *check.C) {
|
||||
|
|
|
@ -374,7 +374,7 @@ func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error {
|
|||
}
|
||||
|
||||
if routesChanged {
|
||||
hsdb.notifyStateChange()
|
||||
hsdb.notifier.NotifyAll()
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
@ -52,8 +52,6 @@ func (s *Suite) TestGetRoutes(c *check.C) {
|
|||
|
||||
err = db.enableRoutes(&machine, "10.0.0.0/24")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetEnableRoutes(c *check.C) {
|
||||
|
@ -129,8 +127,6 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
|
|||
enabledRoutesWithAdditionalRoute, err := db.GetEnabledRoutes(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(enabledRoutesWithAdditionalRoute), check.Equals, 2)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(3))
|
||||
}
|
||||
|
||||
func (s *Suite) TestIsUniquePrefix(c *check.C) {
|
||||
|
@ -215,8 +211,6 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
|
|||
routes, err = db.GetMachinePrimaryRoutes(&machine2)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(routes), check.Equals, 0)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(3))
|
||||
}
|
||||
|
||||
func (s *Suite) TestSubnetFailover(c *check.C) {
|
||||
|
@ -359,8 +353,6 @@ func (s *Suite) TestSubnetFailover(c *check.C) {
|
|||
routes, err = db.GetMachinePrimaryRoutes(&machine2)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(routes), check.Equals, 2)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(6))
|
||||
}
|
||||
|
||||
func (s *Suite) TestDeleteRoutes(c *check.C) {
|
||||
|
@ -420,6 +412,4 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
|
|||
enabledRoutes1, err := db.GetEnabledRoutes(&machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(enabledRoutes1), check.Equals, 1)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(2))
|
||||
}
|
||||
|
|
|
@ -3,9 +3,9 @@ package db
|
|||
import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||
"gopkg.in/check.v1"
|
||||
)
|
||||
|
||||
|
@ -20,14 +20,9 @@ type Suite struct{}
|
|||
var (
|
||||
tmpDir string
|
||||
db *HSDatabase
|
||||
|
||||
// channelUpdates counts the number of times
|
||||
// either of the channels was notified.
|
||||
channelUpdates int32
|
||||
)
|
||||
|
||||
func (s *Suite) SetUpTest(c *check.C) {
|
||||
atomic.StoreInt32(&channelUpdates, 0)
|
||||
s.ResetDB(c)
|
||||
}
|
||||
|
||||
|
@ -35,13 +30,6 @@ func (s *Suite) TearDownTest(c *check.C) {
|
|||
os.RemoveAll(tmpDir)
|
||||
}
|
||||
|
||||
func notificationSink(c <-chan struct{}) {
|
||||
for {
|
||||
<-c
|
||||
atomic.AddInt32(&channelUpdates, 1)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Suite) ResetDB(c *check.C) {
|
||||
if len(tmpDir) != 0 {
|
||||
os.RemoveAll(tmpDir)
|
||||
|
@ -52,15 +40,11 @@ func (s *Suite) ResetDB(c *check.C) {
|
|||
c.Fatal(err)
|
||||
}
|
||||
|
||||
sink := make(chan struct{})
|
||||
|
||||
go notificationSink(sink)
|
||||
|
||||
db, err = NewHeadscaleDatabase(
|
||||
"sqlite3",
|
||||
tmpDir+"/headscale_test.db",
|
||||
false,
|
||||
sink,
|
||||
notifier.NewNotifier(),
|
||||
[]netip.Prefix{
|
||||
netip.MustParsePrefix("10.27.0.0/23"),
|
||||
},
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue