use gorm serialiser instead of custom hooks (#2156)
* add sqlite to debug/test image Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * test using gorm serialiser instead of custom hooks Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> --------- Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
3964dec1c6
commit
bc9e83b52e
21 changed files with 240 additions and 351 deletions
|
@ -20,9 +20,14 @@ import (
|
|||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/schema"
|
||||
"tailscale.com/util/set"
|
||||
)
|
||||
|
||||
func init() {
|
||||
schema.RegisterSerializer("text", TextSerialiser{})
|
||||
}
|
||||
|
||||
var errDatabaseNotSupported = errors.New("database type not supported")
|
||||
|
||||
// KV is a key-value store in a psql table. For future use...
|
||||
|
@ -33,7 +38,8 @@ type KV struct {
|
|||
}
|
||||
|
||||
type HSDatabase struct {
|
||||
DB *gorm.DB
|
||||
DB *gorm.DB
|
||||
cfg *types.DatabaseConfig
|
||||
|
||||
baseDomain string
|
||||
}
|
||||
|
@ -191,7 +197,7 @@ func NewHeadscaleDatabase(
|
|||
|
||||
type NodeAux struct {
|
||||
ID uint64
|
||||
EnabledRoutes types.IPPrefixes
|
||||
EnabledRoutes []netip.Prefix `gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
nodesAux := []NodeAux{}
|
||||
|
@ -214,7 +220,7 @@ func NewHeadscaleDatabase(
|
|||
}
|
||||
|
||||
err = tx.Preload("Node").
|
||||
Where("node_id = ? AND prefix = ?", node.ID, types.IPPrefix(prefix)).
|
||||
Where("node_id = ? AND prefix = ?", node.ID, prefix).
|
||||
First(&types.Route{}).
|
||||
Error
|
||||
if err == nil {
|
||||
|
@ -229,7 +235,7 @@ func NewHeadscaleDatabase(
|
|||
NodeID: node.ID,
|
||||
Advertised: true,
|
||||
Enabled: true,
|
||||
Prefix: types.IPPrefix(prefix),
|
||||
Prefix: prefix,
|
||||
}
|
||||
if err := tx.Create(&route).Error; err != nil {
|
||||
log.Error().Err(err).Msg("Error creating route")
|
||||
|
@ -476,7 +482,8 @@ func NewHeadscaleDatabase(
|
|||
}
|
||||
|
||||
db := HSDatabase{
|
||||
DB: dbConn,
|
||||
DB: dbConn,
|
||||
cfg: &cfg,
|
||||
|
||||
baseDomain: baseDomain,
|
||||
}
|
||||
|
@ -676,6 +683,10 @@ func (hsdb *HSDatabase) Close() error {
|
|||
return err
|
||||
}
|
||||
|
||||
if hsdb.cfg.Type == types.DatabaseSqlite && hsdb.cfg.Sqlite.WriteAheadLog {
|
||||
db.Exec("VACUUM")
|
||||
}
|
||||
|
||||
return db.Close()
|
||||
}
|
||||
|
||||
|
|
|
@ -13,13 +13,14 @@ import (
|
|||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func TestMigrations(t *testing.T) {
|
||||
ipp := func(p string) types.IPPrefix {
|
||||
return types.IPPrefix(netip.MustParsePrefix(p))
|
||||
ipp := func(p string) netip.Prefix {
|
||||
return netip.MustParsePrefix(p)
|
||||
}
|
||||
r := func(id uint64, p string, a, e, i bool) types.Route {
|
||||
return types.Route{
|
||||
|
@ -56,9 +57,7 @@ func TestMigrations(t *testing.T) {
|
|||
r(31, "::/0", true, false, false),
|
||||
r(32, "192.168.0.24/32", true, true, true),
|
||||
}
|
||||
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), cmp.Comparer(func(x, y types.IPPrefix) bool {
|
||||
return x == y
|
||||
})); diff != "" {
|
||||
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), util.PrefixComparer); diff != "" {
|
||||
t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
},
|
||||
|
@ -103,9 +102,7 @@ func TestMigrations(t *testing.T) {
|
|||
r(13, "::/0", true, true, false),
|
||||
r(13, "10.18.80.2/32", true, true, true),
|
||||
}
|
||||
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), cmp.Comparer(func(x, y types.IPPrefix) bool {
|
||||
return x == y
|
||||
})); diff != "" {
|
||||
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), util.PrefixComparer); diff != "" {
|
||||
t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
},
|
||||
|
@ -172,6 +169,29 @@ func TestMigrations(t *testing.T) {
|
|||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
dbPath: "testdata/0-23-0-to-0-24-0-no-more-special-types.sqlite",
|
||||
wantFunc: func(t *testing.T, h *HSDatabase) {
|
||||
nodes, err := Read(h.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
||||
return ListNodes(rx)
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
for _, node := range nodes {
|
||||
assert.Falsef(t, node.MachineKey.IsZero(), "expected non zero machinekey")
|
||||
assert.Contains(t, node.MachineKey.String(), "mkey:")
|
||||
assert.Falsef(t, node.NodeKey.IsZero(), "expected non zero nodekey")
|
||||
assert.Contains(t, node.NodeKey.String(), "nodekey:")
|
||||
assert.Falsef(t, node.DiscoKey.IsZero(), "expected non zero discokey")
|
||||
assert.Contains(t, node.DiscoKey.String(), "discokey:")
|
||||
assert.NotNil(t, node.IPv4)
|
||||
assert.NotNil(t, node.IPv4)
|
||||
assert.Len(t, node.Endpoints, 1)
|
||||
assert.NotNil(t, node.Hostinfo)
|
||||
assert.NotNil(t, node.MachineKey)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
@ -294,15 +293,7 @@ func TestBackfillIPAddresses(t *testing.T) {
|
|||
v4 := fmt.Sprintf("100.64.0.%d", i)
|
||||
v6 := fmt.Sprintf("fd7a:115c:a1e0::%d", i)
|
||||
return &types.Node{
|
||||
IPv4DatabaseField: sql.NullString{
|
||||
Valid: true,
|
||||
String: v4,
|
||||
},
|
||||
IPv4: nap(v4),
|
||||
IPv6DatabaseField: sql.NullString{
|
||||
Valid: true,
|
||||
String: v6,
|
||||
},
|
||||
IPv6: nap(v6),
|
||||
}
|
||||
}
|
||||
|
@ -334,15 +325,7 @@ func TestBackfillIPAddresses(t *testing.T) {
|
|||
|
||||
want: types.Nodes{
|
||||
&types.Node{
|
||||
IPv4DatabaseField: sql.NullString{
|
||||
Valid: true,
|
||||
String: "100.64.0.1",
|
||||
},
|
||||
IPv4: nap("100.64.0.1"),
|
||||
IPv6DatabaseField: sql.NullString{
|
||||
Valid: true,
|
||||
String: "fd7a:115c:a1e0::1",
|
||||
},
|
||||
IPv6: nap("fd7a:115c:a1e0::1"),
|
||||
},
|
||||
},
|
||||
|
@ -367,15 +350,7 @@ func TestBackfillIPAddresses(t *testing.T) {
|
|||
|
||||
want: types.Nodes{
|
||||
&types.Node{
|
||||
IPv4DatabaseField: sql.NullString{
|
||||
Valid: true,
|
||||
String: "100.64.0.1",
|
||||
},
|
||||
IPv4: nap("100.64.0.1"),
|
||||
IPv6DatabaseField: sql.NullString{
|
||||
Valid: true,
|
||||
String: "fd7a:115c:a1e0::1",
|
||||
},
|
||||
IPv6: nap("fd7a:115c:a1e0::1"),
|
||||
},
|
||||
},
|
||||
|
@ -400,10 +375,6 @@ func TestBackfillIPAddresses(t *testing.T) {
|
|||
|
||||
want: types.Nodes{
|
||||
&types.Node{
|
||||
IPv4DatabaseField: sql.NullString{
|
||||
Valid: true,
|
||||
String: "100.64.0.1",
|
||||
},
|
||||
IPv4: nap("100.64.0.1"),
|
||||
},
|
||||
},
|
||||
|
@ -428,10 +399,6 @@ func TestBackfillIPAddresses(t *testing.T) {
|
|||
|
||||
want: types.Nodes{
|
||||
&types.Node{
|
||||
IPv6DatabaseField: sql.NullString{
|
||||
Valid: true,
|
||||
String: "fd7a:115c:a1e0::1",
|
||||
},
|
||||
IPv6: nap("fd7a:115c:a1e0::1"),
|
||||
},
|
||||
},
|
||||
|
@ -477,13 +444,9 @@ func TestBackfillIPAddresses(t *testing.T) {
|
|||
|
||||
comps := append(util.Comparers, cmpopts.IgnoreFields(types.Node{},
|
||||
"ID",
|
||||
"MachineKeyDatabaseField",
|
||||
"NodeKeyDatabaseField",
|
||||
"DiscoKeyDatabaseField",
|
||||
"User",
|
||||
"UserID",
|
||||
"Endpoints",
|
||||
"HostinfoDatabaseField",
|
||||
"Hostinfo",
|
||||
"Routes",
|
||||
"CreatedAt",
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
@ -207,21 +208,26 @@ func SetTags(
|
|||
) error {
|
||||
if len(tags) == 0 {
|
||||
// if no tags are provided, we remove all forced tags
|
||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", types.StringList{}).Error; err != nil {
|
||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", "[]").Error; err != nil {
|
||||
return fmt.Errorf("failed to remove tags for node in the database: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var newTags types.StringList
|
||||
var newTags []string
|
||||
for _, tag := range tags {
|
||||
if !slices.Contains(newTags, tag) {
|
||||
newTags = append(newTags, tag)
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", newTags).Error; err != nil {
|
||||
b, err := json.Marshal(newTags)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", string(b)).Error; err != nil {
|
||||
return fmt.Errorf("failed to update tags for node in the database: %w", err)
|
||||
}
|
||||
|
||||
|
@ -569,7 +575,7 @@ func enableRoutes(tx *gorm.DB,
|
|||
for _, prefix := range newRoutes {
|
||||
route := types.Route{}
|
||||
err := tx.Preload("Node").
|
||||
Where("node_id = ? AND prefix = ?", node.ID, types.IPPrefix(prefix)).
|
||||
Where("node_id = ? AND prefix = ?", node.ID, prefix.String()).
|
||||
First(&route).Error
|
||||
if err == nil {
|
||||
route.Enabled = true
|
||||
|
|
|
@ -201,7 +201,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
|||
nodeKey := key.NewNode()
|
||||
machineKey := key.NewMachine()
|
||||
|
||||
v4 := netip.MustParseAddr(fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1)))
|
||||
v4 := netip.MustParseAddr(fmt.Sprintf("100.64.0.%d", index+1))
|
||||
node := types.Node{
|
||||
ID: types.NodeID(index),
|
||||
MachineKey: machineKey.Public(),
|
||||
|
@ -239,6 +239,8 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
|||
|
||||
adminNode, err := db.GetNodeByID(1)
|
||||
c.Logf("Node(%v), user: %v", adminNode.Hostname, adminNode.User)
|
||||
c.Assert(adminNode.IPv4, check.NotNil)
|
||||
c.Assert(adminNode.IPv6, check.IsNil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
testNode, err := db.GetNodeByID(2)
|
||||
|
@ -247,9 +249,11 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
|||
|
||||
adminPeers, err := db.ListPeers(adminNode.ID)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(adminPeers), check.Equals, 9)
|
||||
|
||||
testPeers, err := db.ListPeers(testNode.ID)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(testPeers), check.Equals, 9)
|
||||
|
||||
adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
@ -259,14 +263,14 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
|||
|
||||
peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules)
|
||||
peersOfTestNode := policy.FilterNodesByACL(testNode, testPeers, testRules)
|
||||
|
||||
c.Log(peersOfAdminNode)
|
||||
c.Log(peersOfTestNode)
|
||||
|
||||
c.Assert(len(peersOfTestNode), check.Equals, 9)
|
||||
c.Assert(peersOfTestNode[0].Hostname, check.Equals, "testnode1")
|
||||
c.Assert(peersOfTestNode[1].Hostname, check.Equals, "testnode3")
|
||||
c.Assert(peersOfTestNode[3].Hostname, check.Equals, "testnode5")
|
||||
|
||||
c.Log(peersOfAdminNode)
|
||||
c.Assert(len(peersOfAdminNode), check.Equals, 9)
|
||||
c.Assert(peersOfAdminNode[0].Hostname, check.Equals, "testnode2")
|
||||
c.Assert(peersOfAdminNode[2].Hostname, check.Equals, "testnode4")
|
||||
|
@ -346,7 +350,7 @@ func (s *Suite) TestSetTags(c *check.C) {
|
|||
c.Assert(err, check.IsNil)
|
||||
node, err = db.getNode("test", "testnode")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(node.ForcedTags, check.DeepEquals, types.StringList(sTags))
|
||||
c.Assert(node.ForcedTags, check.DeepEquals, sTags)
|
||||
|
||||
// assign duplicate tags, expect no errors but no doubles in DB
|
||||
eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"}
|
||||
|
@ -357,7 +361,7 @@ func (s *Suite) TestSetTags(c *check.C) {
|
|||
c.Assert(
|
||||
node.ForcedTags,
|
||||
check.DeepEquals,
|
||||
types.StringList([]string{"tag:bar", "tag:test", "tag:unknown"}),
|
||||
[]string{"tag:bar", "tag:test", "tag:unknown"},
|
||||
)
|
||||
|
||||
// test removing tags
|
||||
|
@ -365,7 +369,7 @@ func (s *Suite) TestSetTags(c *check.C) {
|
|||
c.Assert(err, check.IsNil)
|
||||
node, err = db.getNode("test", "testnode")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(node.ForcedTags, check.DeepEquals, types.StringList([]string{}))
|
||||
c.Assert(node.ForcedTags, check.DeepEquals, []string{})
|
||||
}
|
||||
|
||||
func TestHeadscale_generateGivenName(t *testing.T) {
|
||||
|
|
|
@ -77,7 +77,7 @@ func CreatePreAuthKey(
|
|||
Ephemeral: ephemeral,
|
||||
CreatedAt: &now,
|
||||
Expiration: expiration,
|
||||
Tags: types.StringList(aclTags),
|
||||
Tags: aclTags,
|
||||
}
|
||||
|
||||
if err := tx.Save(&key).Error; err != nil {
|
||||
|
|
|
@ -49,7 +49,7 @@ func getRoutesByPrefix(tx *gorm.DB, pref netip.Prefix) (types.Routes, error) {
|
|||
err := tx.
|
||||
Preload("Node").
|
||||
Preload("Node.User").
|
||||
Where("prefix = ?", types.IPPrefix(pref)).
|
||||
Where("prefix = ?", pref.String()).
|
||||
Find(&routes).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -286,7 +286,7 @@ func isUniquePrefix(tx *gorm.DB, route types.Route) bool {
|
|||
var count int64
|
||||
tx.Model(&types.Route{}).
|
||||
Where("prefix = ? AND node_id != ? AND advertised = ? AND enabled = ?",
|
||||
route.Prefix,
|
||||
route.Prefix.String(),
|
||||
route.NodeID,
|
||||
true, true).Count(&count)
|
||||
|
||||
|
@ -297,7 +297,7 @@ func getPrimaryRoute(tx *gorm.DB, prefix netip.Prefix) (*types.Route, error) {
|
|||
var route types.Route
|
||||
err := tx.
|
||||
Preload("Node").
|
||||
Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", types.IPPrefix(prefix), true, true, true).
|
||||
Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", prefix.String(), true, true, true).
|
||||
First(&route).Error
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, err
|
||||
|
@ -392,7 +392,7 @@ func SaveNodeRoutes(tx *gorm.DB, node *types.Node) (bool, error) {
|
|||
if !exists {
|
||||
route := types.Route{
|
||||
NodeID: node.ID.Uint64(),
|
||||
Prefix: types.IPPrefix(prefix),
|
||||
Prefix: prefix,
|
||||
Advertised: true,
|
||||
Enabled: false,
|
||||
}
|
||||
|
|
|
@ -290,7 +290,7 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
|
|||
}
|
||||
|
||||
var (
|
||||
ipp = func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) }
|
||||
ipp = func(s string) netip.Prefix { return netip.MustParsePrefix(s) }
|
||||
mkNode = func(nid types.NodeID) types.Node {
|
||||
return types.Node{ID: nid}
|
||||
}
|
||||
|
@ -301,7 +301,7 @@ var np = func(nid types.NodeID) *types.Node {
|
|||
return &no
|
||||
}
|
||||
|
||||
var r = func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) types.Route {
|
||||
var r = func(id uint, nid types.NodeID, prefix netip.Prefix, enabled, primary bool) types.Route {
|
||||
return types.Route{
|
||||
Model: gorm.Model{
|
||||
ID: id,
|
||||
|
@ -313,7 +313,7 @@ var r = func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary
|
|||
}
|
||||
}
|
||||
|
||||
var rp = func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) *types.Route {
|
||||
var rp = func(id uint, nid types.NodeID, prefix netip.Prefix, enabled, primary bool) *types.Route {
|
||||
ro := r(id, nid, prefix, enabled, primary)
|
||||
return &ro
|
||||
}
|
||||
|
@ -1069,7 +1069,7 @@ func TestFailoverRouteTx(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestFailoverRoute(t *testing.T) {
|
||||
r := func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) types.Route {
|
||||
r := func(id uint, nid types.NodeID, prefix netip.Prefix, enabled, primary bool) types.Route {
|
||||
return types.Route{
|
||||
Model: gorm.Model{
|
||||
ID: id,
|
||||
|
@ -1082,7 +1082,7 @@ func TestFailoverRoute(t *testing.T) {
|
|||
IsPrimary: primary,
|
||||
}
|
||||
}
|
||||
rp := func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) *types.Route {
|
||||
rp := func(id uint, nid types.NodeID, prefix netip.Prefix, enabled, primary bool) *types.Route {
|
||||
ro := r(id, nid, prefix, enabled, primary)
|
||||
return &ro
|
||||
}
|
||||
|
@ -1205,13 +1205,6 @@ func TestFailoverRoute(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
cmps := append(
|
||||
util.Comparers,
|
||||
cmp.Comparer(func(x, y types.IPPrefix) bool {
|
||||
return netip.Prefix(x) == netip.Prefix(y)
|
||||
}),
|
||||
)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotf := failoverRoute(smap(tt.isConnected), &tt.failingRoute, tt.routes)
|
||||
|
@ -1235,7 +1228,7 @@ func TestFailoverRoute(t *testing.T) {
|
|||
"old": gotf.old,
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(want, got, cmps...); diff != "" {
|
||||
if diff := cmp.Diff(want, got, util.Comparers...); diff != "" {
|
||||
t.Fatalf("failoverRoute unexpected result (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
|
BIN
hscontrol/db/testdata/0-23-0-to-0-24-0-no-more-special-types.sqlite
vendored
Normal file
BIN
hscontrol/db/testdata/0-23-0-to-0-24-0-no-more-special-types.sqlite
vendored
Normal file
Binary file not shown.
99
hscontrol/db/text_serialiser.go
Normal file
99
hscontrol/db/text_serialiser.go
Normal file
|
@ -0,0 +1,99 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
// Got from https://github.com/xdg-go/strum/blob/main/types.go
|
||||
var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
|
||||
|
||||
func isTextUnmarshaler(rv reflect.Value) bool {
|
||||
return rv.Type().Implements(textUnmarshalerType)
|
||||
}
|
||||
|
||||
func maybeInstantiatePtr(rv reflect.Value) {
|
||||
if rv.Kind() == reflect.Ptr && rv.IsNil() {
|
||||
np := reflect.New(rv.Type().Elem())
|
||||
rv.Set(np)
|
||||
}
|
||||
}
|
||||
|
||||
func decodingError(name string, err error) error {
|
||||
return fmt.Errorf("error decoding to %s: %w", name, err)
|
||||
}
|
||||
|
||||
// TextSerialiser implements the Serialiser interface for fields that
|
||||
// have a type that implements encoding.TextUnmarshaler.
|
||||
type TextSerialiser struct{}
|
||||
|
||||
func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue interface{}) (err error) {
|
||||
fieldValue := reflect.New(field.FieldType)
|
||||
|
||||
// If the field is a pointer, we need to dereference it to get the actual type
|
||||
// so we do not end with a second pointer.
|
||||
if fieldValue.Elem().Kind() == reflect.Ptr {
|
||||
fieldValue = fieldValue.Elem()
|
||||
}
|
||||
|
||||
if dbValue != nil {
|
||||
var bytes []byte
|
||||
switch v := dbValue.(type) {
|
||||
case []byte:
|
||||
bytes = v
|
||||
case string:
|
||||
bytes = []byte(v)
|
||||
default:
|
||||
return fmt.Errorf("failed to unmarshal text value: %#v", dbValue)
|
||||
}
|
||||
|
||||
if isTextUnmarshaler(fieldValue) {
|
||||
maybeInstantiatePtr(fieldValue)
|
||||
f := fieldValue.MethodByName("UnmarshalText")
|
||||
args := []reflect.Value{reflect.ValueOf(bytes)}
|
||||
ret := f.Call(args)
|
||||
if !ret[0].IsNil() {
|
||||
return decodingError(field.Name, ret[0].Interface().(error))
|
||||
}
|
||||
|
||||
// If the underlying field is to a pointer type, we need to
|
||||
// assign the value as a pointer to it.
|
||||
// If it is not a pointer, we need to assign the value to the
|
||||
// field.
|
||||
dstField := field.ReflectValueOf(ctx, dst)
|
||||
if dstField.Kind() == reflect.Ptr {
|
||||
dstField.Set(fieldValue)
|
||||
} else {
|
||||
dstField.Set(fieldValue.Elem())
|
||||
}
|
||||
return nil
|
||||
} else {
|
||||
return fmt.Errorf("unsupported type: %T", fieldValue.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
|
||||
switch v := fieldValue.(type) {
|
||||
case encoding.TextMarshaler:
|
||||
// If the value is nil, we return nil, however, go nil values are not
|
||||
// always comparable, particularly when reflection is involved:
|
||||
// https://dev.to/arxeiss/in-go-nil-is-not-equal-to-nil-sometimes-jn8
|
||||
if v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil()) {
|
||||
return nil, nil
|
||||
}
|
||||
b, err := v.MarshalText()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return string(b), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("only encoding.TextMarshaler is supported, got %t", v)
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue