Replace database locks with transactions (#1701)

This commits removes the locks used to guard data integrity for the
database and replaces them with Transactions, turns out that SQL had
a way to deal with this all along.

This reduces the complexity we had with multiple locks that might stack
or recurse (database, nofitifer, mapper). All notifications and state
updates are now triggered _after_ a database change.


Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2024-02-08 17:28:19 +01:00 committed by GitHub
parent cbf57e27a7
commit 83769ba715
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 1496 additions and 1128 deletions

View file

@ -7,7 +7,6 @@ import (
"time"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/stretchr/testify/assert"
@ -24,7 +23,7 @@ func (s *Suite) TestGetRoutes(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetNode("test", "test_get_route_node")
_, err = db.getNode("test", "test_get_route_node")
c.Assert(err, check.NotNil)
route, err := netip.ParsePrefix("10.0.0.0/24")
@ -42,7 +41,7 @@ func (s *Suite) TestGetRoutes(c *check.C) {
AuthKeyID: uint(pak.ID),
Hostinfo: &hostInfo,
}
db.db.Save(&node)
db.DB.Save(&node)
su, err := db.SaveNodeRoutes(&node)
c.Assert(err, check.IsNil)
@ -52,10 +51,11 @@ func (s *Suite) TestGetRoutes(c *check.C) {
c.Assert(err, check.IsNil)
c.Assert(len(advertisedRoutes), check.Equals, 1)
err = db.enableRoutes(&node, "192.168.0.0/24")
// TODO(kradalby): check state update
_, err = db.enableRoutes(&node, "192.168.0.0/24")
c.Assert(err, check.NotNil)
err = db.enableRoutes(&node, "10.0.0.0/24")
_, err = db.enableRoutes(&node, "10.0.0.0/24")
c.Assert(err, check.IsNil)
}
@ -66,7 +66,7 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetNode("test", "test_enable_route_node")
_, err = db.getNode("test", "test_enable_route_node")
c.Assert(err, check.NotNil)
route, err := netip.ParsePrefix(
@ -91,7 +91,7 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
AuthKeyID: uint(pak.ID),
Hostinfo: &hostInfo,
}
db.db.Save(&node)
db.DB.Save(&node)
sendUpdate, err := db.SaveNodeRoutes(&node)
c.Assert(err, check.IsNil)
@ -106,10 +106,10 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
c.Assert(err, check.IsNil)
c.Assert(len(noEnabledRoutes), check.Equals, 0)
err = db.enableRoutes(&node, "192.168.0.0/24")
_, err = db.enableRoutes(&node, "192.168.0.0/24")
c.Assert(err, check.NotNil)
err = db.enableRoutes(&node, "10.0.0.0/24")
_, err = db.enableRoutes(&node, "10.0.0.0/24")
c.Assert(err, check.IsNil)
enabledRoutes, err := db.GetEnabledRoutes(&node)
@ -117,14 +117,14 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
c.Assert(len(enabledRoutes), check.Equals, 1)
// Adding it twice will just let it pass through
err = db.enableRoutes(&node, "10.0.0.0/24")
_, err = db.enableRoutes(&node, "10.0.0.0/24")
c.Assert(err, check.IsNil)
enableRoutesAfterDoubleApply, err := db.GetEnabledRoutes(&node)
c.Assert(err, check.IsNil)
c.Assert(len(enableRoutesAfterDoubleApply), check.Equals, 1)
err = db.enableRoutes(&node, "150.0.10.0/25")
_, err = db.enableRoutes(&node, "150.0.10.0/25")
c.Assert(err, check.IsNil)
enabledRoutesWithAdditionalRoute, err := db.GetEnabledRoutes(&node)
@ -139,7 +139,7 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetNode("test", "test_enable_route_node")
_, err = db.getNode("test", "test_enable_route_node")
c.Assert(err, check.NotNil)
route, err := netip.ParsePrefix(
@ -163,16 +163,16 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
AuthKeyID: uint(pak.ID),
Hostinfo: &hostInfo1,
}
db.db.Save(&node1)
db.DB.Save(&node1)
sendUpdate, err := db.SaveNodeRoutes(&node1)
c.Assert(err, check.IsNil)
c.Assert(sendUpdate, check.Equals, false)
err = db.enableRoutes(&node1, route.String())
_, err = db.enableRoutes(&node1, route.String())
c.Assert(err, check.IsNil)
err = db.enableRoutes(&node1, route2.String())
_, err = db.enableRoutes(&node1, route2.String())
c.Assert(err, check.IsNil)
hostInfo2 := tailcfg.Hostinfo{
@ -186,13 +186,13 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
AuthKeyID: uint(pak.ID),
Hostinfo: &hostInfo2,
}
db.db.Save(&node2)
db.DB.Save(&node2)
sendUpdate, err = db.SaveNodeRoutes(&node2)
c.Assert(err, check.IsNil)
c.Assert(sendUpdate, check.Equals, false)
err = db.enableRoutes(&node2, route2.String())
_, err = db.enableRoutes(&node2, route2.String())
c.Assert(err, check.IsNil)
enabledRoutes1, err := db.GetEnabledRoutes(&node1)
@ -219,7 +219,7 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetNode("test", "test_enable_route_node")
_, err = db.getNode("test", "test_enable_route_node")
c.Assert(err, check.NotNil)
prefix, err := netip.ParsePrefix(
@ -246,22 +246,23 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
Hostinfo: &hostInfo1,
LastSeen: &now,
}
db.db.Save(&node1)
db.DB.Save(&node1)
sendUpdate, err := db.SaveNodeRoutes(&node1)
c.Assert(err, check.IsNil)
c.Assert(sendUpdate, check.Equals, false)
err = db.enableRoutes(&node1, prefix.String())
_, err = db.enableRoutes(&node1, prefix.String())
c.Assert(err, check.IsNil)
err = db.enableRoutes(&node1, prefix2.String())
_, err = db.enableRoutes(&node1, prefix2.String())
c.Assert(err, check.IsNil)
routes, err := db.GetNodeRoutes(&node1)
c.Assert(err, check.IsNil)
err = db.DeleteRoute(uint64(routes[0].ID))
// TODO(kradalby): check stateupdate
_, err = db.DeleteRoute(uint64(routes[0].ID), map[key.MachinePublic]bool{})
c.Assert(err, check.IsNil)
enabledRoutes1, err := db.GetEnabledRoutes(&node1)
@ -269,17 +270,9 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
c.Assert(len(enabledRoutes1), check.Equals, 1)
}
var ipp = func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) }
func TestFailoverRoute(t *testing.T) {
ipp := func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) }
// TODO(kradalby): Count/verify updates
var sink chan types.StateUpdate
go func() {
for range sink {
}
}()
machineKeys := []key.MachinePublic{
key.NewMachine().Public(),
key.NewMachine().Public(),
@ -291,6 +284,7 @@ func TestFailoverRoute(t *testing.T) {
name string
failingRoute types.Route
routes types.Routes
isConnected map[key.MachinePublic]bool
want []key.MachinePublic
wantErr bool
}{
@ -397,6 +391,10 @@ func TestFailoverRoute(t *testing.T) {
Enabled: true,
},
},
isConnected: map[key.MachinePublic]bool{
machineKeys[0]: false,
machineKeys[1]: true,
},
want: []key.MachinePublic{
machineKeys[0],
machineKeys[1],
@ -491,6 +489,11 @@ func TestFailoverRoute(t *testing.T) {
Enabled: true,
},
},
isConnected: map[key.MachinePublic]bool{
machineKeys[0]: true,
machineKeys[1]: true,
machineKeys[2]: true,
},
want: []key.MachinePublic{
machineKeys[1],
machineKeys[0],
@ -535,6 +538,10 @@ func TestFailoverRoute(t *testing.T) {
Enabled: true,
},
},
isConnected: map[key.MachinePublic]bool{
machineKeys[0]: true,
machineKeys[3]: false,
},
want: nil,
wantErr: false,
},
@ -587,6 +594,11 @@ func TestFailoverRoute(t *testing.T) {
Enabled: true,
},
},
isConnected: map[key.MachinePublic]bool{
machineKeys[0]: false,
machineKeys[1]: true,
machineKeys[3]: false,
},
want: []key.MachinePublic{
machineKeys[0],
machineKeys[1],
@ -641,13 +653,10 @@ func TestFailoverRoute(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "failover-db-test")
assert.NoError(t, err)
notif := notifier.NewNotifier()
db, err = NewHeadscaleDatabase(
"sqlite3",
tmpDir+"/headscale_test.db",
false,
notif,
[]netip.Prefix{
netip.MustParsePrefix("10.27.0.0/23"),
},
@ -655,23 +664,15 @@ func TestFailoverRoute(t *testing.T) {
)
assert.NoError(t, err)
// Pretend that all the nodes are connected to control
for idx, key := range machineKeys {
// Pretend one node is offline
if idx == 3 {
continue
}
notif.AddNode(key, sink)
}
for _, route := range tt.routes {
if err := db.db.Save(&route).Error; err != nil {
if err := db.DB.Save(&route).Error; err != nil {
t.Fatalf("failed to create route: %s", err)
}
}
got, err := db.failoverRoute(&tt.failingRoute)
got, err := Write(db.DB, func(tx *gorm.DB) ([]key.MachinePublic, error) {
return failoverRoute(tx, tt.isConnected, &tt.failingRoute)
})
if (err != nil) != tt.wantErr {
t.Errorf("failoverRoute() error = %v, wantErr %v", err, tt.wantErr)
@ -685,3 +686,231 @@ func TestFailoverRoute(t *testing.T) {
})
}
}
// func TestDisableRouteFailover(t *testing.T) {
// machineKeys := []key.MachinePublic{
// key.NewMachine().Public(),
// key.NewMachine().Public(),
// key.NewMachine().Public(),
// key.NewMachine().Public(),
// }
// tests := []struct {
// name string
// nodes types.Nodes
// routeID uint64
// isConnected map[key.MachinePublic]bool
// wantMachineKey key.MachinePublic
// wantErr string
// }{
// {
// name: "single-route",
// nodes: types.Nodes{
// &types.Node{
// ID: 0,
// MachineKey: machineKeys[0],
// Routes: []types.Route{
// {
// Model: gorm.Model{
// ID: 1,
// },
// Prefix: ipp("10.0.0.0/24"),
// Node: types.Node{
// MachineKey: machineKeys[0],
// },
// IsPrimary: true,
// },
// },
// Hostinfo: &tailcfg.Hostinfo{
// RoutableIPs: []netip.Prefix{
// netip.MustParsePrefix("10.0.0.0/24"),
// },
// },
// },
// },
// routeID: 1,
// wantMachineKey: machineKeys[0],
// },
// {
// name: "failover-simple",
// nodes: types.Nodes{
// &types.Node{
// ID: 0,
// MachineKey: machineKeys[0],
// Routes: []types.Route{
// {
// Model: gorm.Model{
// ID: 1,
// },
// Prefix: ipp("10.0.0.0/24"),
// IsPrimary: true,
// },
// },
// Hostinfo: &tailcfg.Hostinfo{
// RoutableIPs: []netip.Prefix{
// netip.MustParsePrefix("10.0.0.0/24"),
// },
// },
// },
// &types.Node{
// ID: 1,
// MachineKey: machineKeys[1],
// Routes: []types.Route{
// {
// Model: gorm.Model{
// ID: 2,
// },
// Prefix: ipp("10.0.0.0/24"),
// IsPrimary: false,
// },
// },
// Hostinfo: &tailcfg.Hostinfo{
// RoutableIPs: []netip.Prefix{
// netip.MustParsePrefix("10.0.0.0/24"),
// },
// },
// },
// },
// routeID: 1,
// wantMachineKey: machineKeys[1],
// },
// {
// name: "no-failover-offline",
// nodes: types.Nodes{
// &types.Node{
// ID: 0,
// MachineKey: machineKeys[0],
// Routes: []types.Route{
// {
// Model: gorm.Model{
// ID: 1,
// },
// Prefix: ipp("10.0.0.0/24"),
// IsPrimary: true,
// },
// },
// Hostinfo: &tailcfg.Hostinfo{
// RoutableIPs: []netip.Prefix{
// netip.MustParsePrefix("10.0.0.0/24"),
// },
// },
// },
// &types.Node{
// ID: 1,
// MachineKey: machineKeys[1],
// Routes: []types.Route{
// {
// Model: gorm.Model{
// ID: 2,
// },
// Prefix: ipp("10.0.0.0/24"),
// IsPrimary: false,
// },
// },
// Hostinfo: &tailcfg.Hostinfo{
// RoutableIPs: []netip.Prefix{
// netip.MustParsePrefix("10.0.0.0/24"),
// },
// },
// },
// },
// isConnected: map[key.MachinePublic]bool{
// machineKeys[0]: true,
// machineKeys[1]: false,
// },
// routeID: 1,
// wantMachineKey: machineKeys[1],
// },
// {
// name: "failover-to-online",
// nodes: types.Nodes{
// &types.Node{
// ID: 0,
// MachineKey: machineKeys[0],
// Routes: []types.Route{
// {
// Model: gorm.Model{
// ID: 1,
// },
// Prefix: ipp("10.0.0.0/24"),
// IsPrimary: true,
// },
// },
// Hostinfo: &tailcfg.Hostinfo{
// RoutableIPs: []netip.Prefix{
// netip.MustParsePrefix("10.0.0.0/24"),
// },
// },
// },
// &types.Node{
// ID: 1,
// MachineKey: machineKeys[1],
// Routes: []types.Route{
// {
// Model: gorm.Model{
// ID: 2,
// },
// Prefix: ipp("10.0.0.0/24"),
// IsPrimary: false,
// },
// },
// Hostinfo: &tailcfg.Hostinfo{
// RoutableIPs: []netip.Prefix{
// netip.MustParsePrefix("10.0.0.0/24"),
// },
// },
// },
// },
// isConnected: map[key.MachinePublic]bool{
// machineKeys[0]: true,
// machineKeys[1]: true,
// },
// routeID: 1,
// wantMachineKey: machineKeys[1],
// },
// }
// for _, tt := range tests {
// t.Run(tt.name, func(t *testing.T) {
// datab, err := NewHeadscaleDatabase("sqlite3", ":memory:", false, []netip.Prefix{}, "")
// assert.NoError(t, err)
// // bootstrap db
// datab.DB.Transaction(func(tx *gorm.DB) error {
// for _, node := range tt.nodes {
// err := tx.Save(node).Error
// if err != nil {
// return err
// }
// _, err = SaveNodeRoutes(tx, node)
// if err != nil {
// return err
// }
// }
// return nil
// })
// got, err := Write(datab.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
// return DisableRoute(tx, tt.routeID, tt.isConnected)
// })
// // if (err.Error() != "") != tt.wantErr {
// // t.Errorf("failoverRoute() error = %v, wantErr %v", err, tt.wantErr)
// // return
// // }
// if len(got.ChangeNodes) != 1 {
// t.Errorf("expected update with one machine, got %d", len(got.ChangeNodes))
// }
// if diff := cmp.Diff(tt.wantMachineKey, got.ChangeNodes[0].MachineKey, util.Comparers...); diff != "" {
// t.Errorf("DisableRoute() unexpected result (-want +got):\n%s", diff)
// }
// })
// }
// }