batch updates in notifier (#1905)
This commit is contained in:
parent
fef8261339
commit
cb0b495ea9
8 changed files with 541 additions and 158 deletions
|
@ -18,7 +18,12 @@ var (
|
|||
Namespace: prometheusNamespace,
|
||||
Name: "notifier_update_sent_total",
|
||||
Help: "total count of update sent on nodes channel",
|
||||
}, []string{"status", "type"})
|
||||
}, []string{"status", "type", "trigger"})
|
||||
notifierUpdateReceived = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "notifier_update_received_total",
|
||||
Help: "total count of updates received by notifier",
|
||||
}, []string{"type", "trigger"})
|
||||
notifierNodeUpdateChans = promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "notifier_open_channels_total",
|
||||
|
|
|
@ -3,7 +3,7 @@ package notifier
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -11,19 +11,27 @@ import (
|
|||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/puzpuzpuz/xsync/v3"
|
||||
"github.com/rs/zerolog/log"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/util/set"
|
||||
)
|
||||
|
||||
type Notifier struct {
|
||||
l sync.RWMutex
|
||||
nodes map[types.NodeID]chan<- types.StateUpdate
|
||||
connected *xsync.MapOf[types.NodeID, bool]
|
||||
b *batcher
|
||||
}
|
||||
|
||||
func NewNotifier() *Notifier {
|
||||
return &Notifier{
|
||||
func NewNotifier(cfg *types.Config) *Notifier {
|
||||
n := &Notifier{
|
||||
nodes: make(map[types.NodeID]chan<- types.StateUpdate),
|
||||
connected: xsync.NewMapOf[types.NodeID, bool](),
|
||||
}
|
||||
b := newBatcher(cfg.Tuning.BatchChangeDelay, n)
|
||||
n.b = b
|
||||
// TODO(kradalby): clean this up
|
||||
go b.doWork()
|
||||
return n
|
||||
}
|
||||
|
||||
func (n *Notifier) AddNode(nodeID types.NodeID, c chan<- types.StateUpdate) {
|
||||
|
@ -108,13 +116,8 @@ func (n *Notifier) NotifyWithIgnore(
|
|||
update types.StateUpdate,
|
||||
ignoreNodeIDs ...types.NodeID,
|
||||
) {
|
||||
for nodeID := range n.nodes {
|
||||
if slices.Contains(ignoreNodeIDs, nodeID) {
|
||||
continue
|
||||
}
|
||||
|
||||
n.NotifyByNodeID(ctx, update, nodeID)
|
||||
}
|
||||
notifierUpdateReceived.WithLabelValues(update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
|
||||
n.b.addOrPassthrough(update)
|
||||
}
|
||||
|
||||
func (n *Notifier) NotifyByNodeID(
|
||||
|
@ -139,10 +142,10 @@ func (n *Notifier) NotifyByNodeID(
|
|||
log.Error().
|
||||
Err(ctx.Err()).
|
||||
Uint64("node.id", nodeID.Uint64()).
|
||||
Any("origin", ctx.Value("origin")).
|
||||
Any("origin-hostname", ctx.Value("hostname")).
|
||||
Any("origin", types.NotifyOriginKey.Value(ctx)).
|
||||
Any("origin-hostname", types.NotifyHostnameKey.Value(ctx)).
|
||||
Msgf("update not sent, context cancelled")
|
||||
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String()).Inc()
|
||||
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
|
||||
|
||||
return
|
||||
case c <- update:
|
||||
|
@ -151,11 +154,23 @@ func (n *Notifier) NotifyByNodeID(
|
|||
Any("origin", ctx.Value("origin")).
|
||||
Any("origin-hostname", ctx.Value("hostname")).
|
||||
Msgf("update successfully sent on chan")
|
||||
notifierUpdateSent.WithLabelValues("ok", update.Type.String()).Inc()
|
||||
notifierUpdateSent.WithLabelValues("ok", update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (n *Notifier) sendAll(update types.StateUpdate) {
|
||||
start := time.Now()
|
||||
n.l.RLock()
|
||||
defer n.l.RUnlock()
|
||||
notifierWaitForLock.WithLabelValues("send-all").Observe(time.Since(start).Seconds())
|
||||
|
||||
for _, c := range n.nodes {
|
||||
c <- update
|
||||
notifierUpdateSent.WithLabelValues("ok", update.Type.String(), "send-all").Inc()
|
||||
}
|
||||
}
|
||||
|
||||
func (n *Notifier) String() string {
|
||||
n.l.RLock()
|
||||
defer n.l.RUnlock()
|
||||
|
@ -177,3 +192,166 @@ func (n *Notifier) String() string {
|
|||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
type batcher struct {
|
||||
tick *time.Ticker
|
||||
|
||||
mu sync.Mutex
|
||||
|
||||
cancelCh chan struct{}
|
||||
|
||||
changedNodeIDs set.Slice[types.NodeID]
|
||||
nodesChanged bool
|
||||
patches map[types.NodeID]tailcfg.PeerChange
|
||||
patchesChanged bool
|
||||
|
||||
n *Notifier
|
||||
}
|
||||
|
||||
func newBatcher(batchTime time.Duration, n *Notifier) *batcher {
|
||||
return &batcher{
|
||||
tick: time.NewTicker(batchTime),
|
||||
cancelCh: make(chan struct{}),
|
||||
patches: make(map[types.NodeID]tailcfg.PeerChange),
|
||||
n: n,
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (b *batcher) close() {
|
||||
b.cancelCh <- struct{}{}
|
||||
}
|
||||
|
||||
// addOrPassthrough adds the update to the batcher, if it is not a
|
||||
// type that is currently batched, it will be sent immediately.
|
||||
func (b *batcher) addOrPassthrough(update types.StateUpdate) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
switch update.Type {
|
||||
case types.StatePeerChanged:
|
||||
b.changedNodeIDs.Add(update.ChangeNodes...)
|
||||
b.nodesChanged = true
|
||||
|
||||
case types.StatePeerChangedPatch:
|
||||
for _, newPatch := range update.ChangePatches {
|
||||
if curr, ok := b.patches[types.NodeID(newPatch.NodeID)]; ok {
|
||||
overwritePatch(&curr, newPatch)
|
||||
b.patches[types.NodeID(newPatch.NodeID)] = curr
|
||||
} else {
|
||||
b.patches[types.NodeID(newPatch.NodeID)] = *newPatch
|
||||
}
|
||||
}
|
||||
b.patchesChanged = true
|
||||
|
||||
default:
|
||||
b.n.sendAll(update)
|
||||
}
|
||||
}
|
||||
|
||||
// flush sends all the accumulated patches to all
|
||||
// nodes in the notifier.
|
||||
func (b *batcher) flush() {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if b.nodesChanged || b.patchesChanged {
|
||||
var patches []*tailcfg.PeerChange
|
||||
// If a node is getting a full update from a change
|
||||
// node update, then the patch can be dropped.
|
||||
for nodeID, patch := range b.patches {
|
||||
if b.changedNodeIDs.Contains(nodeID) {
|
||||
delete(b.patches, nodeID)
|
||||
} else {
|
||||
patches = append(patches, &patch)
|
||||
}
|
||||
}
|
||||
|
||||
changedNodes := b.changedNodeIDs.Slice().AsSlice()
|
||||
sort.Slice(changedNodes, func(i, j int) bool {
|
||||
return changedNodes[i] < changedNodes[j]
|
||||
})
|
||||
|
||||
if b.changedNodeIDs.Slice().Len() > 0 {
|
||||
update := types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: changedNodes,
|
||||
}
|
||||
|
||||
b.n.sendAll(update)
|
||||
}
|
||||
|
||||
if len(patches) > 0 {
|
||||
patchUpdate := types.StateUpdate{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: patches,
|
||||
}
|
||||
|
||||
b.n.sendAll(patchUpdate)
|
||||
}
|
||||
|
||||
b.changedNodeIDs = set.Slice[types.NodeID]{}
|
||||
b.nodesChanged = false
|
||||
b.patches = make(map[types.NodeID]tailcfg.PeerChange, len(b.patches))
|
||||
b.patchesChanged = false
|
||||
}
|
||||
}
|
||||
|
||||
func (b *batcher) doWork() {
|
||||
for {
|
||||
select {
|
||||
case <-b.cancelCh:
|
||||
return
|
||||
case <-b.tick.C:
|
||||
b.flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// overwritePatch takes the current patch and a newer patch
|
||||
// and override any field that has changed
|
||||
func overwritePatch(currPatch, newPatch *tailcfg.PeerChange) {
|
||||
if newPatch.DERPRegion != 0 {
|
||||
currPatch.DERPRegion = newPatch.DERPRegion
|
||||
}
|
||||
|
||||
if newPatch.Cap != 0 {
|
||||
currPatch.Cap = newPatch.Cap
|
||||
}
|
||||
|
||||
if newPatch.CapMap != nil {
|
||||
currPatch.CapMap = newPatch.CapMap
|
||||
}
|
||||
|
||||
if newPatch.Endpoints != nil {
|
||||
currPatch.Endpoints = newPatch.Endpoints
|
||||
}
|
||||
|
||||
if newPatch.Key != nil {
|
||||
currPatch.Key = newPatch.Key
|
||||
}
|
||||
|
||||
if newPatch.KeySignature != nil {
|
||||
currPatch.KeySignature = newPatch.KeySignature
|
||||
}
|
||||
|
||||
if newPatch.DiscoKey != nil {
|
||||
currPatch.DiscoKey = newPatch.DiscoKey
|
||||
}
|
||||
|
||||
if newPatch.Online != nil {
|
||||
currPatch.Online = newPatch.Online
|
||||
}
|
||||
|
||||
if newPatch.LastSeen != nil {
|
||||
currPatch.LastSeen = newPatch.LastSeen
|
||||
}
|
||||
|
||||
if newPatch.KeyExpiry != nil {
|
||||
currPatch.KeyExpiry = newPatch.KeyExpiry
|
||||
}
|
||||
|
||||
if newPatch.Capabilities != nil {
|
||||
currPatch.Capabilities = newPatch.Capabilities
|
||||
}
|
||||
}
|
||||
|
|
249
hscontrol/notifier/notifier_test.go
Normal file
249
hscontrol/notifier/notifier_test.go
Normal file
|
@ -0,0 +1,249 @@
|
|||
package notifier
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
func TestBatcher(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
updates []types.StateUpdate
|
||||
want []types.StateUpdate
|
||||
}{
|
||||
{
|
||||
name: "full-passthrough",
|
||||
updates: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StateFullUpdate,
|
||||
},
|
||||
},
|
||||
want: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StateFullUpdate,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "derp-passthrough",
|
||||
updates: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StateDERPUpdated,
|
||||
},
|
||||
},
|
||||
want: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StateDERPUpdated,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single-node-update",
|
||||
updates: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: []types.NodeID{
|
||||
2,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: []types.NodeID{
|
||||
2,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "merge-node-update",
|
||||
updates: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: []types.NodeID{
|
||||
2, 4,
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: []types.NodeID{
|
||||
2, 3,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: []types.NodeID{
|
||||
2, 3, 4,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single-patch-update",
|
||||
updates: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 2,
|
||||
DERPRegion: 5,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 2,
|
||||
DERPRegion: 5,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "merge-patch-to-same-node-update",
|
||||
updates: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 2,
|
||||
DERPRegion: 5,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 2,
|
||||
DERPRegion: 6,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 2,
|
||||
DERPRegion: 6,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "merge-patch-to-multiple-node-update",
|
||||
updates: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 3,
|
||||
Endpoints: []netip.AddrPort{
|
||||
netip.MustParseAddrPort("1.1.1.1:9090"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 3,
|
||||
Endpoints: []netip.AddrPort{
|
||||
netip.MustParseAddrPort("1.1.1.1:9090"),
|
||||
netip.MustParseAddrPort("2.2.2.2:8080"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 4,
|
||||
DERPRegion: 6,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 4,
|
||||
Cap: tailcfg.CapabilityVersion(54),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 3,
|
||||
Endpoints: []netip.AddrPort{
|
||||
netip.MustParseAddrPort("1.1.1.1:9090"),
|
||||
netip.MustParseAddrPort("2.2.2.2:8080"),
|
||||
},
|
||||
},
|
||||
{
|
||||
NodeID: 4,
|
||||
DERPRegion: 6,
|
||||
Cap: tailcfg.CapabilityVersion(54),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
n := NewNotifier(&types.Config{
|
||||
Tuning: types.Tuning{
|
||||
// We will call flush manually for the tests,
|
||||
// so do not run the worker.
|
||||
BatchChangeDelay: time.Hour,
|
||||
},
|
||||
})
|
||||
|
||||
ch := make(chan types.StateUpdate, 30)
|
||||
defer close(ch)
|
||||
n.AddNode(1, ch)
|
||||
defer n.RemoveNode(1)
|
||||
|
||||
for _, u := range tt.updates {
|
||||
n.NotifyAll(context.Background(), u)
|
||||
}
|
||||
|
||||
n.b.flush()
|
||||
|
||||
var got []types.StateUpdate
|
||||
for len(ch) > 0 {
|
||||
out := <-ch
|
||||
got = append(got, out)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
|
||||
t.Errorf("batcher() unexpected result (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue