Rework map session

This commit restructures the map session in to a struct
holding the state of what is needed during its lifetime.

For streaming sessions, the event loop is structured a
bit differently not hammering the clients with updates
but rather batching them over a short, configurable time
which should significantly improve cpu usage, and potentially
flakyness.

The use of Patch updates has been dialed back a little as
it does not look like its a 100% ready for prime time. Nodes
are now updated with full changes, except for a few things
like online status.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2024-02-23 10:59:24 +01:00 committed by Juan Font
parent dd693c444c
commit 58c94d2bd3
35 changed files with 1803 additions and 1716 deletions

View file

@ -16,12 +16,12 @@ import (
"time"
mapset "github.com/deckarep/golang-set/v2"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/klauspost/compress/zstd"
"github.com/rs/zerolog/log"
"golang.org/x/exp/maps"
"tailscale.com/envknob"
"tailscale.com/smallzstd"
"tailscale.com/tailcfg"
@ -51,21 +51,14 @@ var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_
type Mapper struct {
// Configuration
// TODO(kradalby): figure out if this is the format we want this in
derpMap *tailcfg.DERPMap
baseDomain string
dnsCfg *tailcfg.DNSConfig
logtail bool
randomClientPort bool
db *db.HSDatabase
cfg *types.Config
derpMap *tailcfg.DERPMap
isLikelyConnected types.NodeConnectedMap
uid string
created time.Time
seq uint64
// Map isnt concurrency safe, so we need to ensure
// only one func is accessing it over time.
mu sync.Mutex
peers map[uint64]*types.Node
patches map[uint64][]patch
}
type patch struct {
@ -74,35 +67,22 @@ type patch struct {
}
func NewMapper(
node *types.Node,
peers types.Nodes,
db *db.HSDatabase,
cfg *types.Config,
derpMap *tailcfg.DERPMap,
baseDomain string,
dnsCfg *tailcfg.DNSConfig,
logtail bool,
randomClientPort bool,
isLikelyConnected types.NodeConnectedMap,
) *Mapper {
log.Debug().
Caller().
Str("node", node.Hostname).
Msg("creating new mapper")
uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
return &Mapper{
derpMap: derpMap,
baseDomain: baseDomain,
dnsCfg: dnsCfg,
logtail: logtail,
randomClientPort: randomClientPort,
db: db,
cfg: cfg,
derpMap: derpMap,
isLikelyConnected: isLikelyConnected,
uid: uid,
created: time.Now(),
seq: 0,
// TODO: populate
peers: peers.IDMap(),
patches: make(map[uint64][]patch),
}
}
@ -207,11 +187,10 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
// It is a separate function to make testing easier.
func (m *Mapper) fullMapResponse(
node *types.Node,
peers types.Nodes,
pol *policy.ACLPolicy,
capVer tailcfg.CapabilityVersion,
) (*tailcfg.MapResponse, error) {
peers := nodeMapToList(m.peers)
resp, err := m.baseWithConfigMapResponse(node, pol, capVer)
if err != nil {
return nil, err
@ -219,14 +198,13 @@ func (m *Mapper) fullMapResponse(
err = appendPeerChanges(
resp,
true, // full change
pol,
node,
capVer,
peers,
peers,
m.baseDomain,
m.dnsCfg,
m.randomClientPort,
m.cfg,
)
if err != nil {
return nil, err
@ -240,35 +218,25 @@ func (m *Mapper) FullMapResponse(
mapRequest tailcfg.MapRequest,
node *types.Node,
pol *policy.ACLPolicy,
messages ...string,
) ([]byte, error) {
m.mu.Lock()
defer m.mu.Unlock()
peers := maps.Keys(m.peers)
peersWithPatches := maps.Keys(m.patches)
slices.Sort(peers)
slices.Sort(peersWithPatches)
if len(peersWithPatches) > 0 {
log.Debug().
Str("node", node.Hostname).
Uints64("peers", peers).
Uints64("pending_patches", peersWithPatches).
Msgf("node requested full map response, but has pending patches")
}
resp, err := m.fullMapResponse(node, pol, mapRequest.Version)
peers, err := m.ListPeers(node.ID)
if err != nil {
return nil, err
}
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress)
resp, err := m.fullMapResponse(node, peers, pol, mapRequest.Version)
if err != nil {
return nil, err
}
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...)
}
// LiteMapResponse returns a MapResponse for the given node.
// ReadOnlyResponse returns a MapResponse for the given node.
// Lite means that the peers has been omitted, this is intended
// to be used to answer MapRequests with OmitPeers set to true.
func (m *Mapper) LiteMapResponse(
func (m *Mapper) ReadOnlyMapResponse(
mapRequest tailcfg.MapRequest,
node *types.Node,
pol *policy.ACLPolicy,
@ -279,18 +247,6 @@ func (m *Mapper) LiteMapResponse(
return nil, err
}
rules, sshPolicy, err := policy.GenerateFilterAndSSHRules(
pol,
node,
nodeMapToList(m.peers),
)
if err != nil {
return nil, err
}
resp.PacketFilter = policy.ReduceFilterRules(node, rules)
resp.SSHPolicy = sshPolicy
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...)
}
@ -320,50 +276,74 @@ func (m *Mapper) DERPMapResponse(
func (m *Mapper) PeerChangedResponse(
mapRequest tailcfg.MapRequest,
node *types.Node,
changed types.Nodes,
changed map[types.NodeID]bool,
patches []*tailcfg.PeerChange,
pol *policy.ACLPolicy,
messages ...string,
) ([]byte, error) {
m.mu.Lock()
defer m.mu.Unlock()
// Update our internal map.
for _, node := range changed {
if patches, ok := m.patches[node.ID]; ok {
// preserve online status in case the patch has an outdated one
online := node.IsOnline
for _, p := range patches {
// TODO(kradalby): Figure if this needs to be sorted by timestamp
node.ApplyPeerChange(p.change)
}
// Ensure the patches are not applied again later
delete(m.patches, node.ID)
node.IsOnline = online
}
m.peers[node.ID] = node
}
resp := m.baseMapResponse()
err := appendPeerChanges(
peers, err := m.ListPeers(node.ID)
if err != nil {
return nil, err
}
var removedIDs []tailcfg.NodeID
var changedIDs []types.NodeID
for nodeID, nodeChanged := range changed {
if nodeChanged {
changedIDs = append(changedIDs, nodeID)
} else {
removedIDs = append(removedIDs, nodeID.NodeID())
}
}
changedNodes := make(types.Nodes, 0, len(changedIDs))
for _, peer := range peers {
if slices.Contains(changedIDs, peer.ID) {
changedNodes = append(changedNodes, peer)
}
}
err = appendPeerChanges(
&resp,
false, // partial change
pol,
node,
mapRequest.Version,
nodeMapToList(m.peers),
changed,
m.baseDomain,
m.dnsCfg,
m.randomClientPort,
peers,
changedNodes,
m.cfg,
)
if err != nil {
return nil, err
}
resp.PeersRemoved = removedIDs
// Sending patches as a part of a PeersChanged response
// is technically not suppose to be done, but they are
// applied after the PeersChanged. The patch list
// should _only_ contain Nodes that are not in the
// PeersChanged or PeersRemoved list and the caller
// should filter them out.
//
// From tailcfg docs:
// These are applied after Peers* above, but in practice the
// control server should only send these on their own, without
// the Peers* fields also set.
if patches != nil {
resp.PeersChangedPatch = patches
}
// Add the node itself, it might have changed, and particularly
// if there are no patches or changes, this is a self update.
tailnode, err := tailNode(node, mapRequest.Version, pol, m.cfg)
if err != nil {
return nil, err
}
resp.Node = tailnode
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress, messages...)
}
@ -375,71 +355,12 @@ func (m *Mapper) PeerChangedPatchResponse(
changed []*tailcfg.PeerChange,
pol *policy.ACLPolicy,
) ([]byte, error) {
m.mu.Lock()
defer m.mu.Unlock()
sendUpdate := false
// patch the internal map
for _, change := range changed {
if peer, ok := m.peers[uint64(change.NodeID)]; ok {
peer.ApplyPeerChange(change)
sendUpdate = true
} else {
log.Trace().Str("node", node.Hostname).Msgf("Node with ID %s is missing from mapper for Node %s, saving patch for when node is available", change.NodeID, node.Hostname)
p := patch{
timestamp: time.Now(),
change: change,
}
if patches, ok := m.patches[uint64(change.NodeID)]; ok {
m.patches[uint64(change.NodeID)] = append(patches, p)
} else {
m.patches[uint64(change.NodeID)] = []patch{p}
}
}
}
if !sendUpdate {
return nil, nil
}
resp := m.baseMapResponse()
resp.PeersChangedPatch = changed
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
}
// TODO(kradalby): We need some integration tests for this.
func (m *Mapper) PeerRemovedResponse(
mapRequest tailcfg.MapRequest,
node *types.Node,
removed []tailcfg.NodeID,
) ([]byte, error) {
m.mu.Lock()
defer m.mu.Unlock()
// Some nodes might have been removed already
// so we dont want to ask downstream to remove
// twice, than can cause a panic in tailscaled.
notYetRemoved := []tailcfg.NodeID{}
// remove from our internal map
for _, id := range removed {
if _, ok := m.peers[uint64(id)]; ok {
notYetRemoved = append(notYetRemoved, id)
}
delete(m.peers, uint64(id))
delete(m.patches, uint64(id))
}
resp := m.baseMapResponse()
resp.PeersRemoved = notYetRemoved
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
}
func (m *Mapper) marshalMapResponse(
mapRequest tailcfg.MapRequest,
resp *tailcfg.MapResponse,
@ -469,10 +390,8 @@ func (m *Mapper) marshalMapResponse(
switch {
case resp.Peers != nil && len(resp.Peers) > 0:
responseType = "full"
case isSelfUpdate(messages...):
case resp.Peers == nil && resp.PeersChanged == nil && resp.PeersChangedPatch == nil && resp.DERPMap == nil && !resp.KeepAlive:
responseType = "self"
case resp.Peers == nil && resp.PeersChanged == nil && resp.PeersChangedPatch == nil:
responseType = "lite"
case resp.PeersChanged != nil && len(resp.PeersChanged) > 0:
responseType = "changed"
case resp.PeersChangedPatch != nil && len(resp.PeersChangedPatch) > 0:
@ -496,11 +415,11 @@ func (m *Mapper) marshalMapResponse(
panic(err)
}
now := time.Now().UnixNano()
now := time.Now().Format("2006-01-02T15-04-05.999999999")
mapResponsePath := path.Join(
mPath,
fmt.Sprintf("%d-%s-%d-%s.json", now, m.uid, atomic.LoadUint64(&m.seq), responseType),
fmt.Sprintf("%s-%s-%d-%s.json", now, m.uid, atomic.LoadUint64(&m.seq), responseType),
)
log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath)
@ -574,7 +493,7 @@ func (m *Mapper) baseWithConfigMapResponse(
) (*tailcfg.MapResponse, error) {
resp := m.baseMapResponse()
tailnode, err := tailNode(node, capVer, pol, m.dnsCfg, m.baseDomain, m.randomClientPort)
tailnode, err := tailNode(node, capVer, pol, m.cfg)
if err != nil {
return nil, err
}
@ -582,7 +501,7 @@ func (m *Mapper) baseWithConfigMapResponse(
resp.DERPMap = m.derpMap
resp.Domain = m.baseDomain
resp.Domain = m.cfg.BaseDomain
// Do not instruct clients to collect services we do not
// support or do anything with them
@ -591,12 +510,26 @@ func (m *Mapper) baseWithConfigMapResponse(
resp.KeepAlive = false
resp.Debug = &tailcfg.Debug{
DisableLogTail: !m.logtail,
DisableLogTail: !m.cfg.LogTail.Enabled,
}
return &resp, nil
}
func (m *Mapper) ListPeers(nodeID types.NodeID) (types.Nodes, error) {
peers, err := m.db.ListPeers(nodeID)
if err != nil {
return nil, err
}
for _, peer := range peers {
online := m.isLikelyConnected[peer.ID]
peer.IsOnline = &online
}
return peers, nil
}
func nodeMapToList(nodes map[uint64]*types.Node) types.Nodes {
ret := make(types.Nodes, 0)
@ -612,42 +545,41 @@ func nodeMapToList(nodes map[uint64]*types.Node) types.Nodes {
func appendPeerChanges(
resp *tailcfg.MapResponse,
fullChange bool,
pol *policy.ACLPolicy,
node *types.Node,
capVer tailcfg.CapabilityVersion,
peers types.Nodes,
changed types.Nodes,
baseDomain string,
dnsCfg *tailcfg.DNSConfig,
randomClientPort bool,
cfg *types.Config,
) error {
fullChange := len(peers) == len(changed)
rules, sshPolicy, err := policy.GenerateFilterAndSSHRules(
pol,
node,
peers,
)
packetFilter, err := pol.CompileFilterRules(append(peers, node))
if err != nil {
return err
}
sshPolicy, err := pol.CompileSSHPolicy(node, peers)
if err != nil {
return err
}
// If there are filter rules present, see if there are any nodes that cannot
// access eachother at all and remove them from the peers.
if len(rules) > 0 {
changed = policy.FilterNodesByACL(node, changed, rules)
if len(packetFilter) > 0 {
changed = policy.FilterNodesByACL(node, changed, packetFilter)
}
profiles := generateUserProfiles(node, changed, baseDomain)
profiles := generateUserProfiles(node, changed, cfg.BaseDomain)
dnsConfig := generateDNSConfig(
dnsCfg,
baseDomain,
cfg.DNSConfig,
cfg.BaseDomain,
node,
peers,
)
tailPeers, err := tailNodes(changed, capVer, pol, dnsCfg, baseDomain, randomClientPort)
tailPeers, err := tailNodes(changed, capVer, pol, cfg)
if err != nil {
return err
}
@ -663,19 +595,9 @@ func appendPeerChanges(
resp.PeersChanged = tailPeers
}
resp.DNSConfig = dnsConfig
resp.PacketFilter = policy.ReduceFilterRules(node, rules)
resp.PacketFilter = policy.ReduceFilterRules(node, packetFilter)
resp.UserProfiles = profiles
resp.SSHPolicy = sshPolicy
return nil
}
func isSelfUpdate(messages ...string) bool {
for _, message := range messages {
if strings.Contains(message, types.SelfUpdateIdentifier) {
return true
}
}
return false
}

View file

@ -331,13 +331,10 @@ func Test_fullMapResponse(t *testing.T) {
node *types.Node
peers types.Nodes
baseDomain string
dnsConfig *tailcfg.DNSConfig
derpMap *tailcfg.DERPMap
logtail bool
randomClientPort bool
want *tailcfg.MapResponse
wantErr bool
derpMap *tailcfg.DERPMap
cfg *types.Config
want *tailcfg.MapResponse
wantErr bool
}{
// {
// name: "empty-node",
@ -349,15 +346,17 @@ func Test_fullMapResponse(t *testing.T) {
// wantErr: true,
// },
{
name: "no-pol-no-peers-map-response",
pol: &policy.ACLPolicy{},
node: mini,
peers: types.Nodes{},
baseDomain: "",
dnsConfig: &tailcfg.DNSConfig{},
derpMap: &tailcfg.DERPMap{},
logtail: false,
randomClientPort: false,
name: "no-pol-no-peers-map-response",
pol: &policy.ACLPolicy{},
node: mini,
peers: types.Nodes{},
derpMap: &tailcfg.DERPMap{},
cfg: &types.Config{
BaseDomain: "",
DNSConfig: &tailcfg.DNSConfig{},
LogTail: types.LogTailConfig{Enabled: false},
RandomizeClientPort: false,
},
want: &tailcfg.MapResponse{
Node: tailMini,
KeepAlive: false,
@ -383,11 +382,13 @@ func Test_fullMapResponse(t *testing.T) {
peers: types.Nodes{
peer1,
},
baseDomain: "",
dnsConfig: &tailcfg.DNSConfig{},
derpMap: &tailcfg.DERPMap{},
logtail: false,
randomClientPort: false,
derpMap: &tailcfg.DERPMap{},
cfg: &types.Config{
BaseDomain: "",
DNSConfig: &tailcfg.DNSConfig{},
LogTail: types.LogTailConfig{Enabled: false},
RandomizeClientPort: false,
},
want: &tailcfg.MapResponse{
KeepAlive: false,
Node: tailMini,
@ -424,11 +425,13 @@ func Test_fullMapResponse(t *testing.T) {
peer1,
peer2,
},
baseDomain: "",
dnsConfig: &tailcfg.DNSConfig{},
derpMap: &tailcfg.DERPMap{},
logtail: false,
randomClientPort: false,
derpMap: &tailcfg.DERPMap{},
cfg: &types.Config{
BaseDomain: "",
DNSConfig: &tailcfg.DNSConfig{},
LogTail: types.LogTailConfig{Enabled: false},
RandomizeClientPort: false,
},
want: &tailcfg.MapResponse{
KeepAlive: false,
Node: tailMini,
@ -463,17 +466,15 @@ func Test_fullMapResponse(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mappy := NewMapper(
tt.node,
tt.peers,
nil,
tt.cfg,
tt.derpMap,
tt.baseDomain,
tt.dnsConfig,
tt.logtail,
tt.randomClientPort,
nil,
)
got, err := mappy.fullMapResponse(
tt.node,
tt.peers,
tt.pol,
0,
)

View file

@ -3,12 +3,10 @@ package mapper
import (
"fmt"
"net/netip"
"strconv"
"time"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/samber/lo"
"tailscale.com/tailcfg"
)
@ -17,9 +15,7 @@ func tailNodes(
nodes types.Nodes,
capVer tailcfg.CapabilityVersion,
pol *policy.ACLPolicy,
dnsConfig *tailcfg.DNSConfig,
baseDomain string,
randomClientPort bool,
cfg *types.Config,
) ([]*tailcfg.Node, error) {
tNodes := make([]*tailcfg.Node, len(nodes))
@ -28,9 +24,7 @@ func tailNodes(
node,
capVer,
pol,
dnsConfig,
baseDomain,
randomClientPort,
cfg,
)
if err != nil {
return nil, err
@ -48,9 +42,7 @@ func tailNode(
node *types.Node,
capVer tailcfg.CapabilityVersion,
pol *policy.ACLPolicy,
dnsConfig *tailcfg.DNSConfig,
baseDomain string,
randomClientPort bool,
cfg *types.Config,
) (*tailcfg.Node, error) {
addrs := node.IPAddresses.Prefixes()
@ -85,7 +77,7 @@ func tailNode(
keyExpiry = time.Time{}
}
hostname, err := node.GetFQDN(dnsConfig, baseDomain)
hostname, err := node.GetFQDN(cfg.DNSConfig, cfg.BaseDomain)
if err != nil {
return nil, fmt.Errorf("tailNode, failed to create FQDN: %s", err)
}
@ -94,12 +86,10 @@ func tailNode(
tags = lo.Uniq(append(tags, node.ForcedTags...))
tNode := tailcfg.Node{
ID: tailcfg.NodeID(node.ID), // this is the actual ID
StableID: tailcfg.StableNodeID(
strconv.FormatUint(node.ID, util.Base10),
), // in headscale, unlike tailcontrol server, IDs are permanent
Name: hostname,
Cap: capVer,
ID: tailcfg.NodeID(node.ID), // this is the actual ID
StableID: node.ID.StableID(),
Name: hostname,
Cap: capVer,
User: tailcfg.UserID(node.UserID),
@ -133,7 +123,7 @@ func tailNode(
tailcfg.CapabilitySSH: []tailcfg.RawMessage{},
}
if randomClientPort {
if cfg.RandomizeClientPort {
tNode.CapMap[tailcfg.NodeAttrRandomizeClientPort] = []tailcfg.RawMessage{}
}
} else {
@ -143,7 +133,7 @@ func tailNode(
tailcfg.CapabilitySSH,
}
if randomClientPort {
if cfg.RandomizeClientPort {
tNode.Capabilities = append(tNode.Capabilities, tailcfg.NodeAttrRandomizeClientPort)
}
}

View file

@ -182,13 +182,16 @@ func TestTailNode(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &types.Config{
BaseDomain: tt.baseDomain,
DNSConfig: tt.dnsConfig,
RandomizeClientPort: false,
}
got, err := tailNode(
tt.node,
0,
tt.pol,
tt.dnsConfig,
tt.baseDomain,
false,
cfg,
)
if (err != nil) != tt.wantErr {