Replace the timestamp based state system

This commit replaces the timestamp based state system with a new
one that has update channels directly to the connected nodes. It
will send an update to all listening clients via the polling
mechanism.

It introduces a new package notifier, which has a concurrency safe
manager for all our channels to the connected nodes.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2023-06-21 11:29:52 +02:00 committed by Kristoffer Dalby
parent 056d3a81c5
commit 66ff1fcd40
13 changed files with 216 additions and 731 deletions

View file

@ -10,7 +10,6 @@ import (
"net/http"
"os"
"os/signal"
"sort"
"strconv"
"strings"
"sync"
@ -26,13 +25,13 @@ import (
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/derp"
derpServer "github.com/juanfont/headscale/hscontrol/derp/server"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/patrickmn/go-cache"
zerolog "github.com/philip-bui/grpc-zerolog"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/puzpuzpuz/xsync/v2"
zl "github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"golang.org/x/crypto/acme"
@ -84,7 +83,7 @@ type Headscale struct {
ACLPolicy *policy.ACLPolicy
lastStateChange *xsync.MapOf[string, time.Time]
nodeNotifier *notifier.Notifier
oidcProvider *oidc.Provider
oauth2Config *oauth2.Config
@ -93,9 +92,6 @@ type Headscale struct {
shutdownChan chan struct{}
pollNetMapStreamWG sync.WaitGroup
stateUpdateChan chan struct{}
cancelStateUpdateChan chan struct{}
}
func NewHeadscale(cfg *types.Config) (*Headscale, error) {
@ -158,19 +154,14 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
noisePrivateKey: noisePrivateKey,
registrationCache: registrationCache,
pollNetMapStreamWG: sync.WaitGroup{},
lastStateChange: xsync.NewMapOf[time.Time](),
stateUpdateChan: make(chan struct{}),
cancelStateUpdateChan: make(chan struct{}),
nodeNotifier: notifier.NewNotifier(),
}
go app.watchStateChannel()
database, err := db.NewHeadscaleDatabase(
cfg.DBtype,
dbString,
app.dbDebug,
app.stateUpdateChan,
app.nodeNotifier,
cfg.IPPrefixes,
cfg.BaseDomain)
if err != nil {
@ -203,7 +194,11 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
if cfg.DERP.ServerEnabled {
// TODO(kradalby): replace this key with a dedicated DERP key.
embeddedDERPServer, err := derpServer.NewDERPServer(cfg.ServerURL, key.NodePrivate(*privateKey), &cfg.DERP)
embeddedDERPServer, err := derpServer.NewDERPServer(
cfg.ServerURL,
key.NodePrivate(*privateKey),
&cfg.DERP,
)
if err != nil {
return nil, err
}
@ -230,10 +225,14 @@ func (h *Headscale) expireEphemeralNodes(milliSeconds int64) {
// expireExpiredMachines expires machines that have an explicit expiry set
// after that expiry time has passed.
func (h *Headscale) expireExpiredMachines(milliSeconds int64) {
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
func (h *Headscale) expireExpiredMachines(intervalMs int64) {
interval := time.Duration(intervalMs) * time.Millisecond
ticker := time.NewTicker(interval)
lastCheck := time.Unix(0, 0)
for range ticker.C {
h.db.ExpireExpiredMachines(h.getLastStateChange())
lastCheck = h.db.ExpireExpiredMachines(lastCheck)
}
}
@ -258,7 +257,7 @@ func (h *Headscale) scheduledDERPMapUpdateWorker(cancelChan <-chan struct{}) {
h.DERPMap.Regions[region.RegionID] = &region
}
h.setLastStateChangeToNow()
h.nodeNotifier.NotifyAll()
}
}
}
@ -722,7 +721,7 @@ func (h *Headscale) Serve() error {
Str("path", aclPath).
Msg("ACL policy successfully reloaded, notifying nodes of change")
h.setLastStateChangeToNow()
h.nodeNotifier.NotifyAll()
}
default:
@ -760,10 +759,6 @@ func (h *Headscale) Serve() error {
// Stop listening (and unlink the socket if unix type):
socketListener.Close()
<-h.cancelStateUpdateChan
close(h.stateUpdateChan)
close(h.cancelStateUpdateChan)
// Close db connections
err = h.db.Close()
if err != nil {
@ -859,73 +854,6 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
}
}
// TODO(kradalby): baby steps, make this more robust.
func (h *Headscale) watchStateChannel() {
for {
select {
case <-h.stateUpdateChan:
h.setLastStateChangeToNow()
case <-h.cancelStateUpdateChan:
return
}
}
}
func (h *Headscale) setLastStateChangeToNow() {
var err error
now := time.Now().UTC()
users, err := h.db.ListUsers()
if err != nil {
log.Error().
Caller().
Err(err).
Msg("failed to fetch all users, failing to update last changed state.")
}
for _, user := range users {
lastStateUpdate.WithLabelValues(user.Name, "headscale").Set(float64(now.Unix()))
if h.lastStateChange == nil {
h.lastStateChange = xsync.NewMapOf[time.Time]()
}
h.lastStateChange.Store(user.Name, now)
}
}
func (h *Headscale) getLastStateChange(users ...types.User) time.Time {
times := []time.Time{}
// getLastStateChange takes a list of users as a "filter", if no users
// are past, then use the entier list of users and look for the last update
if len(users) > 0 {
for _, user := range users {
if lastChange, ok := h.lastStateChange.Load(user.Name); ok {
times = append(times, lastChange)
}
}
} else {
h.lastStateChange.Range(func(key string, value time.Time) bool {
times = append(times, value)
return true
})
}
sort.Slice(times, func(i, j int) bool {
return times[i].After(times[j])
})
log.Trace().Msgf("Latest times %#v", times)
if len(times) == 0 {
return time.Now().UTC()
} else {
return times[0]
}
}
func notFoundHandler(
writer http.ResponseWriter,
req *http.Request,