use gorm serialiser instead of custom hooks (#2156)

* add sqlite to debug/test image

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* test using gorm serialiser instead of custom hooks

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

---------

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2024-10-02 11:41:58 +02:00 committed by GitHub
parent 3964dec1c6
commit bc9e83b52e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 240 additions and 351 deletions

View file

@ -20,9 +20,14 @@ import (
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
"tailscale.com/util/set"
)
func init() {
schema.RegisterSerializer("text", TextSerialiser{})
}
var errDatabaseNotSupported = errors.New("database type not supported")
// KV is a key-value store in a psql table. For future use...
@ -33,7 +38,8 @@ type KV struct {
}
type HSDatabase struct {
DB *gorm.DB
DB *gorm.DB
cfg *types.DatabaseConfig
baseDomain string
}
@ -191,7 +197,7 @@ func NewHeadscaleDatabase(
type NodeAux struct {
ID uint64
EnabledRoutes types.IPPrefixes
EnabledRoutes []netip.Prefix `gorm:"serializer:json"`
}
nodesAux := []NodeAux{}
@ -214,7 +220,7 @@ func NewHeadscaleDatabase(
}
err = tx.Preload("Node").
Where("node_id = ? AND prefix = ?", node.ID, types.IPPrefix(prefix)).
Where("node_id = ? AND prefix = ?", node.ID, prefix).
First(&types.Route{}).
Error
if err == nil {
@ -229,7 +235,7 @@ func NewHeadscaleDatabase(
NodeID: node.ID,
Advertised: true,
Enabled: true,
Prefix: types.IPPrefix(prefix),
Prefix: prefix,
}
if err := tx.Create(&route).Error; err != nil {
log.Error().Err(err).Msg("Error creating route")
@ -476,7 +482,8 @@ func NewHeadscaleDatabase(
}
db := HSDatabase{
DB: dbConn,
DB: dbConn,
cfg: &cfg,
baseDomain: baseDomain,
}
@ -676,6 +683,10 @@ func (hsdb *HSDatabase) Close() error {
return err
}
if hsdb.cfg.Type == types.DatabaseSqlite && hsdb.cfg.Sqlite.WriteAheadLog {
db.Exec("VACUUM")
}
return db.Close()
}

View file

@ -13,13 +13,14 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/stretchr/testify/assert"
"gorm.io/gorm"
)
func TestMigrations(t *testing.T) {
ipp := func(p string) types.IPPrefix {
return types.IPPrefix(netip.MustParsePrefix(p))
ipp := func(p string) netip.Prefix {
return netip.MustParsePrefix(p)
}
r := func(id uint64, p string, a, e, i bool) types.Route {
return types.Route{
@ -56,9 +57,7 @@ func TestMigrations(t *testing.T) {
r(31, "::/0", true, false, false),
r(32, "192.168.0.24/32", true, true, true),
}
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), cmp.Comparer(func(x, y types.IPPrefix) bool {
return x == y
})); diff != "" {
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), util.PrefixComparer); diff != "" {
t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff)
}
},
@ -103,9 +102,7 @@ func TestMigrations(t *testing.T) {
r(13, "::/0", true, true, false),
r(13, "10.18.80.2/32", true, true, true),
}
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), cmp.Comparer(func(x, y types.IPPrefix) bool {
return x == y
})); diff != "" {
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), util.PrefixComparer); diff != "" {
t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff)
}
},
@ -172,6 +169,29 @@ func TestMigrations(t *testing.T) {
}
},
},
{
dbPath: "testdata/0-23-0-to-0-24-0-no-more-special-types.sqlite",
wantFunc: func(t *testing.T, h *HSDatabase) {
nodes, err := Read(h.DB, func(rx *gorm.DB) (types.Nodes, error) {
return ListNodes(rx)
})
assert.NoError(t, err)
for _, node := range nodes {
assert.Falsef(t, node.MachineKey.IsZero(), "expected non zero machinekey")
assert.Contains(t, node.MachineKey.String(), "mkey:")
assert.Falsef(t, node.NodeKey.IsZero(), "expected non zero nodekey")
assert.Contains(t, node.NodeKey.String(), "nodekey:")
assert.Falsef(t, node.DiscoKey.IsZero(), "expected non zero discokey")
assert.Contains(t, node.DiscoKey.String(), "discokey:")
assert.NotNil(t, node.IPv4)
assert.NotNil(t, node.IPv4)
assert.Len(t, node.Endpoints, 1)
assert.NotNil(t, node.Hostinfo)
assert.NotNil(t, node.MachineKey)
}
},
},
}
for _, tt := range tests {

View file

@ -1,7 +1,6 @@
package db
import (
"database/sql"
"fmt"
"net/netip"
"strings"
@ -294,15 +293,7 @@ func TestBackfillIPAddresses(t *testing.T) {
v4 := fmt.Sprintf("100.64.0.%d", i)
v6 := fmt.Sprintf("fd7a:115c:a1e0::%d", i)
return &types.Node{
IPv4DatabaseField: sql.NullString{
Valid: true,
String: v4,
},
IPv4: nap(v4),
IPv6DatabaseField: sql.NullString{
Valid: true,
String: v6,
},
IPv6: nap(v6),
}
}
@ -334,15 +325,7 @@ func TestBackfillIPAddresses(t *testing.T) {
want: types.Nodes{
&types.Node{
IPv4DatabaseField: sql.NullString{
Valid: true,
String: "100.64.0.1",
},
IPv4: nap("100.64.0.1"),
IPv6DatabaseField: sql.NullString{
Valid: true,
String: "fd7a:115c:a1e0::1",
},
IPv6: nap("fd7a:115c:a1e0::1"),
},
},
@ -367,15 +350,7 @@ func TestBackfillIPAddresses(t *testing.T) {
want: types.Nodes{
&types.Node{
IPv4DatabaseField: sql.NullString{
Valid: true,
String: "100.64.0.1",
},
IPv4: nap("100.64.0.1"),
IPv6DatabaseField: sql.NullString{
Valid: true,
String: "fd7a:115c:a1e0::1",
},
IPv6: nap("fd7a:115c:a1e0::1"),
},
},
@ -400,10 +375,6 @@ func TestBackfillIPAddresses(t *testing.T) {
want: types.Nodes{
&types.Node{
IPv4DatabaseField: sql.NullString{
Valid: true,
String: "100.64.0.1",
},
IPv4: nap("100.64.0.1"),
},
},
@ -428,10 +399,6 @@ func TestBackfillIPAddresses(t *testing.T) {
want: types.Nodes{
&types.Node{
IPv6DatabaseField: sql.NullString{
Valid: true,
String: "fd7a:115c:a1e0::1",
},
IPv6: nap("fd7a:115c:a1e0::1"),
},
},
@ -477,13 +444,9 @@ func TestBackfillIPAddresses(t *testing.T) {
comps := append(util.Comparers, cmpopts.IgnoreFields(types.Node{},
"ID",
"MachineKeyDatabaseField",
"NodeKeyDatabaseField",
"DiscoKeyDatabaseField",
"User",
"UserID",
"Endpoints",
"HostinfoDatabaseField",
"Hostinfo",
"Routes",
"CreatedAt",

View file

@ -1,6 +1,7 @@
package db
import (
"encoding/json"
"errors"
"fmt"
"net/netip"
@ -207,21 +208,26 @@ func SetTags(
) error {
if len(tags) == 0 {
// if no tags are provided, we remove all forced tags
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", types.StringList{}).Error; err != nil {
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", "[]").Error; err != nil {
return fmt.Errorf("failed to remove tags for node in the database: %w", err)
}
return nil
}
var newTags types.StringList
var newTags []string
for _, tag := range tags {
if !slices.Contains(newTags, tag) {
newTags = append(newTags, tag)
}
}
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", newTags).Error; err != nil {
b, err := json.Marshal(newTags)
if err != nil {
return err
}
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", string(b)).Error; err != nil {
return fmt.Errorf("failed to update tags for node in the database: %w", err)
}
@ -569,7 +575,7 @@ func enableRoutes(tx *gorm.DB,
for _, prefix := range newRoutes {
route := types.Route{}
err := tx.Preload("Node").
Where("node_id = ? AND prefix = ?", node.ID, types.IPPrefix(prefix)).
Where("node_id = ? AND prefix = ?", node.ID, prefix.String()).
First(&route).Error
if err == nil {
route.Enabled = true

View file

@ -201,7 +201,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
nodeKey := key.NewNode()
machineKey := key.NewMachine()
v4 := netip.MustParseAddr(fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1)))
v4 := netip.MustParseAddr(fmt.Sprintf("100.64.0.%d", index+1))
node := types.Node{
ID: types.NodeID(index),
MachineKey: machineKey.Public(),
@ -239,6 +239,8 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
adminNode, err := db.GetNodeByID(1)
c.Logf("Node(%v), user: %v", adminNode.Hostname, adminNode.User)
c.Assert(adminNode.IPv4, check.NotNil)
c.Assert(adminNode.IPv6, check.IsNil)
c.Assert(err, check.IsNil)
testNode, err := db.GetNodeByID(2)
@ -247,9 +249,11 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
adminPeers, err := db.ListPeers(adminNode.ID)
c.Assert(err, check.IsNil)
c.Assert(len(adminPeers), check.Equals, 9)
testPeers, err := db.ListPeers(testNode.ID)
c.Assert(err, check.IsNil)
c.Assert(len(testPeers), check.Equals, 9)
adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers)
c.Assert(err, check.IsNil)
@ -259,14 +263,14 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules)
peersOfTestNode := policy.FilterNodesByACL(testNode, testPeers, testRules)
c.Log(peersOfAdminNode)
c.Log(peersOfTestNode)
c.Assert(len(peersOfTestNode), check.Equals, 9)
c.Assert(peersOfTestNode[0].Hostname, check.Equals, "testnode1")
c.Assert(peersOfTestNode[1].Hostname, check.Equals, "testnode3")
c.Assert(peersOfTestNode[3].Hostname, check.Equals, "testnode5")
c.Log(peersOfAdminNode)
c.Assert(len(peersOfAdminNode), check.Equals, 9)
c.Assert(peersOfAdminNode[0].Hostname, check.Equals, "testnode2")
c.Assert(peersOfAdminNode[2].Hostname, check.Equals, "testnode4")
@ -346,7 +350,7 @@ func (s *Suite) TestSetTags(c *check.C) {
c.Assert(err, check.IsNil)
node, err = db.getNode("test", "testnode")
c.Assert(err, check.IsNil)
c.Assert(node.ForcedTags, check.DeepEquals, types.StringList(sTags))
c.Assert(node.ForcedTags, check.DeepEquals, sTags)
// assign duplicate tags, expect no errors but no doubles in DB
eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"}
@ -357,7 +361,7 @@ func (s *Suite) TestSetTags(c *check.C) {
c.Assert(
node.ForcedTags,
check.DeepEquals,
types.StringList([]string{"tag:bar", "tag:test", "tag:unknown"}),
[]string{"tag:bar", "tag:test", "tag:unknown"},
)
// test removing tags
@ -365,7 +369,7 @@ func (s *Suite) TestSetTags(c *check.C) {
c.Assert(err, check.IsNil)
node, err = db.getNode("test", "testnode")
c.Assert(err, check.IsNil)
c.Assert(node.ForcedTags, check.DeepEquals, types.StringList([]string{}))
c.Assert(node.ForcedTags, check.DeepEquals, []string{})
}
func TestHeadscale_generateGivenName(t *testing.T) {

View file

@ -77,7 +77,7 @@ func CreatePreAuthKey(
Ephemeral: ephemeral,
CreatedAt: &now,
Expiration: expiration,
Tags: types.StringList(aclTags),
Tags: aclTags,
}
if err := tx.Save(&key).Error; err != nil {

View file

@ -49,7 +49,7 @@ func getRoutesByPrefix(tx *gorm.DB, pref netip.Prefix) (types.Routes, error) {
err := tx.
Preload("Node").
Preload("Node.User").
Where("prefix = ?", types.IPPrefix(pref)).
Where("prefix = ?", pref.String()).
Find(&routes).Error
if err != nil {
return nil, err
@ -286,7 +286,7 @@ func isUniquePrefix(tx *gorm.DB, route types.Route) bool {
var count int64
tx.Model(&types.Route{}).
Where("prefix = ? AND node_id != ? AND advertised = ? AND enabled = ?",
route.Prefix,
route.Prefix.String(),
route.NodeID,
true, true).Count(&count)
@ -297,7 +297,7 @@ func getPrimaryRoute(tx *gorm.DB, prefix netip.Prefix) (*types.Route, error) {
var route types.Route
err := tx.
Preload("Node").
Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", types.IPPrefix(prefix), true, true, true).
Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", prefix.String(), true, true, true).
First(&route).Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
@ -392,7 +392,7 @@ func SaveNodeRoutes(tx *gorm.DB, node *types.Node) (bool, error) {
if !exists {
route := types.Route{
NodeID: node.ID.Uint64(),
Prefix: types.IPPrefix(prefix),
Prefix: prefix,
Advertised: true,
Enabled: false,
}

View file

@ -290,7 +290,7 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
}
var (
ipp = func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) }
ipp = func(s string) netip.Prefix { return netip.MustParsePrefix(s) }
mkNode = func(nid types.NodeID) types.Node {
return types.Node{ID: nid}
}
@ -301,7 +301,7 @@ var np = func(nid types.NodeID) *types.Node {
return &no
}
var r = func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) types.Route {
var r = func(id uint, nid types.NodeID, prefix netip.Prefix, enabled, primary bool) types.Route {
return types.Route{
Model: gorm.Model{
ID: id,
@ -313,7 +313,7 @@ var r = func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary
}
}
var rp = func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) *types.Route {
var rp = func(id uint, nid types.NodeID, prefix netip.Prefix, enabled, primary bool) *types.Route {
ro := r(id, nid, prefix, enabled, primary)
return &ro
}
@ -1069,7 +1069,7 @@ func TestFailoverRouteTx(t *testing.T) {
}
func TestFailoverRoute(t *testing.T) {
r := func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) types.Route {
r := func(id uint, nid types.NodeID, prefix netip.Prefix, enabled, primary bool) types.Route {
return types.Route{
Model: gorm.Model{
ID: id,
@ -1082,7 +1082,7 @@ func TestFailoverRoute(t *testing.T) {
IsPrimary: primary,
}
}
rp := func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) *types.Route {
rp := func(id uint, nid types.NodeID, prefix netip.Prefix, enabled, primary bool) *types.Route {
ro := r(id, nid, prefix, enabled, primary)
return &ro
}
@ -1205,13 +1205,6 @@ func TestFailoverRoute(t *testing.T) {
},
}
cmps := append(
util.Comparers,
cmp.Comparer(func(x, y types.IPPrefix) bool {
return netip.Prefix(x) == netip.Prefix(y)
}),
)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotf := failoverRoute(smap(tt.isConnected), &tt.failingRoute, tt.routes)
@ -1235,7 +1228,7 @@ func TestFailoverRoute(t *testing.T) {
"old": gotf.old,
}
if diff := cmp.Diff(want, got, cmps...); diff != "" {
if diff := cmp.Diff(want, got, util.Comparers...); diff != "" {
t.Fatalf("failoverRoute unexpected result (-want +got):\n%s", diff)
}
}

Binary file not shown.

View file

@ -0,0 +1,99 @@
package db
import (
"context"
"encoding"
"fmt"
"reflect"
"gorm.io/gorm/schema"
)
// Got from https://github.com/xdg-go/strum/blob/main/types.go
var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
func isTextUnmarshaler(rv reflect.Value) bool {
return rv.Type().Implements(textUnmarshalerType)
}
func maybeInstantiatePtr(rv reflect.Value) {
if rv.Kind() == reflect.Ptr && rv.IsNil() {
np := reflect.New(rv.Type().Elem())
rv.Set(np)
}
}
func decodingError(name string, err error) error {
return fmt.Errorf("error decoding to %s: %w", name, err)
}
// TextSerialiser implements the Serialiser interface for fields that
// have a type that implements encoding.TextUnmarshaler.
type TextSerialiser struct{}
func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue interface{}) (err error) {
fieldValue := reflect.New(field.FieldType)
// If the field is a pointer, we need to dereference it to get the actual type
// so we do not end with a second pointer.
if fieldValue.Elem().Kind() == reflect.Ptr {
fieldValue = fieldValue.Elem()
}
if dbValue != nil {
var bytes []byte
switch v := dbValue.(type) {
case []byte:
bytes = v
case string:
bytes = []byte(v)
default:
return fmt.Errorf("failed to unmarshal text value: %#v", dbValue)
}
if isTextUnmarshaler(fieldValue) {
maybeInstantiatePtr(fieldValue)
f := fieldValue.MethodByName("UnmarshalText")
args := []reflect.Value{reflect.ValueOf(bytes)}
ret := f.Call(args)
if !ret[0].IsNil() {
return decodingError(field.Name, ret[0].Interface().(error))
}
// If the underlying field is to a pointer type, we need to
// assign the value as a pointer to it.
// If it is not a pointer, we need to assign the value to the
// field.
dstField := field.ReflectValueOf(ctx, dst)
if dstField.Kind() == reflect.Ptr {
dstField.Set(fieldValue)
} else {
dstField.Set(fieldValue.Elem())
}
return nil
} else {
return fmt.Errorf("unsupported type: %T", fieldValue.Interface())
}
}
return
}
func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
switch v := fieldValue.(type) {
case encoding.TextMarshaler:
// If the value is nil, we return nil, however, go nil values are not
// always comparable, particularly when reflection is involved:
// https://dev.to/arxeiss/in-go-nil-is-not-equal-to-nil-sometimes-jn8
if v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil()) {
return nil, nil
}
b, err := v.MarshalText()
if err != nil {
return nil, err
}
return string(b), nil
default:
return nil, fmt.Errorf("only encoding.TextMarshaler is supported, got %t", v)
}
}