Migrate IP fields in database to dedicated columns (#1869)
This commit is contained in:
parent
85cef84e17
commit
2ce23df45a
39 changed files with 1885 additions and 1055 deletions
|
@ -1,12 +1,11 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -14,7 +13,6 @@ import (
|
|||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"go4.org/netipx"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
"gorm.io/gorm"
|
||||
|
@ -83,7 +81,19 @@ type Node struct {
|
|||
HostinfoDatabaseField string `gorm:"column:host_info"`
|
||||
Hostinfo *tailcfg.Hostinfo `gorm:"-"`
|
||||
|
||||
IPAddresses NodeAddresses
|
||||
// IPv4DatabaseField is the string representation of v4 address,
|
||||
// it is _only_ used for reading and writing the key to the
|
||||
// database and should not be used.
|
||||
// Use V4 instead.
|
||||
IPv4DatabaseField sql.NullString `gorm:"column:ipv4"`
|
||||
IPv4 *netip.Addr `gorm:"-"`
|
||||
|
||||
// IPv6DatabaseField is the string representation of v4 address,
|
||||
// it is _only_ used for reading and writing the key to the
|
||||
// database and should not be used.
|
||||
// Use V6 instead.
|
||||
IPv6DatabaseField sql.NullString `gorm:"column:ipv6"`
|
||||
IPv6 *netip.Addr `gorm:"-"`
|
||||
|
||||
// Hostname represents the name given by the Tailscale
|
||||
// client during registration
|
||||
|
@ -123,89 +133,6 @@ type (
|
|||
Nodes []*Node
|
||||
)
|
||||
|
||||
type NodeAddresses []netip.Addr
|
||||
|
||||
func (na NodeAddresses) Sort() {
|
||||
sort.Slice(na, func(index1, index2 int) bool {
|
||||
if na[index1].Is4() && na[index2].Is6() {
|
||||
return true
|
||||
}
|
||||
if na[index1].Is6() && na[index2].Is4() {
|
||||
return false
|
||||
}
|
||||
|
||||
return na[index1].Compare(na[index2]) < 0
|
||||
})
|
||||
}
|
||||
|
||||
func (na NodeAddresses) StringSlice() []string {
|
||||
na.Sort()
|
||||
strSlice := make([]string, 0, len(na))
|
||||
for _, addr := range na {
|
||||
strSlice = append(strSlice, addr.String())
|
||||
}
|
||||
|
||||
return strSlice
|
||||
}
|
||||
|
||||
func (na NodeAddresses) Prefixes() []netip.Prefix {
|
||||
addrs := []netip.Prefix{}
|
||||
for _, nodeAddress := range na {
|
||||
ip := netip.PrefixFrom(nodeAddress, nodeAddress.BitLen())
|
||||
addrs = append(addrs, ip)
|
||||
}
|
||||
|
||||
return addrs
|
||||
}
|
||||
|
||||
func (na NodeAddresses) InIPSet(set *netipx.IPSet) bool {
|
||||
for _, nodeAddr := range na {
|
||||
if set.Contains(nodeAddr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// AppendToIPSet adds the individual ips in NodeAddresses to a
|
||||
// given netipx.IPSetBuilder.
|
||||
func (na NodeAddresses) AppendToIPSet(build *netipx.IPSetBuilder) {
|
||||
for _, ip := range na {
|
||||
build.Add(ip)
|
||||
}
|
||||
}
|
||||
|
||||
func (na *NodeAddresses) Scan(destination interface{}) error {
|
||||
switch value := destination.(type) {
|
||||
case string:
|
||||
addresses := strings.Split(value, ",")
|
||||
*na = (*na)[:0]
|
||||
for _, addr := range addresses {
|
||||
if len(addr) < 1 {
|
||||
continue
|
||||
}
|
||||
parsed, err := netip.ParseAddr(addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*na = append(*na, parsed)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
default:
|
||||
return fmt.Errorf("%w: unexpected data type %T", ErrNodeAddressesInvalid, destination)
|
||||
}
|
||||
}
|
||||
|
||||
// Value return json value, implement driver.Valuer interface.
|
||||
func (na NodeAddresses) Value() (driver.Value, error) {
|
||||
addresses := strings.Join(na.StringSlice(), ",")
|
||||
|
||||
return addresses, nil
|
||||
}
|
||||
|
||||
// IsExpired returns whether the node registration has expired.
|
||||
func (node Node) IsExpired() bool {
|
||||
// If Expiry is not set, the client has not indicated that
|
||||
|
@ -224,8 +151,65 @@ func (node *Node) IsEphemeral() bool {
|
|||
return node.AuthKey != nil && node.AuthKey.Ephemeral
|
||||
}
|
||||
|
||||
func (node *Node) IPs() []netip.Addr {
|
||||
var ret []netip.Addr
|
||||
|
||||
if node.IPv4 != nil {
|
||||
ret = append(ret, *node.IPv4)
|
||||
}
|
||||
|
||||
if node.IPv6 != nil {
|
||||
ret = append(ret, *node.IPv6)
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func (node *Node) Prefixes() []netip.Prefix {
|
||||
addrs := []netip.Prefix{}
|
||||
for _, nodeAddress := range node.IPs() {
|
||||
ip := netip.PrefixFrom(nodeAddress, nodeAddress.BitLen())
|
||||
addrs = append(addrs, ip)
|
||||
}
|
||||
|
||||
return addrs
|
||||
}
|
||||
|
||||
func (node *Node) IPsAsString() []string {
|
||||
var ret []string
|
||||
|
||||
if node.IPv4 != nil {
|
||||
ret = append(ret, node.IPv4.String())
|
||||
}
|
||||
|
||||
if node.IPv6 != nil {
|
||||
ret = append(ret, node.IPv6.String())
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func (node *Node) InIPSet(set *netipx.IPSet) bool {
|
||||
for _, nodeAddr := range node.IPs() {
|
||||
if set.Contains(nodeAddr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// AppendToIPSet adds the individual ips in NodeAddresses to a
|
||||
// given netipx.IPSetBuilder.
|
||||
func (node *Node) AppendToIPSet(build *netipx.IPSetBuilder) {
|
||||
for _, ip := range node.IPs() {
|
||||
build.Add(ip)
|
||||
}
|
||||
}
|
||||
|
||||
func (node *Node) CanAccess(filter []tailcfg.FilterRule, node2 *Node) bool {
|
||||
allowedIPs := append([]netip.Addr{}, node2.IPAddresses...)
|
||||
src := node.IPs()
|
||||
allowedIPs := node2.IPs()
|
||||
|
||||
for _, route := range node2.Routes {
|
||||
if route.Enabled {
|
||||
|
@ -237,7 +221,7 @@ func (node *Node) CanAccess(filter []tailcfg.FilterRule, node2 *Node) bool {
|
|||
// TODO(kradalby): Cache or pregen this
|
||||
matcher := matcher.MatchFromFilterRule(rule)
|
||||
|
||||
if !matcher.SrcsContainsIPs([]netip.Addr(node.IPAddresses)) {
|
||||
if !matcher.SrcsContainsIPs(src) {
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -250,13 +234,16 @@ func (node *Node) CanAccess(filter []tailcfg.FilterRule, node2 *Node) bool {
|
|||
}
|
||||
|
||||
func (nodes Nodes) FilterByIP(ip netip.Addr) Nodes {
|
||||
found := make(Nodes, 0)
|
||||
var found Nodes
|
||||
|
||||
for _, node := range nodes {
|
||||
for _, mIP := range node.IPAddresses {
|
||||
if ip == mIP {
|
||||
found = append(found, node)
|
||||
}
|
||||
if node.IPv4 != nil && ip == *node.IPv4 {
|
||||
found = append(found, node)
|
||||
continue
|
||||
}
|
||||
|
||||
if node.IPv6 != nil && ip == *node.IPv6 {
|
||||
found = append(found, node)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -281,10 +268,22 @@ func (node *Node) BeforeSave(tx *gorm.DB) error {
|
|||
|
||||
hi, err := json.Marshal(node.Hostinfo)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal Hostinfo to store in db: %w", err)
|
||||
return fmt.Errorf("marshalling Hostinfo to store in db: %w", err)
|
||||
}
|
||||
node.HostinfoDatabaseField = string(hi)
|
||||
|
||||
if node.IPv4 != nil {
|
||||
node.IPv4DatabaseField.String, node.IPv4DatabaseField.Valid = node.IPv4.String(), true
|
||||
} else {
|
||||
node.IPv4DatabaseField.String, node.IPv4DatabaseField.Valid = "", false
|
||||
}
|
||||
|
||||
if node.IPv6 != nil {
|
||||
node.IPv6DatabaseField.String, node.IPv6DatabaseField.Valid = node.IPv6.String(), true
|
||||
} else {
|
||||
node.IPv6DatabaseField.String, node.IPv6DatabaseField.Valid = "", false
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -296,19 +295,19 @@ func (node *Node) BeforeSave(tx *gorm.DB) error {
|
|||
func (node *Node) AfterFind(tx *gorm.DB) error {
|
||||
var machineKey key.MachinePublic
|
||||
if err := machineKey.UnmarshalText([]byte(node.MachineKeyDatabaseField)); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal machine key from db: %w", err)
|
||||
return fmt.Errorf("unmarshalling machine key from db: %w", err)
|
||||
}
|
||||
node.MachineKey = machineKey
|
||||
|
||||
var nodeKey key.NodePublic
|
||||
if err := nodeKey.UnmarshalText([]byte(node.NodeKeyDatabaseField)); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal node key from db: %w", err)
|
||||
return fmt.Errorf("unmarshalling node key from db: %w", err)
|
||||
}
|
||||
node.NodeKey = nodeKey
|
||||
|
||||
var discoKey key.DiscoPublic
|
||||
if err := discoKey.UnmarshalText([]byte(node.DiscoKeyDatabaseField)); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal disco key from db: %w", err)
|
||||
return fmt.Errorf("unmarshalling disco key from db: %w", err)
|
||||
}
|
||||
node.DiscoKey = discoKey
|
||||
|
||||
|
@ -316,7 +315,7 @@ func (node *Node) AfterFind(tx *gorm.DB) error {
|
|||
for idx, ep := range node.EndpointsDatabaseField {
|
||||
addrPort, err := netip.ParseAddrPort(ep)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse endpoint from db: %w", err)
|
||||
return fmt.Errorf("parsing endpoint from db: %w", err)
|
||||
}
|
||||
|
||||
endpoints[idx] = addrPort
|
||||
|
@ -325,12 +324,28 @@ func (node *Node) AfterFind(tx *gorm.DB) error {
|
|||
|
||||
var hi tailcfg.Hostinfo
|
||||
if err := json.Unmarshal([]byte(node.HostinfoDatabaseField), &hi); err != nil {
|
||||
log.Trace().Err(err).Msgf("Hostinfo content: %s", node.HostinfoDatabaseField)
|
||||
|
||||
return fmt.Errorf("failed to unmarshal Hostinfo from db: %w", err)
|
||||
return fmt.Errorf("unmarshalling hostinfo from database: %w", err)
|
||||
}
|
||||
node.Hostinfo = &hi
|
||||
|
||||
if node.IPv4DatabaseField.Valid {
|
||||
ip, err := netip.ParseAddr(node.IPv4DatabaseField.String)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing IPv4 from database: %w", err)
|
||||
}
|
||||
|
||||
node.IPv4 = &ip
|
||||
}
|
||||
|
||||
if node.IPv6DatabaseField.Valid {
|
||||
ip, err := netip.ParseAddr(node.IPv6DatabaseField.String)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing IPv6 from database: %w", err)
|
||||
}
|
||||
|
||||
node.IPv6 = &ip
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -339,9 +354,11 @@ func (node *Node) Proto() *v1.Node {
|
|||
Id: uint64(node.ID),
|
||||
MachineKey: node.MachineKey.String(),
|
||||
|
||||
NodeKey: node.NodeKey.String(),
|
||||
DiscoKey: node.DiscoKey.String(),
|
||||
IpAddresses: node.IPAddresses.StringSlice(),
|
||||
NodeKey: node.NodeKey.String(),
|
||||
DiscoKey: node.DiscoKey.String(),
|
||||
|
||||
// TODO(kradalby): replace list with v4, v6 field?
|
||||
IpAddresses: node.IPsAsString(),
|
||||
Name: node.Hostname,
|
||||
GivenName: node.GivenName,
|
||||
User: node.User.Proto(),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue