Replace the timestamp based state system
This commit replaces the timestamp based state system with a new one that has update channels directly to the connected nodes. It will send an update to all listening clients via the polling mechanism. It introduces a new package notifier, which has a concurrency safe manager for all our channels to the connected nodes. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
056d3a81c5
commit
66ff1fcd40
13 changed files with 216 additions and 731 deletions
108
hscontrol/app.go
108
hscontrol/app.go
|
@ -10,7 +10,6 @@ import (
|
|||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
@ -26,13 +25,13 @@ import (
|
|||
"github.com/juanfont/headscale/hscontrol/db"
|
||||
"github.com/juanfont/headscale/hscontrol/derp"
|
||||
derpServer "github.com/juanfont/headscale/hscontrol/derp/server"
|
||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/patrickmn/go-cache"
|
||||
zerolog "github.com/philip-bui/grpc-zerolog"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"github.com/puzpuzpuz/xsync/v2"
|
||||
zl "github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/crypto/acme"
|
||||
|
@ -84,7 +83,7 @@ type Headscale struct {
|
|||
|
||||
ACLPolicy *policy.ACLPolicy
|
||||
|
||||
lastStateChange *xsync.MapOf[string, time.Time]
|
||||
nodeNotifier *notifier.Notifier
|
||||
|
||||
oidcProvider *oidc.Provider
|
||||
oauth2Config *oauth2.Config
|
||||
|
@ -93,9 +92,6 @@ type Headscale struct {
|
|||
|
||||
shutdownChan chan struct{}
|
||||
pollNetMapStreamWG sync.WaitGroup
|
||||
|
||||
stateUpdateChan chan struct{}
|
||||
cancelStateUpdateChan chan struct{}
|
||||
}
|
||||
|
||||
func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||
|
@ -158,19 +154,14 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
|||
noisePrivateKey: noisePrivateKey,
|
||||
registrationCache: registrationCache,
|
||||
pollNetMapStreamWG: sync.WaitGroup{},
|
||||
lastStateChange: xsync.NewMapOf[time.Time](),
|
||||
|
||||
stateUpdateChan: make(chan struct{}),
|
||||
cancelStateUpdateChan: make(chan struct{}),
|
||||
nodeNotifier: notifier.NewNotifier(),
|
||||
}
|
||||
|
||||
go app.watchStateChannel()
|
||||
|
||||
database, err := db.NewHeadscaleDatabase(
|
||||
cfg.DBtype,
|
||||
dbString,
|
||||
app.dbDebug,
|
||||
app.stateUpdateChan,
|
||||
app.nodeNotifier,
|
||||
cfg.IPPrefixes,
|
||||
cfg.BaseDomain)
|
||||
if err != nil {
|
||||
|
@ -203,7 +194,11 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
|||
|
||||
if cfg.DERP.ServerEnabled {
|
||||
// TODO(kradalby): replace this key with a dedicated DERP key.
|
||||
embeddedDERPServer, err := derpServer.NewDERPServer(cfg.ServerURL, key.NodePrivate(*privateKey), &cfg.DERP)
|
||||
embeddedDERPServer, err := derpServer.NewDERPServer(
|
||||
cfg.ServerURL,
|
||||
key.NodePrivate(*privateKey),
|
||||
&cfg.DERP,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -230,10 +225,14 @@ func (h *Headscale) expireEphemeralNodes(milliSeconds int64) {
|
|||
|
||||
// expireExpiredMachines expires machines that have an explicit expiry set
|
||||
// after that expiry time has passed.
|
||||
func (h *Headscale) expireExpiredMachines(milliSeconds int64) {
|
||||
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
|
||||
func (h *Headscale) expireExpiredMachines(intervalMs int64) {
|
||||
interval := time.Duration(intervalMs) * time.Millisecond
|
||||
ticker := time.NewTicker(interval)
|
||||
|
||||
lastCheck := time.Unix(0, 0)
|
||||
|
||||
for range ticker.C {
|
||||
h.db.ExpireExpiredMachines(h.getLastStateChange())
|
||||
lastCheck = h.db.ExpireExpiredMachines(lastCheck)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -258,7 +257,7 @@ func (h *Headscale) scheduledDERPMapUpdateWorker(cancelChan <-chan struct{}) {
|
|||
h.DERPMap.Regions[region.RegionID] = ®ion
|
||||
}
|
||||
|
||||
h.setLastStateChangeToNow()
|
||||
h.nodeNotifier.NotifyAll()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -722,7 +721,7 @@ func (h *Headscale) Serve() error {
|
|||
Str("path", aclPath).
|
||||
Msg("ACL policy successfully reloaded, notifying nodes of change")
|
||||
|
||||
h.setLastStateChangeToNow()
|
||||
h.nodeNotifier.NotifyAll()
|
||||
}
|
||||
|
||||
default:
|
||||
|
@ -760,10 +759,6 @@ func (h *Headscale) Serve() error {
|
|||
// Stop listening (and unlink the socket if unix type):
|
||||
socketListener.Close()
|
||||
|
||||
<-h.cancelStateUpdateChan
|
||||
close(h.stateUpdateChan)
|
||||
close(h.cancelStateUpdateChan)
|
||||
|
||||
// Close db connections
|
||||
err = h.db.Close()
|
||||
if err != nil {
|
||||
|
@ -859,73 +854,6 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
|
|||
}
|
||||
}
|
||||
|
||||
// TODO(kradalby): baby steps, make this more robust.
|
||||
func (h *Headscale) watchStateChannel() {
|
||||
for {
|
||||
select {
|
||||
case <-h.stateUpdateChan:
|
||||
h.setLastStateChangeToNow()
|
||||
|
||||
case <-h.cancelStateUpdateChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Headscale) setLastStateChangeToNow() {
|
||||
var err error
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
users, err := h.db.ListUsers()
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("failed to fetch all users, failing to update last changed state.")
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
lastStateUpdate.WithLabelValues(user.Name, "headscale").Set(float64(now.Unix()))
|
||||
if h.lastStateChange == nil {
|
||||
h.lastStateChange = xsync.NewMapOf[time.Time]()
|
||||
}
|
||||
h.lastStateChange.Store(user.Name, now)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Headscale) getLastStateChange(users ...types.User) time.Time {
|
||||
times := []time.Time{}
|
||||
|
||||
// getLastStateChange takes a list of users as a "filter", if no users
|
||||
// are past, then use the entier list of users and look for the last update
|
||||
if len(users) > 0 {
|
||||
for _, user := range users {
|
||||
if lastChange, ok := h.lastStateChange.Load(user.Name); ok {
|
||||
times = append(times, lastChange)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
h.lastStateChange.Range(func(key string, value time.Time) bool {
|
||||
times = append(times, value)
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
sort.Slice(times, func(i, j int) bool {
|
||||
return times[i].After(times[j])
|
||||
})
|
||||
|
||||
log.Trace().Msgf("Latest times %#v", times)
|
||||
|
||||
if len(times) == 0 {
|
||||
return time.Now().UTC()
|
||||
} else {
|
||||
return times[0]
|
||||
}
|
||||
}
|
||||
|
||||
func notFoundHandler(
|
||||
writer http.ResponseWriter,
|
||||
req *http.Request,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue