Replace database locks with transactions (#1701)
This commits removes the locks used to guard data integrity for the database and replaces them with Transactions, turns out that SQL had a way to deal with this all along. This reduces the complexity we had with multiple locks that might stack or recurse (database, nofitifer, mapper). All notifications and state updates are now triggered _after_ a database change. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
cbf57e27a7
commit
83769ba715
32 changed files with 1496 additions and 1128 deletions
|
@ -8,11 +8,13 @@ import (
|
|||
"time"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/db"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
@ -136,12 +138,14 @@ func (api headscaleV1APIServer) ExpirePreAuthKey(
|
|||
ctx context.Context,
|
||||
request *v1.ExpirePreAuthKeyRequest,
|
||||
) (*v1.ExpirePreAuthKeyResponse, error) {
|
||||
preAuthKey, err := api.h.db.GetPreAuthKey(request.GetUser(), request.Key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err := api.h.db.DB.Transaction(func(tx *gorm.DB) error {
|
||||
preAuthKey, err := db.GetPreAuthKey(tx, request.GetUser(), request.Key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = api.h.db.ExpirePreAuthKey(preAuthKey)
|
||||
return db.ExpirePreAuthKey(tx, preAuthKey)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -181,17 +185,31 @@ func (api headscaleV1APIServer) RegisterNode(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
node, err := api.h.db.RegisterNodeFromAuthCallback(
|
||||
api.h.registrationCache,
|
||||
mkey,
|
||||
request.GetUser(),
|
||||
nil,
|
||||
util.RegisterMethodCLI,
|
||||
)
|
||||
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
return db.RegisterNodeFromAuthCallback(
|
||||
tx,
|
||||
api.h.registrationCache,
|
||||
mkey,
|
||||
request.GetUser(),
|
||||
nil,
|
||||
util.RegisterMethodCLI,
|
||||
api.h.cfg.IPPrefixes,
|
||||
)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stateUpdate := types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: types.Nodes{node},
|
||||
Message: "called from api.RegisterNode",
|
||||
}
|
||||
if stateUpdate.Valid() {
|
||||
ctx := types.NotifyCtx(ctx, "cli-registernode", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
|
||||
}
|
||||
|
||||
return &v1.RegisterNodeResponse{Node: node.Proto()}, nil
|
||||
}
|
||||
|
||||
|
@ -217,25 +235,35 @@ func (api headscaleV1APIServer) SetTags(
|
|||
ctx context.Context,
|
||||
request *v1.SetTagsRequest,
|
||||
) (*v1.SetTagsResponse, error) {
|
||||
node, err := api.h.db.GetNodeByID(request.GetNodeId())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, tag := range request.GetTags() {
|
||||
err := validateTag(tag)
|
||||
if err != nil {
|
||||
return &v1.SetTagsResponse{
|
||||
Node: nil,
|
||||
}, status.Error(codes.InvalidArgument, err.Error())
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
err = api.h.db.SetTags(node, request.GetTags())
|
||||
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
err := db.SetTags(tx, request.GetNodeId(), request.GetTags())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db.GetNodeByID(tx, request.GetNodeId())
|
||||
})
|
||||
if err != nil {
|
||||
return &v1.SetTagsResponse{
|
||||
Node: nil,
|
||||
}, status.Error(codes.Internal, err.Error())
|
||||
}, status.Error(codes.InvalidArgument, err.Error())
|
||||
}
|
||||
|
||||
stateUpdate := types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: types.Nodes{node},
|
||||
Message: "called from api.SetTags",
|
||||
}
|
||||
if stateUpdate.Valid() {
|
||||
ctx := types.NotifyCtx(ctx, "cli-settags", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
|
@ -270,11 +298,21 @@ func (api headscaleV1APIServer) DeleteNode(
|
|||
|
||||
err = api.h.db.DeleteNode(
|
||||
node,
|
||||
api.h.nodeNotifier.ConnectedMap(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stateUpdate := types.StateUpdate{
|
||||
Type: types.StatePeerRemoved,
|
||||
Removed: []tailcfg.NodeID{tailcfg.NodeID(node.ID)},
|
||||
}
|
||||
if stateUpdate.Valid() {
|
||||
ctx := types.NotifyCtx(ctx, "cli-deletenode", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyAll(ctx, stateUpdate)
|
||||
}
|
||||
|
||||
return &v1.DeleteNodeResponse{}, nil
|
||||
}
|
||||
|
||||
|
@ -282,17 +320,38 @@ func (api headscaleV1APIServer) ExpireNode(
|
|||
ctx context.Context,
|
||||
request *v1.ExpireNodeRequest,
|
||||
) (*v1.ExpireNodeResponse, error) {
|
||||
node, err := api.h.db.GetNodeByID(request.GetNodeId())
|
||||
now := time.Now()
|
||||
|
||||
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
db.NodeSetExpiry(
|
||||
tx,
|
||||
request.GetNodeId(),
|
||||
now,
|
||||
)
|
||||
|
||||
return db.GetNodeByID(tx, request.GetNodeId())
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
selfUpdate := types.StateUpdate{
|
||||
Type: types.StateSelfUpdate,
|
||||
ChangeNodes: types.Nodes{node},
|
||||
}
|
||||
if selfUpdate.Valid() {
|
||||
ctx := types.NotifyCtx(ctx, "cli-expirenode-self", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyByMachineKey(
|
||||
ctx,
|
||||
selfUpdate,
|
||||
node.MachineKey)
|
||||
}
|
||||
|
||||
api.h.db.NodeSetExpiry(
|
||||
node,
|
||||
now,
|
||||
)
|
||||
stateUpdate := types.StateUpdateExpire(node.ID, now)
|
||||
if stateUpdate.Valid() {
|
||||
ctx := types.NotifyCtx(ctx, "cli-expirenode-peers", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
Str("node", node.Hostname).
|
||||
|
@ -306,17 +365,30 @@ func (api headscaleV1APIServer) RenameNode(
|
|||
ctx context.Context,
|
||||
request *v1.RenameNodeRequest,
|
||||
) (*v1.RenameNodeResponse, error) {
|
||||
node, err := api.h.db.GetNodeByID(request.GetNodeId())
|
||||
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
err := db.RenameNode(
|
||||
tx,
|
||||
request.GetNodeId(),
|
||||
request.GetNewName(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db.GetNodeByID(tx, request.GetNodeId())
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = api.h.db.RenameNode(
|
||||
node,
|
||||
request.GetNewName(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
stateUpdate := types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: types.Nodes{node},
|
||||
Message: "called from api.RenameNode",
|
||||
}
|
||||
if stateUpdate.Valid() {
|
||||
ctx := types.NotifyCtx(ctx, "cli-renamenode", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
|
@ -331,8 +403,11 @@ func (api headscaleV1APIServer) ListNodes(
|
|||
ctx context.Context,
|
||||
request *v1.ListNodesRequest,
|
||||
) (*v1.ListNodesResponse, error) {
|
||||
isConnected := api.h.nodeNotifier.ConnectedMap()
|
||||
if request.GetUser() != "" {
|
||||
nodes, err := api.h.db.ListNodesByUser(request.GetUser())
|
||||
nodes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
||||
return db.ListNodesByUser(rx, request.GetUser())
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -343,7 +418,7 @@ func (api headscaleV1APIServer) ListNodes(
|
|||
|
||||
// Populate the online field based on
|
||||
// currently connected nodes.
|
||||
resp.Online = api.h.nodeNotifier.IsConnected(node.MachineKey)
|
||||
resp.Online = isConnected[node.MachineKey]
|
||||
|
||||
response[index] = resp
|
||||
}
|
||||
|
@ -362,10 +437,10 @@ func (api headscaleV1APIServer) ListNodes(
|
|||
|
||||
// Populate the online field based on
|
||||
// currently connected nodes.
|
||||
resp.Online = api.h.nodeNotifier.IsConnected(node.MachineKey)
|
||||
resp.Online = isConnected[node.MachineKey]
|
||||
|
||||
validTags, invalidTags := api.h.ACLPolicy.TagsOfNode(
|
||||
&node,
|
||||
node,
|
||||
)
|
||||
resp.InvalidTags = invalidTags
|
||||
resp.ValidTags = validTags
|
||||
|
@ -396,7 +471,9 @@ func (api headscaleV1APIServer) GetRoutes(
|
|||
ctx context.Context,
|
||||
request *v1.GetRoutesRequest,
|
||||
) (*v1.GetRoutesResponse, error) {
|
||||
routes, err := api.h.db.GetRoutes()
|
||||
routes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Routes, error) {
|
||||
return db.GetRoutes(rx)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -410,11 +487,19 @@ func (api headscaleV1APIServer) EnableRoute(
|
|||
ctx context.Context,
|
||||
request *v1.EnableRouteRequest,
|
||||
) (*v1.EnableRouteResponse, error) {
|
||||
err := api.h.db.EnableRoute(request.GetRouteId())
|
||||
update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
|
||||
return db.EnableRoute(tx, request.GetRouteId())
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if update != nil && update.Valid() {
|
||||
ctx := types.NotifyCtx(ctx, "cli-enableroute", "unknown")
|
||||
api.h.nodeNotifier.NotifyAll(
|
||||
ctx, *update)
|
||||
}
|
||||
|
||||
return &v1.EnableRouteResponse{}, nil
|
||||
}
|
||||
|
||||
|
@ -422,11 +507,19 @@ func (api headscaleV1APIServer) DisableRoute(
|
|||
ctx context.Context,
|
||||
request *v1.DisableRouteRequest,
|
||||
) (*v1.DisableRouteResponse, error) {
|
||||
err := api.h.db.DisableRoute(request.GetRouteId())
|
||||
isConnected := api.h.nodeNotifier.ConnectedMap()
|
||||
update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
|
||||
return db.DisableRoute(tx, request.GetRouteId(), isConnected)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if update != nil && update.Valid() {
|
||||
ctx := types.NotifyCtx(ctx, "cli-disableroute", "unknown")
|
||||
api.h.nodeNotifier.NotifyAll(ctx, *update)
|
||||
}
|
||||
|
||||
return &v1.DisableRouteResponse{}, nil
|
||||
}
|
||||
|
||||
|
@ -453,11 +546,19 @@ func (api headscaleV1APIServer) DeleteRoute(
|
|||
ctx context.Context,
|
||||
request *v1.DeleteRouteRequest,
|
||||
) (*v1.DeleteRouteResponse, error) {
|
||||
err := api.h.db.DeleteRoute(request.GetRouteId())
|
||||
isConnected := api.h.nodeNotifier.ConnectedMap()
|
||||
update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
|
||||
return db.DeleteRoute(tx, request.GetRouteId(), isConnected)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if update != nil && update.Valid() {
|
||||
ctx := types.NotifyCtx(ctx, "cli-deleteroute", "unknown")
|
||||
api.h.nodeNotifier.NotifyWithIgnore(ctx, *update)
|
||||
}
|
||||
|
||||
return &v1.DeleteRouteResponse{}, nil
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue