batch updates in notifier (#1905)
This commit is contained in:
parent
fef8261339
commit
cb0b495ea9
8 changed files with 541 additions and 158 deletions
|
@ -3,7 +3,7 @@ package notifier
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -11,19 +11,27 @@ import (
|
|||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/puzpuzpuz/xsync/v3"
|
||||
"github.com/rs/zerolog/log"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/util/set"
|
||||
)
|
||||
|
||||
type Notifier struct {
|
||||
l sync.RWMutex
|
||||
nodes map[types.NodeID]chan<- types.StateUpdate
|
||||
connected *xsync.MapOf[types.NodeID, bool]
|
||||
b *batcher
|
||||
}
|
||||
|
||||
func NewNotifier() *Notifier {
|
||||
return &Notifier{
|
||||
func NewNotifier(cfg *types.Config) *Notifier {
|
||||
n := &Notifier{
|
||||
nodes: make(map[types.NodeID]chan<- types.StateUpdate),
|
||||
connected: xsync.NewMapOf[types.NodeID, bool](),
|
||||
}
|
||||
b := newBatcher(cfg.Tuning.BatchChangeDelay, n)
|
||||
n.b = b
|
||||
// TODO(kradalby): clean this up
|
||||
go b.doWork()
|
||||
return n
|
||||
}
|
||||
|
||||
func (n *Notifier) AddNode(nodeID types.NodeID, c chan<- types.StateUpdate) {
|
||||
|
@ -108,13 +116,8 @@ func (n *Notifier) NotifyWithIgnore(
|
|||
update types.StateUpdate,
|
||||
ignoreNodeIDs ...types.NodeID,
|
||||
) {
|
||||
for nodeID := range n.nodes {
|
||||
if slices.Contains(ignoreNodeIDs, nodeID) {
|
||||
continue
|
||||
}
|
||||
|
||||
n.NotifyByNodeID(ctx, update, nodeID)
|
||||
}
|
||||
notifierUpdateReceived.WithLabelValues(update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
|
||||
n.b.addOrPassthrough(update)
|
||||
}
|
||||
|
||||
func (n *Notifier) NotifyByNodeID(
|
||||
|
@ -139,10 +142,10 @@ func (n *Notifier) NotifyByNodeID(
|
|||
log.Error().
|
||||
Err(ctx.Err()).
|
||||
Uint64("node.id", nodeID.Uint64()).
|
||||
Any("origin", ctx.Value("origin")).
|
||||
Any("origin-hostname", ctx.Value("hostname")).
|
||||
Any("origin", types.NotifyOriginKey.Value(ctx)).
|
||||
Any("origin-hostname", types.NotifyHostnameKey.Value(ctx)).
|
||||
Msgf("update not sent, context cancelled")
|
||||
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String()).Inc()
|
||||
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
|
||||
|
||||
return
|
||||
case c <- update:
|
||||
|
@ -151,11 +154,23 @@ func (n *Notifier) NotifyByNodeID(
|
|||
Any("origin", ctx.Value("origin")).
|
||||
Any("origin-hostname", ctx.Value("hostname")).
|
||||
Msgf("update successfully sent on chan")
|
||||
notifierUpdateSent.WithLabelValues("ok", update.Type.String()).Inc()
|
||||
notifierUpdateSent.WithLabelValues("ok", update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (n *Notifier) sendAll(update types.StateUpdate) {
|
||||
start := time.Now()
|
||||
n.l.RLock()
|
||||
defer n.l.RUnlock()
|
||||
notifierWaitForLock.WithLabelValues("send-all").Observe(time.Since(start).Seconds())
|
||||
|
||||
for _, c := range n.nodes {
|
||||
c <- update
|
||||
notifierUpdateSent.WithLabelValues("ok", update.Type.String(), "send-all").Inc()
|
||||
}
|
||||
}
|
||||
|
||||
func (n *Notifier) String() string {
|
||||
n.l.RLock()
|
||||
defer n.l.RUnlock()
|
||||
|
@ -177,3 +192,166 @@ func (n *Notifier) String() string {
|
|||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
type batcher struct {
|
||||
tick *time.Ticker
|
||||
|
||||
mu sync.Mutex
|
||||
|
||||
cancelCh chan struct{}
|
||||
|
||||
changedNodeIDs set.Slice[types.NodeID]
|
||||
nodesChanged bool
|
||||
patches map[types.NodeID]tailcfg.PeerChange
|
||||
patchesChanged bool
|
||||
|
||||
n *Notifier
|
||||
}
|
||||
|
||||
func newBatcher(batchTime time.Duration, n *Notifier) *batcher {
|
||||
return &batcher{
|
||||
tick: time.NewTicker(batchTime),
|
||||
cancelCh: make(chan struct{}),
|
||||
patches: make(map[types.NodeID]tailcfg.PeerChange),
|
||||
n: n,
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (b *batcher) close() {
|
||||
b.cancelCh <- struct{}{}
|
||||
}
|
||||
|
||||
// addOrPassthrough adds the update to the batcher, if it is not a
|
||||
// type that is currently batched, it will be sent immediately.
|
||||
func (b *batcher) addOrPassthrough(update types.StateUpdate) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
switch update.Type {
|
||||
case types.StatePeerChanged:
|
||||
b.changedNodeIDs.Add(update.ChangeNodes...)
|
||||
b.nodesChanged = true
|
||||
|
||||
case types.StatePeerChangedPatch:
|
||||
for _, newPatch := range update.ChangePatches {
|
||||
if curr, ok := b.patches[types.NodeID(newPatch.NodeID)]; ok {
|
||||
overwritePatch(&curr, newPatch)
|
||||
b.patches[types.NodeID(newPatch.NodeID)] = curr
|
||||
} else {
|
||||
b.patches[types.NodeID(newPatch.NodeID)] = *newPatch
|
||||
}
|
||||
}
|
||||
b.patchesChanged = true
|
||||
|
||||
default:
|
||||
b.n.sendAll(update)
|
||||
}
|
||||
}
|
||||
|
||||
// flush sends all the accumulated patches to all
|
||||
// nodes in the notifier.
|
||||
func (b *batcher) flush() {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if b.nodesChanged || b.patchesChanged {
|
||||
var patches []*tailcfg.PeerChange
|
||||
// If a node is getting a full update from a change
|
||||
// node update, then the patch can be dropped.
|
||||
for nodeID, patch := range b.patches {
|
||||
if b.changedNodeIDs.Contains(nodeID) {
|
||||
delete(b.patches, nodeID)
|
||||
} else {
|
||||
patches = append(patches, &patch)
|
||||
}
|
||||
}
|
||||
|
||||
changedNodes := b.changedNodeIDs.Slice().AsSlice()
|
||||
sort.Slice(changedNodes, func(i, j int) bool {
|
||||
return changedNodes[i] < changedNodes[j]
|
||||
})
|
||||
|
||||
if b.changedNodeIDs.Slice().Len() > 0 {
|
||||
update := types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: changedNodes,
|
||||
}
|
||||
|
||||
b.n.sendAll(update)
|
||||
}
|
||||
|
||||
if len(patches) > 0 {
|
||||
patchUpdate := types.StateUpdate{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: patches,
|
||||
}
|
||||
|
||||
b.n.sendAll(patchUpdate)
|
||||
}
|
||||
|
||||
b.changedNodeIDs = set.Slice[types.NodeID]{}
|
||||
b.nodesChanged = false
|
||||
b.patches = make(map[types.NodeID]tailcfg.PeerChange, len(b.patches))
|
||||
b.patchesChanged = false
|
||||
}
|
||||
}
|
||||
|
||||
func (b *batcher) doWork() {
|
||||
for {
|
||||
select {
|
||||
case <-b.cancelCh:
|
||||
return
|
||||
case <-b.tick.C:
|
||||
b.flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// overwritePatch takes the current patch and a newer patch
|
||||
// and override any field that has changed
|
||||
func overwritePatch(currPatch, newPatch *tailcfg.PeerChange) {
|
||||
if newPatch.DERPRegion != 0 {
|
||||
currPatch.DERPRegion = newPatch.DERPRegion
|
||||
}
|
||||
|
||||
if newPatch.Cap != 0 {
|
||||
currPatch.Cap = newPatch.Cap
|
||||
}
|
||||
|
||||
if newPatch.CapMap != nil {
|
||||
currPatch.CapMap = newPatch.CapMap
|
||||
}
|
||||
|
||||
if newPatch.Endpoints != nil {
|
||||
currPatch.Endpoints = newPatch.Endpoints
|
||||
}
|
||||
|
||||
if newPatch.Key != nil {
|
||||
currPatch.Key = newPatch.Key
|
||||
}
|
||||
|
||||
if newPatch.KeySignature != nil {
|
||||
currPatch.KeySignature = newPatch.KeySignature
|
||||
}
|
||||
|
||||
if newPatch.DiscoKey != nil {
|
||||
currPatch.DiscoKey = newPatch.DiscoKey
|
||||
}
|
||||
|
||||
if newPatch.Online != nil {
|
||||
currPatch.Online = newPatch.Online
|
||||
}
|
||||
|
||||
if newPatch.LastSeen != nil {
|
||||
currPatch.LastSeen = newPatch.LastSeen
|
||||
}
|
||||
|
||||
if newPatch.KeyExpiry != nil {
|
||||
currPatch.KeyExpiry = newPatch.KeyExpiry
|
||||
}
|
||||
|
||||
if newPatch.Capabilities != nil {
|
||||
currPatch.Capabilities = newPatch.Capabilities
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue