Use specific types for all fields on machine (no datatypes.json)

This commit removes the need for datatypes.JSON and makes the code a bit
cleaner by allowing us to use proper types throughout the code when it
comes to hostinfo and other datatypes on the machine object.

This allows us to remove alot of unmarshal/marshal operations and remove
a lot of obsolete error checks.

This following commits will clean away a lot of untyped data and
uneccessary error checks.
This commit is contained in:
Kristoffer Dalby 2022-03-01 16:31:25 +00:00
parent 94c5474212
commit 8a95fe517a
4 changed files with 101 additions and 173 deletions

View file

@ -2,7 +2,6 @@ package headscale
import (
"database/sql/driver"
"encoding/json"
"fmt"
"sort"
"strconv"
@ -13,7 +12,6 @@ import (
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/rs/zerolog/log"
"google.golang.org/protobuf/types/known/timestamppb"
"gorm.io/datatypes"
"inet.af/netaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
@ -51,9 +49,9 @@ type Machine struct {
LastSuccessfulUpdate *time.Time
Expiry *time.Time
HostInfo datatypes.JSON
Endpoints datatypes.JSON
EnabledRoutes datatypes.JSON
HostInfo HostInfo
Endpoints StringList
EnabledRoutes IPPrefixes
CreatedAt time.Time
UpdatedAt time.Time
@ -393,20 +391,8 @@ func (h *Headscale) HardDeleteMachine(machine *Machine) error {
}
// GetHostInfo returns a Hostinfo struct for the machine.
func (machine *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) {
hostinfo := tailcfg.Hostinfo{}
if len(machine.HostInfo) != 0 {
hi, err := machine.HostInfo.MarshalJSON()
if err != nil {
return nil, err
}
err = json.Unmarshal(hi, &hostinfo)
if err != nil {
return nil, err
}
}
return &hostinfo, nil
func (machine *Machine) GetHostInfo() tailcfg.Hostinfo {
return tailcfg.Hostinfo(machine.HostInfo)
}
func (h *Headscale) isOutdated(machine *Machine) bool {
@ -536,54 +522,12 @@ func (machine Machine) toNode(
// TODO(kradalby): Needs investigation, We probably dont need this condition
// now that we dont have shared nodes
if includeRoutes {
routesStr := []string{}
if len(machine.EnabledRoutes) != 0 {
allwIps, err := machine.EnabledRoutes.MarshalJSON()
if err != nil {
return nil, err
}
err = json.Unmarshal(allwIps, &routesStr)
if err != nil {
return nil, err
}
}
for _, routeStr := range routesStr {
ip, err := netaddr.ParseIPPrefix(routeStr)
if err != nil {
return nil, err
}
allowedIPs = append(allowedIPs, ip)
}
}
endpoints := []string{}
if len(machine.Endpoints) != 0 {
be, err := machine.Endpoints.MarshalJSON()
if err != nil {
return nil, err
}
err = json.Unmarshal(be, &endpoints)
if err != nil {
return nil, err
}
}
hostinfo := tailcfg.Hostinfo{}
if len(machine.HostInfo) != 0 {
hi, err := machine.HostInfo.MarshalJSON()
if err != nil {
return nil, err
}
err = json.Unmarshal(hi, &hostinfo)
if err != nil {
return nil, err
}
allowedIPs = append(allowedIPs, machine.EnabledRoutes...)
}
var derp string
if hostinfo.NetInfo != nil {
derp = fmt.Sprintf("127.3.3.40:%d", hostinfo.NetInfo.PreferredDERP)
if machine.HostInfo.NetInfo != nil {
derp = fmt.Sprintf("127.3.3.40:%d", machine.HostInfo.NetInfo.PreferredDERP)
} else {
derp = "127.3.3.40:0" // Zero means disconnected or unknown.
}
@ -614,6 +558,8 @@ func (machine Machine) toNode(
hostname = machine.Name
}
hostInfo := machine.GetHostInfo()
node := tailcfg.Node{
ID: tailcfg.NodeID(machine.ID), // this is the actual ID
StableID: tailcfg.StableNodeID(
@ -627,10 +573,10 @@ func (machine Machine) toNode(
DiscoKey: discoKey,
Addresses: addrs,
AllowedIPs: allowedIPs,
Endpoints: endpoints,
Endpoints: machine.Endpoints,
DERP: derp,
Hostinfo: hostinfo.View(),
Hostinfo: hostInfo.View(),
Created: machine.CreatedAt,
LastSeen: machine.LastSeen,
@ -786,37 +732,12 @@ func (h *Headscale) RegisterMachine(
return machine, nil
}
func (machine *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) {
hostInfo, err := machine.GetHostInfo()
if err != nil {
return nil, err
}
return hostInfo.RoutableIPs, nil
func (machine *Machine) GetAdvertisedRoutes() []netaddr.IPPrefix {
return machine.HostInfo.RoutableIPs
}
func (machine *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) {
data, err := machine.EnabledRoutes.MarshalJSON()
if err != nil {
return nil, err
}
routesStr := []string{}
err = json.Unmarshal(data, &routesStr)
if err != nil {
return nil, err
}
routes := make([]netaddr.IPPrefix, len(routesStr))
for index, routeStr := range routesStr {
route, err := netaddr.ParseIPPrefix(routeStr)
if err != nil {
return nil, err
}
routes[index] = route
}
return routes, nil
func (machine *Machine) GetEnabledRoutes() []netaddr.IPPrefix {
return machine.EnabledRoutes
}
func (machine *Machine) IsRoutesEnabled(routeStr string) bool {
@ -825,10 +746,7 @@ func (machine *Machine) IsRoutesEnabled(routeStr string) bool {
return false
}
enabledRoutes, err := machine.GetEnabledRoutes()
if err != nil {
return false
}
enabledRoutes := machine.GetEnabledRoutes()
for _, enabledRoute := range enabledRoutes {
if route == enabledRoute {
@ -852,13 +770,8 @@ func (h *Headscale) EnableRoutes(machine *Machine, routeStrs ...string) error {
newRoutes[index] = route
}
availableRoutes, err := machine.GetAdvertisedRoutes()
if err != nil {
return err
}
for _, newRoute := range newRoutes {
if !containsIPPrefix(availableRoutes, newRoute) {
if !containsIPPrefix(machine.GetAdvertisedRoutes(), newRoute) {
return fmt.Errorf(
"route (%s) is not available on node %s: %w",
machine.Name,
@ -867,30 +780,19 @@ func (h *Headscale) EnableRoutes(machine *Machine, routeStrs ...string) error {
}
}
routes, err := json.Marshal(newRoutes)
if err != nil {
return err
}
machine.EnabledRoutes = datatypes.JSON(routes)
machine.EnabledRoutes = newRoutes
h.db.Save(&machine)
return nil
}
func (machine *Machine) RoutesToProto() (*v1.Routes, error) {
availableRoutes, err := machine.GetAdvertisedRoutes()
if err != nil {
return nil, err
}
func (machine *Machine) RoutesToProto() *v1.Routes {
availableRoutes := machine.GetAdvertisedRoutes()
enabledRoutes, err := machine.GetEnabledRoutes()
if err != nil {
return nil, err
}
enabledRoutes := machine.GetEnabledRoutes()
return &v1.Routes{
AdvertisedRoutes: ipPrefixToString(availableRoutes),
EnabledRoutes: ipPrefixToString(enabledRoutes),
}, nil
}
}