improve testing of route failover logic
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
bf4fd078fc
commit
1704977e76
11 changed files with 518 additions and 143 deletions
|
@ -4,11 +4,13 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sort"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/util/set"
|
||||
)
|
||||
|
||||
var ErrRouteIsNotAvailable = errors.New("route is not available")
|
||||
|
@ -402,11 +404,10 @@ func SaveNodeRoutes(tx *gorm.DB, node *types.Node) (bool, error) {
|
|||
return sendUpdate, nil
|
||||
}
|
||||
|
||||
// FailoverRouteIfAvailable takes a node and checks if the node's route
|
||||
// currently have a functioning host that exposes the network.
|
||||
// If it does not, it is failed over to another suitable route if there
|
||||
// is one.
|
||||
func FailoverRouteIfAvailable(
|
||||
// FailoverNodeRoutesIfNeccessary takes a node and checks if the node's route
|
||||
// need to be failed over to another host.
|
||||
// If needed, the failover will be attempted.
|
||||
func FailoverNodeRoutesIfNeccessary(
|
||||
tx *gorm.DB,
|
||||
isConnected types.NodeConnectedMap,
|
||||
node *types.Node,
|
||||
|
@ -416,8 +417,12 @@ func FailoverRouteIfAvailable(
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
var changedNodes []types.NodeID
|
||||
log.Trace().Msgf("NODE ROUTES: %d", len(nodeRoutes))
|
||||
changedNodes := make(set.Set[types.NodeID])
|
||||
|
||||
nodeRouteLoop:
|
||||
for _, nodeRoute := range nodeRoutes {
|
||||
log.Trace().Msgf("NODE ROUTE: %d", nodeRoute.ID)
|
||||
routes, err := getRoutesByPrefix(tx, netip.Prefix(nodeRoute.Prefix))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting routes by prefix: %w", err)
|
||||
|
@ -427,29 +432,37 @@ func FailoverRouteIfAvailable(
|
|||
if route.IsPrimary {
|
||||
// if we have a primary route, and the node is connected
|
||||
// nothing needs to be done.
|
||||
if isConnected[route.Node.ID] {
|
||||
return nil, nil
|
||||
if conn, ok := isConnected[route.Node.ID]; conn && ok {
|
||||
continue nodeRouteLoop
|
||||
}
|
||||
|
||||
// if not, we need to failover the route
|
||||
failover := failoverRoute(isConnected, &route, routes)
|
||||
if failover != nil {
|
||||
failover.save(tx)
|
||||
err := failover.save(tx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("saving failover routes: %w", err)
|
||||
}
|
||||
|
||||
changedNodes = append(changedNodes, failover.old.Node.ID, failover.new.Node.ID)
|
||||
changedNodes.Add(failover.old.Node.ID)
|
||||
changedNodes.Add(failover.new.Node.ID)
|
||||
|
||||
continue nodeRouteLoop
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
chng := changedNodes.Slice()
|
||||
sort.SliceStable(chng, func(i, j int) bool {
|
||||
return chng[i] < chng[j]
|
||||
})
|
||||
|
||||
if len(changedNodes) != 0 {
|
||||
return &types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: changedNodes,
|
||||
Message: "called from db.FailoverRouteIfAvailable",
|
||||
ChangeNodes: chng,
|
||||
Message: "called from db.FailoverNodeRoutesIfNeccessary",
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -7,9 +7,9 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gopkg.in/check.v1"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
|
@ -270,6 +270,370 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
|
|||
}
|
||||
|
||||
var ipp = func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) }
|
||||
var n = func(nid types.NodeID) types.Node {
|
||||
return types.Node{ID: nid}
|
||||
}
|
||||
var np = func(nid types.NodeID) *types.Node {
|
||||
no := n(nid)
|
||||
return &no
|
||||
}
|
||||
var r = func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) types.Route {
|
||||
return types.Route{
|
||||
Model: gorm.Model{
|
||||
ID: id,
|
||||
},
|
||||
Node: n(nid),
|
||||
Prefix: prefix,
|
||||
Enabled: enabled,
|
||||
IsPrimary: primary,
|
||||
}
|
||||
}
|
||||
var rp = func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) *types.Route {
|
||||
ro := r(id, nid, prefix, enabled, primary)
|
||||
return &ro
|
||||
}
|
||||
|
||||
func dbForTest(t *testing.T, testName string) *HSDatabase {
|
||||
t.Helper()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", testName)
|
||||
if err != nil {
|
||||
t.Fatalf("creating tempdir: %s", err)
|
||||
}
|
||||
|
||||
dbPath := tmpDir + "/headscale_test.db"
|
||||
|
||||
db, err = NewHeadscaleDatabase(
|
||||
types.DatabaseConfig{
|
||||
Type: "sqlite3",
|
||||
Sqlite: types.SqliteConfig{
|
||||
Path: dbPath,
|
||||
},
|
||||
},
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("setting up database: %s", err)
|
||||
}
|
||||
|
||||
t.Logf("database set up at: %s", dbPath)
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func TestFailoverNodeRoutesIfNeccessary(t *testing.T) {
|
||||
su := func(nids ...types.NodeID) *types.StateUpdate {
|
||||
return &types.StateUpdate{
|
||||
ChangeNodes: nids,
|
||||
}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
nodes types.Nodes
|
||||
routes types.Routes
|
||||
isConnected []types.NodeConnectedMap
|
||||
want []*types.StateUpdate
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "n1-down-n2-down-n1-up",
|
||||
nodes: types.Nodes{
|
||||
np(1),
|
||||
np(2),
|
||||
np(1),
|
||||
},
|
||||
routes: types.Routes{
|
||||
r(1, 1, ipp("10.0.0.0/24"), true, true),
|
||||
r(2, 2, ipp("10.0.0.0/24"), true, false),
|
||||
},
|
||||
isConnected: []types.NodeConnectedMap{
|
||||
// n1 goes down
|
||||
{
|
||||
1: false,
|
||||
2: true,
|
||||
},
|
||||
// n2 goes down
|
||||
{
|
||||
1: false,
|
||||
2: false,
|
||||
},
|
||||
// n1 comes up
|
||||
{
|
||||
1: true,
|
||||
2: false,
|
||||
},
|
||||
},
|
||||
want: []*types.StateUpdate{
|
||||
// route changes from 1 -> 2
|
||||
su(1, 2),
|
||||
// both down, no change
|
||||
nil,
|
||||
// route changes from 2 -> 1
|
||||
su(1, 2),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "n1-recon-n2-down-n1-recon-n2-up",
|
||||
nodes: types.Nodes{
|
||||
np(1),
|
||||
np(2),
|
||||
np(1),
|
||||
np(2),
|
||||
},
|
||||
routes: types.Routes{
|
||||
r(1, 1, ipp("10.0.0.0/24"), true, true),
|
||||
r(2, 2, ipp("10.0.0.0/24"), true, false),
|
||||
},
|
||||
isConnected: []types.NodeConnectedMap{
|
||||
// n1 up recon = noop
|
||||
{
|
||||
1: true,
|
||||
2: true,
|
||||
},
|
||||
// n2 goes down
|
||||
{
|
||||
1: true,
|
||||
2: false,
|
||||
},
|
||||
// n1 up recon = noop
|
||||
{
|
||||
1: true,
|
||||
2: false,
|
||||
},
|
||||
// n2 comes back up
|
||||
{
|
||||
1: true,
|
||||
2: false,
|
||||
},
|
||||
},
|
||||
want: []*types.StateUpdate{
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "n1-recon-n2-down-n1-recon-n2-up",
|
||||
nodes: types.Nodes{
|
||||
np(1),
|
||||
np(1),
|
||||
np(3),
|
||||
np(3),
|
||||
np(2),
|
||||
np(1),
|
||||
},
|
||||
routes: types.Routes{
|
||||
r(1, 1, ipp("10.0.0.0/24"), true, true),
|
||||
r(2, 2, ipp("10.0.0.0/24"), true, false),
|
||||
r(3, 3, ipp("10.0.0.0/24"), true, false),
|
||||
},
|
||||
isConnected: []types.NodeConnectedMap{
|
||||
// n1 goes down
|
||||
{
|
||||
1: false,
|
||||
2: false,
|
||||
3: true,
|
||||
},
|
||||
// n1 comes up
|
||||
{
|
||||
1: true,
|
||||
2: false,
|
||||
3: true,
|
||||
},
|
||||
// n3 goes down
|
||||
{
|
||||
1: true,
|
||||
2: false,
|
||||
3: false,
|
||||
},
|
||||
// n3 comes up
|
||||
{
|
||||
1: true,
|
||||
2: false,
|
||||
3: true,
|
||||
},
|
||||
// n2 comes up
|
||||
{
|
||||
1: true,
|
||||
2: true,
|
||||
3: true,
|
||||
},
|
||||
// n1 goes down
|
||||
{
|
||||
1: false,
|
||||
2: true,
|
||||
3: true,
|
||||
},
|
||||
},
|
||||
want: []*types.StateUpdate{
|
||||
su(1, 3), // n1 -> n3
|
||||
nil,
|
||||
su(1, 3), // n3 -> n1
|
||||
nil,
|
||||
nil,
|
||||
su(1, 2), // n1 -> n2
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "n1-recon-n2-dis-n3-take",
|
||||
nodes: types.Nodes{
|
||||
np(1),
|
||||
np(3),
|
||||
},
|
||||
routes: types.Routes{
|
||||
r(1, 1, ipp("10.0.0.0/24"), true, true),
|
||||
r(2, 2, ipp("10.0.0.0/24"), false, false),
|
||||
r(3, 3, ipp("10.0.0.0/24"), true, false),
|
||||
},
|
||||
isConnected: []types.NodeConnectedMap{
|
||||
// n1 goes down
|
||||
{
|
||||
1: false,
|
||||
2: true,
|
||||
3: true,
|
||||
},
|
||||
// n3 goes down
|
||||
{
|
||||
1: false,
|
||||
2: true,
|
||||
3: false,
|
||||
},
|
||||
},
|
||||
want: []*types.StateUpdate{
|
||||
su(1, 3), // n1 -> n3
|
||||
nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multi-n1-oneforeach-n2-n3",
|
||||
nodes: types.Nodes{
|
||||
np(1),
|
||||
},
|
||||
routes: types.Routes{
|
||||
r(1, 1, ipp("10.0.0.0/24"), true, true),
|
||||
r(4, 1, ipp("10.1.0.0/24"), true, true),
|
||||
r(2, 2, ipp("10.0.0.0/24"), true, false),
|
||||
r(3, 3, ipp("10.1.0.0/24"), true, false),
|
||||
},
|
||||
isConnected: []types.NodeConnectedMap{
|
||||
// n1 goes down
|
||||
{
|
||||
1: false,
|
||||
2: true,
|
||||
3: true,
|
||||
},
|
||||
},
|
||||
want: []*types.StateUpdate{
|
||||
su(1, 2, 3), // n1 -> n2,n3
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multi-n1-onefor-n2-disabled-n3",
|
||||
nodes: types.Nodes{
|
||||
np(1),
|
||||
},
|
||||
routes: types.Routes{
|
||||
r(1, 1, ipp("10.0.0.0/24"), true, true),
|
||||
r(4, 1, ipp("10.1.0.0/24"), true, true),
|
||||
r(2, 2, ipp("10.0.0.0/24"), true, false),
|
||||
r(3, 3, ipp("10.1.0.0/24"), false, false),
|
||||
},
|
||||
isConnected: []types.NodeConnectedMap{
|
||||
// n1 goes down
|
||||
{
|
||||
1: false,
|
||||
2: true,
|
||||
3: true,
|
||||
},
|
||||
},
|
||||
want: []*types.StateUpdate{
|
||||
su(1, 2), // n1 -> n2, n3 is not enabled
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multi-n1-onefor-n2-offline-n3",
|
||||
nodes: types.Nodes{
|
||||
np(1),
|
||||
},
|
||||
routes: types.Routes{
|
||||
r(1, 1, ipp("10.0.0.0/24"), true, true),
|
||||
r(4, 1, ipp("10.1.0.0/24"), true, true),
|
||||
r(2, 2, ipp("10.0.0.0/24"), true, false),
|
||||
r(3, 3, ipp("10.1.0.0/24"), true, false),
|
||||
},
|
||||
isConnected: []types.NodeConnectedMap{
|
||||
// n1 goes down
|
||||
{
|
||||
1: false,
|
||||
2: true,
|
||||
3: false,
|
||||
},
|
||||
},
|
||||
want: []*types.StateUpdate{
|
||||
su(1, 2), // n1 -> n2, n3 is offline
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multi-n2-back-to-multi-n1",
|
||||
nodes: types.Nodes{
|
||||
np(1),
|
||||
},
|
||||
routes: types.Routes{
|
||||
r(1, 1, ipp("10.0.0.0/24"), true, false),
|
||||
r(4, 1, ipp("10.1.0.0/24"), true, true),
|
||||
r(2, 2, ipp("10.0.0.0/24"), true, true),
|
||||
r(3, 3, ipp("10.1.0.0/24"), true, false),
|
||||
},
|
||||
isConnected: []types.NodeConnectedMap{
|
||||
// n1 goes down
|
||||
{
|
||||
1: true,
|
||||
2: false,
|
||||
3: true,
|
||||
},
|
||||
},
|
||||
want: []*types.StateUpdate{
|
||||
su(1, 2), // n2 -> n1
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if (len(tt.isConnected) != len(tt.want)) && len(tt.want) != len(tt.nodes) {
|
||||
t.Fatalf("nodes (%d), isConnected updates (%d), wants (%d) must be equal", len(tt.nodes), len(tt.isConnected), len(tt.want))
|
||||
}
|
||||
|
||||
db := dbForTest(t, tt.name)
|
||||
|
||||
for _, route := range tt.routes {
|
||||
if err := db.DB.Save(&route).Error; err != nil {
|
||||
t.Fatalf("failed to create route: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
for step := range len(tt.isConnected) {
|
||||
node := tt.nodes[step]
|
||||
isConnected := tt.isConnected[step]
|
||||
want := tt.want[step]
|
||||
|
||||
got, err := Write(db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
|
||||
return FailoverNodeRoutesIfNeccessary(tx, isConnected, node)
|
||||
})
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("failoverRoute() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(want, got, cmpopts.IgnoreFields(types.StateUpdate{}, "Type", "Message")); diff != "" {
|
||||
t.Errorf("failoverRoute() unexpected result (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailoverRouteTx(t *testing.T) {
|
||||
tests := []struct {
|
||||
|
@ -637,19 +1001,7 @@ func TestFailoverRouteTx(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "failover-db-test")
|
||||
assert.NoError(t, err)
|
||||
|
||||
db, err = NewHeadscaleDatabase(
|
||||
types.DatabaseConfig{
|
||||
Type: "sqlite3",
|
||||
Sqlite: types.SqliteConfig{
|
||||
Path: tmpDir + "/headscale_test.db",
|
||||
},
|
||||
},
|
||||
"",
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
db := dbForTest(t, tt.name)
|
||||
|
||||
for _, route := range tt.routes {
|
||||
if err := db.DB.Save(&route).Error; err != nil {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue