create DB struct
This is step one in detaching the Database layer from Headscale (h). The ultimate goal is to have all function that does database operations in its own package, and keep the business logic and writing separate. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
b01f1f1867
commit
14e29a7bee
48 changed files with 1731 additions and 1572 deletions
|
@ -12,6 +12,7 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/tailscale/hujson"
|
||||
"go4.org/netipx"
|
||||
|
@ -20,21 +21,16 @@ import (
|
|||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
const (
|
||||
errEmptyPolicy = Error("empty policy")
|
||||
errInvalidAction = Error("invalid action")
|
||||
errInvalidGroup = Error("invalid group")
|
||||
errInvalidTag = Error("invalid tag")
|
||||
errInvalidPortFormat = Error("invalid port format")
|
||||
errWildcardIsNeeded = Error("wildcard as port is required for the protocol")
|
||||
var (
|
||||
errEmptyPolicy = errors.New("empty policy")
|
||||
errInvalidAction = errors.New("invalid action")
|
||||
errInvalidGroup = errors.New("invalid group")
|
||||
errInvalidTag = errors.New("invalid tag")
|
||||
errInvalidPortFormat = errors.New("invalid port format")
|
||||
errWildcardIsNeeded = errors.New("wildcard as port is required for the protocol")
|
||||
)
|
||||
|
||||
const (
|
||||
Base8 = 8
|
||||
Base10 = 10
|
||||
BitSize16 = 16
|
||||
BitSize32 = 32
|
||||
BitSize64 = 64
|
||||
portRangeBegin = 0
|
||||
portRangeEnd = 65535
|
||||
expectedTokenItems = 2
|
||||
|
@ -123,7 +119,7 @@ func (h *Headscale) LoadACLPolicyFromBytes(acl []byte, format string) error {
|
|||
}
|
||||
|
||||
func (h *Headscale) UpdateACLRules() error {
|
||||
machines, err := h.ListMachines()
|
||||
machines, err := h.db.ListMachines()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -230,7 +226,7 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) {
|
|||
return nil, errEmptyPolicy
|
||||
}
|
||||
|
||||
machines, err := h.ListMachines()
|
||||
machines, err := h.db.ListMachines()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -570,7 +566,7 @@ func excludeCorrectlyTaggedNodes(
|
|||
for tag := range aclPolicy.TagOwners {
|
||||
owners, _ := getTagOwners(aclPolicy, user, stripEmailDomain)
|
||||
ns := append(owners, user)
|
||||
if contains(ns, user) {
|
||||
if util.StringOrPrefixListContains(ns, user) {
|
||||
tags = append(tags, tag)
|
||||
}
|
||||
}
|
||||
|
@ -580,7 +576,7 @@ func excludeCorrectlyTaggedNodes(
|
|||
|
||||
found := false
|
||||
for _, t := range hi.RequestTags {
|
||||
if contains(tags, t) {
|
||||
if util.StringOrPrefixListContains(tags, t) {
|
||||
found = true
|
||||
|
||||
break
|
||||
|
@ -614,7 +610,7 @@ func expandPorts(portsStr string, needsWildcard bool) (*[]tailcfg.PortRange, err
|
|||
rang := strings.Split(portStr, "-")
|
||||
switch len(rang) {
|
||||
case 1:
|
||||
port, err := strconv.ParseUint(rang[0], Base10, BitSize16)
|
||||
port, err := strconv.ParseUint(rang[0], util.Base10, util.BitSize16)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -624,11 +620,11 @@ func expandPorts(portsStr string, needsWildcard bool) (*[]tailcfg.PortRange, err
|
|||
})
|
||||
|
||||
case expectedTokenItems:
|
||||
start, err := strconv.ParseUint(rang[0], Base10, BitSize16)
|
||||
start, err := strconv.ParseUint(rang[0], util.Base10, util.BitSize16)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
last, err := strconv.ParseUint(rang[1], Base10, BitSize16)
|
||||
last, err := strconv.ParseUint(rang[1], util.Base10, util.BitSize16)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -754,7 +750,7 @@ func (pol *ACLPolicy) getIPsFromTag(
|
|||
|
||||
// check for forced tags
|
||||
for _, machine := range machines {
|
||||
if contains(machine.ForcedTags, alias) {
|
||||
if util.StringOrPrefixListContains(machine.ForcedTags, alias) {
|
||||
machine.IPAddresses.AppendToIPSet(&build)
|
||||
}
|
||||
}
|
||||
|
@ -783,7 +779,7 @@ func (pol *ACLPolicy) getIPsFromTag(
|
|||
machines := filterMachinesByUser(machines, user)
|
||||
for _, machine := range machines {
|
||||
hi := machine.GetHostInfo()
|
||||
if contains(hi.RequestTags, alias) {
|
||||
if util.StringOrPrefixListContains(hi.RequestTags, alias) {
|
||||
machine.IPAddresses.AppendToIPSet(&build)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue