improve testing of route failover logic

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2024-04-10 15:35:09 +02:00 committed by Juan Font
parent bf4fd078fc
commit 1704977e76
11 changed files with 518 additions and 143 deletions

View file

@ -4,11 +4,13 @@ import (
"errors"
"fmt"
"net/netip"
"sort"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"tailscale.com/util/set"
)
var ErrRouteIsNotAvailable = errors.New("route is not available")
@ -402,11 +404,10 @@ func SaveNodeRoutes(tx *gorm.DB, node *types.Node) (bool, error) {
return sendUpdate, nil
}
// FailoverRouteIfAvailable takes a node and checks if the node's route
// currently have a functioning host that exposes the network.
// If it does not, it is failed over to another suitable route if there
// is one.
func FailoverRouteIfAvailable(
// FailoverNodeRoutesIfNeccessary takes a node and checks if the node's route
// need to be failed over to another host.
// If needed, the failover will be attempted.
func FailoverNodeRoutesIfNeccessary(
tx *gorm.DB,
isConnected types.NodeConnectedMap,
node *types.Node,
@ -416,8 +417,12 @@ func FailoverRouteIfAvailable(
return nil, nil
}
var changedNodes []types.NodeID
log.Trace().Msgf("NODE ROUTES: %d", len(nodeRoutes))
changedNodes := make(set.Set[types.NodeID])
nodeRouteLoop:
for _, nodeRoute := range nodeRoutes {
log.Trace().Msgf("NODE ROUTE: %d", nodeRoute.ID)
routes, err := getRoutesByPrefix(tx, netip.Prefix(nodeRoute.Prefix))
if err != nil {
return nil, fmt.Errorf("getting routes by prefix: %w", err)
@ -427,29 +432,37 @@ func FailoverRouteIfAvailable(
if route.IsPrimary {
// if we have a primary route, and the node is connected
// nothing needs to be done.
if isConnected[route.Node.ID] {
return nil, nil
if conn, ok := isConnected[route.Node.ID]; conn && ok {
continue nodeRouteLoop
}
// if not, we need to failover the route
failover := failoverRoute(isConnected, &route, routes)
if failover != nil {
failover.save(tx)
err := failover.save(tx)
if err != nil {
return nil, fmt.Errorf("saving failover routes: %w", err)
}
changedNodes = append(changedNodes, failover.old.Node.ID, failover.new.Node.ID)
changedNodes.Add(failover.old.Node.ID)
changedNodes.Add(failover.new.Node.ID)
continue nodeRouteLoop
}
}
}
}
chng := changedNodes.Slice()
sort.SliceStable(chng, func(i, j int) bool {
return chng[i] < chng[j]
})
if len(changedNodes) != 0 {
return &types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: changedNodes,
Message: "called from db.FailoverRouteIfAvailable",
ChangeNodes: chng,
Message: "called from db.FailoverNodeRoutesIfNeccessary",
}, nil
}

View file

@ -7,9 +7,9 @@ import (
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/stretchr/testify/assert"
"gopkg.in/check.v1"
"gorm.io/gorm"
"tailscale.com/tailcfg"
@ -270,6 +270,370 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
}
var ipp = func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) }
var n = func(nid types.NodeID) types.Node {
return types.Node{ID: nid}
}
var np = func(nid types.NodeID) *types.Node {
no := n(nid)
return &no
}
var r = func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) types.Route {
return types.Route{
Model: gorm.Model{
ID: id,
},
Node: n(nid),
Prefix: prefix,
Enabled: enabled,
IsPrimary: primary,
}
}
var rp = func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) *types.Route {
ro := r(id, nid, prefix, enabled, primary)
return &ro
}
func dbForTest(t *testing.T, testName string) *HSDatabase {
t.Helper()
tmpDir, err := os.MkdirTemp("", testName)
if err != nil {
t.Fatalf("creating tempdir: %s", err)
}
dbPath := tmpDir + "/headscale_test.db"
db, err = NewHeadscaleDatabase(
types.DatabaseConfig{
Type: "sqlite3",
Sqlite: types.SqliteConfig{
Path: dbPath,
},
},
"",
)
if err != nil {
t.Fatalf("setting up database: %s", err)
}
t.Logf("database set up at: %s", dbPath)
return db
}
func TestFailoverNodeRoutesIfNeccessary(t *testing.T) {
su := func(nids ...types.NodeID) *types.StateUpdate {
return &types.StateUpdate{
ChangeNodes: nids,
}
}
tests := []struct {
name string
nodes types.Nodes
routes types.Routes
isConnected []types.NodeConnectedMap
want []*types.StateUpdate
wantErr bool
}{
{
name: "n1-down-n2-down-n1-up",
nodes: types.Nodes{
np(1),
np(2),
np(1),
},
routes: types.Routes{
r(1, 1, ipp("10.0.0.0/24"), true, true),
r(2, 2, ipp("10.0.0.0/24"), true, false),
},
isConnected: []types.NodeConnectedMap{
// n1 goes down
{
1: false,
2: true,
},
// n2 goes down
{
1: false,
2: false,
},
// n1 comes up
{
1: true,
2: false,
},
},
want: []*types.StateUpdate{
// route changes from 1 -> 2
su(1, 2),
// both down, no change
nil,
// route changes from 2 -> 1
su(1, 2),
},
},
{
name: "n1-recon-n2-down-n1-recon-n2-up",
nodes: types.Nodes{
np(1),
np(2),
np(1),
np(2),
},
routes: types.Routes{
r(1, 1, ipp("10.0.0.0/24"), true, true),
r(2, 2, ipp("10.0.0.0/24"), true, false),
},
isConnected: []types.NodeConnectedMap{
// n1 up recon = noop
{
1: true,
2: true,
},
// n2 goes down
{
1: true,
2: false,
},
// n1 up recon = noop
{
1: true,
2: false,
},
// n2 comes back up
{
1: true,
2: false,
},
},
want: []*types.StateUpdate{
nil,
nil,
nil,
nil,
},
},
{
name: "n1-recon-n2-down-n1-recon-n2-up",
nodes: types.Nodes{
np(1),
np(1),
np(3),
np(3),
np(2),
np(1),
},
routes: types.Routes{
r(1, 1, ipp("10.0.0.0/24"), true, true),
r(2, 2, ipp("10.0.0.0/24"), true, false),
r(3, 3, ipp("10.0.0.0/24"), true, false),
},
isConnected: []types.NodeConnectedMap{
// n1 goes down
{
1: false,
2: false,
3: true,
},
// n1 comes up
{
1: true,
2: false,
3: true,
},
// n3 goes down
{
1: true,
2: false,
3: false,
},
// n3 comes up
{
1: true,
2: false,
3: true,
},
// n2 comes up
{
1: true,
2: true,
3: true,
},
// n1 goes down
{
1: false,
2: true,
3: true,
},
},
want: []*types.StateUpdate{
su(1, 3), // n1 -> n3
nil,
su(1, 3), // n3 -> n1
nil,
nil,
su(1, 2), // n1 -> n2
},
},
{
name: "n1-recon-n2-dis-n3-take",
nodes: types.Nodes{
np(1),
np(3),
},
routes: types.Routes{
r(1, 1, ipp("10.0.0.0/24"), true, true),
r(2, 2, ipp("10.0.0.0/24"), false, false),
r(3, 3, ipp("10.0.0.0/24"), true, false),
},
isConnected: []types.NodeConnectedMap{
// n1 goes down
{
1: false,
2: true,
3: true,
},
// n3 goes down
{
1: false,
2: true,
3: false,
},
},
want: []*types.StateUpdate{
su(1, 3), // n1 -> n3
nil,
},
},
{
name: "multi-n1-oneforeach-n2-n3",
nodes: types.Nodes{
np(1),
},
routes: types.Routes{
r(1, 1, ipp("10.0.0.0/24"), true, true),
r(4, 1, ipp("10.1.0.0/24"), true, true),
r(2, 2, ipp("10.0.0.0/24"), true, false),
r(3, 3, ipp("10.1.0.0/24"), true, false),
},
isConnected: []types.NodeConnectedMap{
// n1 goes down
{
1: false,
2: true,
3: true,
},
},
want: []*types.StateUpdate{
su(1, 2, 3), // n1 -> n2,n3
},
},
{
name: "multi-n1-onefor-n2-disabled-n3",
nodes: types.Nodes{
np(1),
},
routes: types.Routes{
r(1, 1, ipp("10.0.0.0/24"), true, true),
r(4, 1, ipp("10.1.0.0/24"), true, true),
r(2, 2, ipp("10.0.0.0/24"), true, false),
r(3, 3, ipp("10.1.0.0/24"), false, false),
},
isConnected: []types.NodeConnectedMap{
// n1 goes down
{
1: false,
2: true,
3: true,
},
},
want: []*types.StateUpdate{
su(1, 2), // n1 -> n2, n3 is not enabled
},
},
{
name: "multi-n1-onefor-n2-offline-n3",
nodes: types.Nodes{
np(1),
},
routes: types.Routes{
r(1, 1, ipp("10.0.0.0/24"), true, true),
r(4, 1, ipp("10.1.0.0/24"), true, true),
r(2, 2, ipp("10.0.0.0/24"), true, false),
r(3, 3, ipp("10.1.0.0/24"), true, false),
},
isConnected: []types.NodeConnectedMap{
// n1 goes down
{
1: false,
2: true,
3: false,
},
},
want: []*types.StateUpdate{
su(1, 2), // n1 -> n2, n3 is offline
},
},
{
name: "multi-n2-back-to-multi-n1",
nodes: types.Nodes{
np(1),
},
routes: types.Routes{
r(1, 1, ipp("10.0.0.0/24"), true, false),
r(4, 1, ipp("10.1.0.0/24"), true, true),
r(2, 2, ipp("10.0.0.0/24"), true, true),
r(3, 3, ipp("10.1.0.0/24"), true, false),
},
isConnected: []types.NodeConnectedMap{
// n1 goes down
{
1: true,
2: false,
3: true,
},
},
want: []*types.StateUpdate{
su(1, 2), // n2 -> n1
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if (len(tt.isConnected) != len(tt.want)) && len(tt.want) != len(tt.nodes) {
t.Fatalf("nodes (%d), isConnected updates (%d), wants (%d) must be equal", len(tt.nodes), len(tt.isConnected), len(tt.want))
}
db := dbForTest(t, tt.name)
for _, route := range tt.routes {
if err := db.DB.Save(&route).Error; err != nil {
t.Fatalf("failed to create route: %s", err)
}
}
for step := range len(tt.isConnected) {
node := tt.nodes[step]
isConnected := tt.isConnected[step]
want := tt.want[step]
got, err := Write(db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
return FailoverNodeRoutesIfNeccessary(tx, isConnected, node)
})
if (err != nil) != tt.wantErr {
t.Errorf("failoverRoute() error = %v, wantErr %v", err, tt.wantErr)
return
}
if diff := cmp.Diff(want, got, cmpopts.IgnoreFields(types.StateUpdate{}, "Type", "Message")); diff != "" {
t.Errorf("failoverRoute() unexpected result (-want +got):\n%s", diff)
}
}
})
}
}
func TestFailoverRouteTx(t *testing.T) {
tests := []struct {
@ -637,19 +1001,7 @@ func TestFailoverRouteTx(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "failover-db-test")
assert.NoError(t, err)
db, err = NewHeadscaleDatabase(
types.DatabaseConfig{
Type: "sqlite3",
Sqlite: types.SqliteConfig{
Path: tmpDir + "/headscale_test.db",
},
},
"",
)
assert.NoError(t, err)
db := dbForTest(t, tt.name)
for _, route := range tt.routes {
if err := db.DB.Save(&route).Error; err != nil {