use IPSet in acls instead of string slice

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2023-04-28 16:11:02 +02:00 committed by Juan Font
parent 1a7ae11697
commit 735b185e7f
5 changed files with 209 additions and 104 deletions

186
acls.go
View file

@ -13,8 +13,8 @@ import (
"time"
"github.com/rs/zerolog/log"
"github.com/samber/lo"
"github.com/tailscale/hujson"
"go4.org/netipx"
"gopkg.in/yaml.v3"
"tailscale.com/envknob"
"tailscale.com/tailcfg"
@ -272,21 +272,41 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) {
principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources))
for innerIndex, rawSrc := range sshACL.Sources {
expandedSrcs, err := h.aclPolicy.expandAlias(
machines,
rawSrc,
h.cfg.OIDC.StripEmaildomain,
)
if err != nil {
log.Error().
Msgf("Error parsing SSH %d, Source %d", index, innerIndex)
return nil, err
}
for _, expandedSrc := range expandedSrcs {
if isWildcard(rawSrc) {
principals = append(principals, &tailcfg.SSHPrincipal{
NodeIP: expandedSrc,
Any: true,
})
} else if isGroup(rawSrc) {
users, err := h.aclPolicy.getUsersInGroup(rawSrc, h.cfg.OIDC.StripEmaildomain)
if err != nil {
log.Error().
Msgf("Error parsing SSH %d, Source %d", index, innerIndex)
return nil, err
}
for _, user := range users {
principals = append(principals, &tailcfg.SSHPrincipal{
UserLogin: user,
})
}
} else {
expandedSrcs, err := h.aclPolicy.expandAlias(
machines,
rawSrc,
h.cfg.OIDC.StripEmaildomain,
)
if err != nil {
log.Error().
Msgf("Error parsing SSH %d, Source %d", index, innerIndex)
return nil, err
}
for _, expandedSrc := range expandedSrcs.Prefixes() {
principals = append(principals, &tailcfg.SSHPrincipal{
NodeIP: expandedSrc.Addr().String(),
})
}
}
}
@ -295,10 +315,9 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) {
userMap[user] = "="
}
rules = append(rules, &tailcfg.SSHRule{
RuleExpires: nil,
Principals: principals,
SSHUsers: userMap,
Action: &action,
Principals: principals,
SSHUsers: userMap,
Action: &action,
})
}
@ -329,7 +348,18 @@ func (pol *ACLPolicy) getIPsFromSource(
machines []Machine,
stripEmaildomain bool,
) ([]string, error) {
return pol.expandAlias(machines, src, stripEmaildomain)
ipSet, err := pol.expandAlias(machines, src, stripEmaildomain)
if err != nil {
return []string{}, err
}
prefixes := []string{}
for _, prefix := range ipSet.Prefixes() {
prefixes = append(prefixes, prefix.String())
}
return prefixes, nil
}
// getNetPortRangeFromDestination returns a set of tailcfg.NetPortRange
@ -397,11 +427,11 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination(
}
dests := []tailcfg.NetPortRange{}
for _, d := range expanded {
for _, p := range *ports {
for _, dest := range expanded.Prefixes() {
for _, port := range *ports {
pr := tailcfg.NetPortRange{
IP: d,
Ports: p,
IP: dest.String(),
Ports: port,
}
dests = append(dests, pr)
}
@ -472,28 +502,30 @@ func (pol *ACLPolicy) expandAlias(
machines Machines,
alias string,
stripEmailDomain bool,
) ([]string, error) {
if alias == "*" {
return []string{"*"}, nil
) (*netipx.IPSet, error) {
if isWildcard(alias) {
return parseIPSet("*", nil)
}
build := netipx.IPSetBuilder{}
log.Debug().
Str("alias", alias).
Msg("Expanding")
// if alias is a group
if strings.HasPrefix(alias, "group:") {
if isGroup(alias) {
return pol.getIPsFromGroup(alias, machines, stripEmailDomain)
}
// if alias is a tag
if strings.HasPrefix(alias, "tag:") {
if isTag(alias) {
return pol.getIPsFromTag(alias, machines, stripEmailDomain)
}
// if alias is a user
if ips := pol.getIPsForUser(alias, machines, stripEmailDomain); len(ips) > 0 {
return ips, nil
if ips, err := pol.getIPsForUser(alias, machines, stripEmailDomain); ips != nil {
return ips, err
}
// if alias is an host
@ -516,7 +548,7 @@ func (pol *ACLPolicy) expandAlias(
log.Warn().Msgf("No IPs found with the alias %v", alias)
return []string{}, nil
return build.IPSet()
}
// excludeCorrectlyTaggedNodes will remove from the list of input nodes the ones
@ -561,7 +593,7 @@ func excludeCorrectlyTaggedNodes(
}
func expandPorts(portsStr string, needsWildcard bool) (*[]tailcfg.PortRange, error) {
if portsStr == "*" {
if isWildcard(portsStr) {
return &[]tailcfg.PortRange{
{First: portRangeBegin, Last: portRangeEnd},
}, nil
@ -636,7 +668,7 @@ func getTagOwners(
)
}
for _, owner := range ows {
if strings.HasPrefix(owner, "group:") {
if isGroup(owner) {
gs, err := pol.getUsersInGroup(owner, stripEmailDomain)
if err != nil {
return []string{}, err
@ -667,7 +699,7 @@ func (pol *ACLPolicy) getUsersInGroup(
)
}
for _, group := range aclGroups {
if strings.HasPrefix(group, "group:") {
if isGroup(group) {
return []string{}, fmt.Errorf(
"%w. A group cannot be composed of groups. https://tailscale.com/kb/1018/acls/#groups",
errInvalidGroup,
@ -691,34 +723,34 @@ func (pol *ACLPolicy) getIPsFromGroup(
group string,
machines Machines,
stripEmailDomain bool,
) ([]string, error) {
ips := []string{}
) (*netipx.IPSet, error) {
build := netipx.IPSetBuilder{}
users, err := pol.getUsersInGroup(group, stripEmailDomain)
if err != nil {
return ips, err
return &netipx.IPSet{}, err
}
for _, n := range users {
nodes := filterMachinesByUser(machines, n)
for _, node := range nodes {
ips = append(ips, node.IPAddresses.ToStringSlice()...)
for _, user := range users {
filteredMachines := filterMachinesByUser(machines, user)
for _, machine := range filteredMachines {
machine.IPAddresses.AppendToIPSet(&build)
}
}
return ips, nil
return build.IPSet()
}
func (pol *ACLPolicy) getIPsFromTag(
alias string,
machines Machines,
stripEmailDomain bool,
) ([]string, error) {
ips := []string{}
) (*netipx.IPSet, error) {
build := netipx.IPSetBuilder{}
// check for forced tags
for _, machine := range machines {
if contains(machine.ForcedTags, alias) {
ips = append(ips, machine.IPAddresses.ToStringSlice()...)
machine.IPAddresses.AppendToIPSet(&build)
}
}
@ -726,17 +758,18 @@ func (pol *ACLPolicy) getIPsFromTag(
owners, err := getTagOwners(pol, alias, stripEmailDomain)
if err != nil {
if errors.Is(err, errInvalidTag) {
if len(ips) == 0 {
return ips, fmt.Errorf(
ipSet, _ := build.IPSet()
if len(ipSet.Prefixes()) == 0 {
return ipSet, fmt.Errorf(
"%w. %v isn't owned by a TagOwner and no forced tags are defined",
errInvalidTag,
alias,
)
}
return ips, nil
return build.IPSet()
} else {
return ips, err
return nil, err
}
}
@ -746,53 +779,62 @@ func (pol *ACLPolicy) getIPsFromTag(
for _, machine := range machines {
hi := machine.GetHostInfo()
if contains(hi.RequestTags, alias) {
ips = append(ips, machine.IPAddresses.ToStringSlice()...)
machine.IPAddresses.AppendToIPSet(&build)
}
}
}
return ips, nil
return build.IPSet()
}
func (pol *ACLPolicy) getIPsForUser(
user string,
machines Machines,
stripEmailDomain bool,
) []string {
ips := []string{}
) (*netipx.IPSet, error) {
build := netipx.IPSetBuilder{}
nodes := filterMachinesByUser(machines, user)
nodes = excludeCorrectlyTaggedNodes(pol, nodes, user, stripEmailDomain)
filteredMachines := filterMachinesByUser(machines, user)
filteredMachines = excludeCorrectlyTaggedNodes(pol, filteredMachines, user, stripEmailDomain)
for _, n := range nodes {
ips = append(ips, n.IPAddresses.ToStringSlice()...)
// shortcurcuit if we have no machines to get ips from.
if len(filteredMachines) == 0 {
return nil, nil //nolint
}
return ips
for _, machine := range filteredMachines {
machine.IPAddresses.AppendToIPSet(&build)
}
return build.IPSet()
}
func (pol *ACLPolicy) getIPsFromSingleIP(
ip netip.Addr,
machines Machines,
) ([]string, error) {
) (*netipx.IPSet, error) {
log.Trace().Str("ip", ip.String()).Msg("expandAlias got ip")
ips := []string{ip.String()}
matches := machines.FilterByIP(ip)
build := netipx.IPSetBuilder{}
build.Add(ip)
for _, machine := range matches {
ips = append(ips, machine.IPAddresses.ToStringSlice()...)
machine.IPAddresses.AppendToIPSet(&build)
}
return lo.Uniq(ips), nil
return build.IPSet()
}
func (pol *ACLPolicy) getIPsFromIPPrefix(
prefix netip.Prefix,
machines Machines,
) ([]string, error) {
) (*netipx.IPSet, error) {
log.Trace().Str("prefix", prefix.String()).Msg("expandAlias got prefix")
val := []string{prefix.String()}
build := netipx.IPSetBuilder{}
build.AddPrefix(prefix)
// This is suboptimal and quite expensive, but if we only add the prefix, we will miss all the relevant IPv6
// addresses for the hosts that belong to tailscale. This doesnt really affect stuff like subnet routers.
for _, machine := range machines {
@ -800,10 +842,22 @@ func (pol *ACLPolicy) getIPsFromIPPrefix(
// log.Trace().
// Msgf("checking if machine ip (%s) is part of prefix (%s): %v, is single ip prefix (%v), addr: %s", ip.String(), prefix.String(), prefix.Contains(ip), prefix.IsSingleIP(), prefix.Addr().String())
if prefix.Contains(ip) {
val = append(val, machine.IPAddresses.ToStringSlice()...)
machine.IPAddresses.AppendToIPSet(&build)
}
}
}
return lo.Uniq(val), nil
return build.IPSet()
}
func isWildcard(str string) bool {
return str == "*"
}
func isGroup(str string) bool {
return strings.HasPrefix(str, "group:")
}
func isTag(str string) bool {
return strings.HasPrefix(str, "tag:")
}