use gorm serialiser instead of custom hooks (#2156)

* add sqlite to debug/test image

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* test using gorm serialiser instead of custom hooks

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

---------

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2024-10-02 11:41:58 +02:00 committed by GitHub
parent 3964dec1c6
commit bc9e83b52e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 240 additions and 351 deletions

View file

@ -2,11 +2,7 @@ package types
import (
"context"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"net/netip"
"time"
"tailscale.com/tailcfg"
@ -21,74 +17,6 @@ const (
var ErrCannotParsePrefix = errors.New("cannot parse prefix")
type IPPrefix netip.Prefix
func (i *IPPrefix) Scan(destination interface{}) error {
switch value := destination.(type) {
case string:
prefix, err := netip.ParsePrefix(value)
if err != nil {
return err
}
*i = IPPrefix(prefix)
return nil
default:
return fmt.Errorf("%w: unexpected data type %T", ErrCannotParsePrefix, destination)
}
}
// Value return json value, implement driver.Valuer interface.
func (i IPPrefix) Value() (driver.Value, error) {
prefixStr := netip.Prefix(i).String()
return prefixStr, nil
}
type IPPrefixes []netip.Prefix
func (i *IPPrefixes) Scan(destination interface{}) error {
switch value := destination.(type) {
case []byte:
return json.Unmarshal(value, i)
case string:
return json.Unmarshal([]byte(value), i)
default:
return fmt.Errorf("%w: unexpected data type %T", ErrNodeAddressesInvalid, destination)
}
}
// Value return json value, implement driver.Valuer interface.
func (i IPPrefixes) Value() (driver.Value, error) {
bytes, err := json.Marshal(i)
return string(bytes), err
}
type StringList []string
func (i *StringList) Scan(destination interface{}) error {
switch value := destination.(type) {
case []byte:
return json.Unmarshal(value, i)
case string:
return json.Unmarshal([]byte(value), i)
default:
return fmt.Errorf("%w: unexpected data type %T", ErrNodeAddressesInvalid, destination)
}
}
// Value return json value, implement driver.Valuer interface.
func (i StringList) Value() (driver.Value, error) {
bytes, err := json.Marshal(i)
return string(bytes), err
}
type StateUpdateType int
func (su StateUpdateType) String() string {

View file

@ -1,8 +1,6 @@
package types
import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"net/netip"
@ -15,7 +13,6 @@ import (
"github.com/juanfont/headscale/hscontrol/util"
"go4.org/netipx"
"google.golang.org/protobuf/types/known/timestamppb"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
@ -51,54 +48,16 @@ func (id NodeID) String() string {
type Node struct {
ID NodeID `gorm:"primary_key"`
// MachineKeyDatabaseField is the string representation of MachineKey
// it is _only_ used for reading and writing the key to the
// database and should not be used.
// Use MachineKey instead.
MachineKeyDatabaseField string `gorm:"column:machine_key;unique_index"`
MachineKey key.MachinePublic `gorm:"-"`
MachineKey key.MachinePublic `gorm:"serializer:text"`
NodeKey key.NodePublic `gorm:"serializer:text"`
DiscoKey key.DiscoPublic `gorm:"serializer:text"`
// NodeKeyDatabaseField is the string representation of NodeKey
// it is _only_ used for reading and writing the key to the
// database and should not be used.
// Use NodeKey instead.
NodeKeyDatabaseField string `gorm:"column:node_key"`
NodeKey key.NodePublic `gorm:"-"`
Endpoints []netip.AddrPort `gorm:"serializer:json"`
// DiscoKeyDatabaseField is the string representation of DiscoKey
// it is _only_ used for reading and writing the key to the
// database and should not be used.
// Use DiscoKey instead.
DiscoKeyDatabaseField string `gorm:"column:disco_key"`
DiscoKey key.DiscoPublic `gorm:"-"`
Hostinfo *tailcfg.Hostinfo `gorm:"serializer:json"`
// EndpointsDatabaseField is the string list representation of Endpoints
// it is _only_ used for reading and writing the key to the
// database and should not be used.
// Use Endpoints instead.
EndpointsDatabaseField StringList `gorm:"column:endpoints"`
Endpoints []netip.AddrPort `gorm:"-"`
// EndpointsDatabaseField is the string list representation of Endpoints
// it is _only_ used for reading and writing the key to the
// database and should not be used.
// Use Endpoints instead.
HostinfoDatabaseField string `gorm:"column:host_info"`
Hostinfo *tailcfg.Hostinfo `gorm:"-"`
// IPv4DatabaseField is the string representation of v4 address,
// it is _only_ used for reading and writing the key to the
// database and should not be used.
// Use V4 instead.
IPv4DatabaseField sql.NullString `gorm:"column:ipv4"`
IPv4 *netip.Addr `gorm:"-"`
// IPv6DatabaseField is the string representation of v4 address,
// it is _only_ used for reading and writing the key to the
// database and should not be used.
// Use V6 instead.
IPv6DatabaseField sql.NullString `gorm:"column:ipv6"`
IPv6 *netip.Addr `gorm:"-"`
IPv4 *netip.Addr `gorm:"serializer:text"`
IPv6 *netip.Addr `gorm:"serializer:text"`
// Hostname represents the name given by the Tailscale
// client during registration
@ -116,7 +75,7 @@ type Node struct {
RegisterMethod string
ForcedTags StringList
ForcedTags []string `gorm:"serializer:json"`
// TODO(kradalby): This seems like irrelevant information?
AuthKeyID *uint64 `sql:"DEFAULT:NULL"`
@ -216,16 +175,20 @@ func (node *Node) CanAccess(filter []tailcfg.FilterRule, node2 *Node) bool {
src := node.IPs()
allowedIPs := node2.IPs()
// TODO(kradalby): Regenerate this everytime the filter change, instead of
// every time we use it.
matchers := make([]matcher.Match, len(filter))
for i, rule := range filter {
matchers[i] = matcher.MatchFromFilterRule(rule)
}
for _, route := range node2.Routes {
if route.Enabled {
allowedIPs = append(allowedIPs, netip.Prefix(route.Prefix).Addr())
}
}
for _, rule := range filter {
// TODO(kradalby): Cache or pregen this
matcher := matcher.MatchFromFilterRule(rule)
for _, matcher := range matchers {
if !matcher.SrcsContainsIPs(src) {
continue
}
@ -255,109 +218,6 @@ func (nodes Nodes) FilterByIP(ip netip.Addr) Nodes {
return found
}
// BeforeSave is a hook that ensures that some values that
// cannot be directly marshalled into database values are stored
// correctly in the database.
// This currently means storing the keys as strings.
func (node *Node) BeforeSave(tx *gorm.DB) error {
node.MachineKeyDatabaseField = node.MachineKey.String()
node.NodeKeyDatabaseField = node.NodeKey.String()
node.DiscoKeyDatabaseField = node.DiscoKey.String()
var endpoints StringList
for _, addrPort := range node.Endpoints {
endpoints = append(endpoints, addrPort.String())
}
node.EndpointsDatabaseField = endpoints
hi, err := json.Marshal(node.Hostinfo)
if err != nil {
return fmt.Errorf("marshalling Hostinfo to store in db: %w", err)
}
node.HostinfoDatabaseField = string(hi)
if node.IPv4 != nil {
node.IPv4DatabaseField.String, node.IPv4DatabaseField.Valid = node.IPv4.String(), true
} else {
node.IPv4DatabaseField.String, node.IPv4DatabaseField.Valid = "", false
}
if node.IPv6 != nil {
node.IPv6DatabaseField.String, node.IPv6DatabaseField.Valid = node.IPv6.String(), true
} else {
node.IPv6DatabaseField.String, node.IPv6DatabaseField.Valid = "", false
}
return nil
}
// AfterFind is a hook that ensures that Node objects fields that
// has a different type in the database is unwrapped and populated
// correctly.
// This currently unmarshals all the keys, stored as strings, into
// the proper types.
func (node *Node) AfterFind(tx *gorm.DB) error {
var machineKey key.MachinePublic
if err := machineKey.UnmarshalText([]byte(node.MachineKeyDatabaseField)); err != nil {
return fmt.Errorf("unmarshalling machine key from db: %w", err)
}
node.MachineKey = machineKey
var nodeKey key.NodePublic
if err := nodeKey.UnmarshalText([]byte(node.NodeKeyDatabaseField)); err != nil {
return fmt.Errorf("unmarshalling node key from db: %w", err)
}
node.NodeKey = nodeKey
// DiscoKey might be empty if a node has not sent it to headscale.
// This means that this might fail if the disco key is empty.
if node.DiscoKeyDatabaseField != "" {
var discoKey key.DiscoPublic
if err := discoKey.UnmarshalText([]byte(node.DiscoKeyDatabaseField)); err != nil {
return fmt.Errorf("unmarshalling disco key from db: %w", err)
}
node.DiscoKey = discoKey
}
endpoints := make([]netip.AddrPort, len(node.EndpointsDatabaseField))
for idx, ep := range node.EndpointsDatabaseField {
addrPort, err := netip.ParseAddrPort(ep)
if err != nil {
return fmt.Errorf("parsing endpoint from db: %w", err)
}
endpoints[idx] = addrPort
}
node.Endpoints = endpoints
var hi tailcfg.Hostinfo
if err := json.Unmarshal([]byte(node.HostinfoDatabaseField), &hi); err != nil {
return fmt.Errorf("unmarshalling hostinfo from database: %w", err)
}
node.Hostinfo = &hi
if node.IPv4DatabaseField.Valid {
ip, err := netip.ParseAddr(node.IPv4DatabaseField.String)
if err != nil {
return fmt.Errorf("parsing IPv4 from database: %w", err)
}
node.IPv4 = &ip
}
if node.IPv6DatabaseField.Valid {
ip, err := netip.ParseAddr(node.IPv6DatabaseField.String)
if err != nil {
return fmt.Errorf("parsing IPv6 from database: %w", err)
}
node.IPv6 = &ip
}
return nil
}
func (node *Node) Proto() *v1.Node {
nodeProto := &v1.Node{
Id: uint64(node.ID),

View file

@ -17,7 +17,7 @@ type Route struct {
Node Node
// TODO(kradalby): change this custom type to netip.Prefix
Prefix IPPrefix
Prefix netip.Prefix `gorm:"serializer:text"`
Advertised bool
Enabled bool
@ -31,7 +31,7 @@ func (r *Route) String() string {
}
func (r *Route) IsExitRoute() bool {
return tsaddr.IsExitRoute(netip.Prefix(r.Prefix))
return tsaddr.IsExitRoute(r.Prefix)
}
func (r *Route) IsAnnouncable() bool {
@ -59,8 +59,8 @@ func (rs Routes) Primaries() Routes {
return res
}
func (rs Routes) PrefixMap() map[IPPrefix][]Route {
res := map[IPPrefix][]Route{}
func (rs Routes) PrefixMap() map[netip.Prefix][]Route {
res := map[netip.Prefix][]Route{}
for _, route := range rs {
if _, ok := res[route.Prefix]; ok {
@ -80,7 +80,7 @@ func (rs Routes) Proto() []*v1.Route {
protoRoute := v1.Route{
Id: uint64(route.ID),
Node: route.Node.Proto(),
Prefix: netip.Prefix(route.Prefix).String(),
Prefix: route.Prefix.String(),
Advertised: route.Advertised,
Enabled: route.Enabled,
IsPrimary: route.IsPrimary,

View file

@ -10,16 +10,11 @@ import (
)
func TestPrefixMap(t *testing.T) {
ipp := func(s string) IPPrefix { return IPPrefix(netip.MustParsePrefix(s)) }
// TODO(kradalby): Remove when we have gotten rid of IPPrefix type
prefixComparer := cmp.Comparer(func(x, y IPPrefix) bool {
return x == y
})
ipp := func(s string) netip.Prefix { return netip.MustParsePrefix(s) }
tests := []struct {
rs Routes
want map[IPPrefix][]Route
want map[netip.Prefix][]Route
}{
{
rs: Routes{
@ -27,7 +22,7 @@ func TestPrefixMap(t *testing.T) {
Prefix: ipp("10.0.0.0/24"),
},
},
want: map[IPPrefix][]Route{
want: map[netip.Prefix][]Route{
ipp("10.0.0.0/24"): Routes{
Route{
Prefix: ipp("10.0.0.0/24"),
@ -44,7 +39,7 @@ func TestPrefixMap(t *testing.T) {
Prefix: ipp("10.0.1.0/24"),
},
},
want: map[IPPrefix][]Route{
want: map[netip.Prefix][]Route{
ipp("10.0.0.0/24"): Routes{
Route{
Prefix: ipp("10.0.0.0/24"),
@ -68,7 +63,7 @@ func TestPrefixMap(t *testing.T) {
Enabled: false,
},
},
want: map[IPPrefix][]Route{
want: map[netip.Prefix][]Route{
ipp("10.0.0.0/24"): Routes{
Route{
Prefix: ipp("10.0.0.0/24"),
@ -86,7 +81,7 @@ func TestPrefixMap(t *testing.T) {
for idx, tt := range tests {
t.Run(fmt.Sprintf("test-%d", idx), func(t *testing.T) {
got := tt.rs.PrefixMap()
if diff := cmp.Diff(tt.want, got, prefixComparer, util.MkeyComparer, util.NkeyComparer, util.DkeyComparer); diff != "" {
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
t.Errorf("PrefixMap() unexpected result (-want +got):\n%s", diff)
}
})