remove "stripEmailDomain" argument

This commit makes a wrapper function round the normalisation requiring
"stripEmailDomain" which has to be passed in almost all functions of
headscale by loading it from Viper instead.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2023-06-12 15:29:34 +02:00 committed by Kristoffer Dalby
parent 161243c787
commit 717abe89c1
16 changed files with 127 additions and 220 deletions

View file

@ -10,6 +10,7 @@ import (
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
"go4.org/netipx"
"gopkg.in/check.v1"
@ -199,7 +200,7 @@ func (s *Suite) TestRuleInvalidGeneration(c *check.C) {
c.Assert(pol.ACLs, check.HasLen, 6)
c.Assert(err, check.IsNil)
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false)
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{})
c.Assert(err, check.NotNil)
c.Assert(rules, check.IsNil)
}
@ -230,7 +231,7 @@ func (s *Suite) TestBasicRule(c *check.C) {
pol, err := LoadACLPolicyFromBytes(acl, "hujson")
c.Assert(err, check.IsNil)
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false)
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{})
c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil)
}
@ -246,7 +247,7 @@ func (s *Suite) TestInvalidAction(c *check.C) {
},
},
}
_, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{}, false)
_, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{})
c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true)
}
@ -265,7 +266,7 @@ func (s *Suite) TestInvalidGroupInGroup(c *check.C) {
},
},
}
_, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{}, false)
_, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{})
c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true)
}
@ -281,7 +282,7 @@ func (s *Suite) TestInvalidTagOwners(c *check.C) {
},
}
_, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{}, false)
_, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{})
c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true)
}
@ -310,7 +311,7 @@ func (s *Suite) TestPortRange(c *check.C) {
c.Assert(err, check.IsNil)
c.Assert(pol, check.NotNil)
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false)
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{})
c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil)
@ -366,7 +367,7 @@ func (s *Suite) TestProtocolParsing(c *check.C) {
c.Assert(err, check.IsNil)
c.Assert(pol, check.NotNil)
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false)
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{})
c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil)
@ -401,7 +402,7 @@ func (s *Suite) TestPortWildcard(c *check.C) {
c.Assert(err, check.IsNil)
c.Assert(pol, check.NotNil)
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false)
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{})
c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil)
@ -428,7 +429,7 @@ acls:
c.Assert(err, check.IsNil)
c.Assert(pol, check.NotNil)
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false)
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{})
c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil)
@ -459,7 +460,7 @@ acls:
c.Assert(err, check.IsNil)
c.Assert(pol, check.NotNil)
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false)
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{})
c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil)
@ -483,8 +484,8 @@ func Test_expandGroup(t *testing.T) {
pol ACLPolicy
}
type args struct {
group string
stripEmailDomain bool
group string
stripEmail bool
}
tests := []struct {
name string
@ -504,8 +505,7 @@ func Test_expandGroup(t *testing.T) {
},
},
args: args{
group: "group:test",
stripEmailDomain: true,
group: "group:test",
},
want: []string{"user1", "user2", "user3"},
wantErr: false,
@ -521,14 +521,13 @@ func Test_expandGroup(t *testing.T) {
},
},
args: args{
group: "group:undefined",
stripEmailDomain: true,
group: "group:undefined",
},
want: []string{},
wantErr: true,
},
{
name: "Expand emails in group",
name: "Expand emails in group strip domains",
field: field{
pol: ACLPolicy{
Groups: Groups{
@ -540,8 +539,8 @@ func Test_expandGroup(t *testing.T) {
},
},
args: args{
group: "group:admin",
stripEmailDomain: true,
group: "group:admin",
stripEmail: true,
},
want: []string{"joe.bar", "john.doe"},
wantErr: false,
@ -559,8 +558,7 @@ func Test_expandGroup(t *testing.T) {
},
},
args: args{
group: "group:admin",
stripEmailDomain: false,
group: "group:admin",
},
want: []string{"joe.bar.gmail.com", "john.doe.yahoo.fr"},
wantErr: false,
@ -568,17 +566,20 @@ func Test_expandGroup(t *testing.T) {
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
viper.Set("oidc.strip_email_domain", test.args.stripEmail)
got, err := test.field.pol.getUsersInGroup(
test.args.group,
test.args.stripEmailDomain,
)
if (err != nil) != test.wantErr {
t.Errorf("expandGroup() error = %v, wantErr %v", err, test.wantErr)
return
}
if !reflect.DeepEqual(got, test.want) {
t.Errorf("expandGroup() = %v, want %v", got, test.want)
if diff := cmp.Diff(test.want, got); diff != "" {
t.Errorf("expandGroup() unexpected result (-want +got):\n%s", diff)
}
})
}
@ -586,9 +587,8 @@ func Test_expandGroup(t *testing.T) {
func Test_expandTagOwners(t *testing.T) {
type args struct {
aclPolicy *ACLPolicy
tag string
stripEmailDomain bool
aclPolicy *ACLPolicy
tag string
}
tests := []struct {
name string
@ -602,8 +602,7 @@ func Test_expandTagOwners(t *testing.T) {
aclPolicy: &ACLPolicy{
TagOwners: TagOwners{"tag:test": []string{"user1"}},
},
tag: "tag:test",
stripEmailDomain: true,
tag: "tag:test",
},
want: []string{"user1"},
wantErr: false,
@ -615,8 +614,7 @@ func Test_expandTagOwners(t *testing.T) {
Groups: Groups{"group:foo": []string{"user1", "user2"}},
TagOwners: TagOwners{"tag:test": []string{"group:foo"}},
},
tag: "tag:test",
stripEmailDomain: true,
tag: "tag:test",
},
want: []string{"user1", "user2"},
wantErr: false,
@ -628,8 +626,7 @@ func Test_expandTagOwners(t *testing.T) {
Groups: Groups{"group:foo": []string{"user1", "user2"}},
TagOwners: TagOwners{"tag:test": []string{"group:foo", "user3"}},
},
tag: "tag:test",
stripEmailDomain: true,
tag: "tag:test",
},
want: []string{"user1", "user2", "user3"},
wantErr: false,
@ -640,8 +637,7 @@ func Test_expandTagOwners(t *testing.T) {
aclPolicy: &ACLPolicy{
TagOwners: TagOwners{"tag:foo": []string{"group:foo", "user1"}},
},
tag: "tag:test",
stripEmailDomain: true,
tag: "tag:test",
},
want: []string{},
wantErr: true,
@ -653,8 +649,7 @@ func Test_expandTagOwners(t *testing.T) {
Groups: Groups{"group:bar": []string{"user1", "user2"}},
TagOwners: TagOwners{"tag:test": []string{"group:foo", "user2"}},
},
tag: "tag:test",
stripEmailDomain: true,
tag: "tag:test",
},
want: []string{},
wantErr: true,
@ -665,7 +660,6 @@ func Test_expandTagOwners(t *testing.T) {
got, err := getTagOwners(
test.args.aclPolicy,
test.args.tag,
test.args.stripEmailDomain,
)
if (err != nil) != test.wantErr {
t.Errorf("expandTagOwners() error = %v, wantErr %v", err, test.wantErr)
@ -861,10 +855,9 @@ func Test_expandAlias(t *testing.T) {
pol ACLPolicy
}
type args struct {
machines types.Machines
aclPolicy ACLPolicy
alias string
stripEmailDomain bool
machines types.Machines
aclPolicy ACLPolicy
alias string
}
tests := []struct {
name string
@ -888,7 +881,6 @@ func Test_expandAlias(t *testing.T) {
},
},
},
stripEmailDomain: true,
},
want: set([]string{}, []string{
"0.0.0.0/0",
@ -931,7 +923,6 @@ func Test_expandAlias(t *testing.T) {
User: types.User{Name: "mickael"},
},
},
stripEmailDomain: true,
},
want: set([]string{
"100.64.0.1", "100.64.0.2", "100.64.0.3",
@ -973,7 +964,6 @@ func Test_expandAlias(t *testing.T) {
User: types.User{Name: "mickael"},
},
},
stripEmailDomain: true,
},
want: set([]string{}, []string{}),
wantErr: true,
@ -984,9 +974,8 @@ func Test_expandAlias(t *testing.T) {
pol: ACLPolicy{},
},
args: args{
alias: "10.0.0.3",
machines: types.Machines{},
stripEmailDomain: true,
alias: "10.0.0.3",
machines: types.Machines{},
},
want: set([]string{
"10.0.0.3",
@ -999,9 +988,8 @@ func Test_expandAlias(t *testing.T) {
pol: ACLPolicy{},
},
args: args{
alias: "10.0.0.1",
machines: types.Machines{},
stripEmailDomain: true,
alias: "10.0.0.1",
machines: types.Machines{},
},
want: set([]string{
"10.0.0.1",
@ -1023,7 +1011,6 @@ func Test_expandAlias(t *testing.T) {
User: types.User{Name: "mickael"},
},
},
stripEmailDomain: true,
},
want: set([]string{
"10.0.0.1",
@ -1046,7 +1033,6 @@ func Test_expandAlias(t *testing.T) {
User: types.User{Name: "mickael"},
},
},
stripEmailDomain: true,
},
want: set([]string{
"10.0.0.1", "fd7a:115c:a1e0:ab12:4843:2222:6273:2222",
@ -1069,7 +1055,6 @@ func Test_expandAlias(t *testing.T) {
User: types.User{Name: "mickael"},
},
},
stripEmailDomain: true,
},
want: set([]string{
"fd7a:115c:a1e0:ab12:4843:2222:6273:2222", "10.0.0.1",
@ -1086,9 +1071,8 @@ func Test_expandAlias(t *testing.T) {
},
},
args: args{
alias: "testy",
machines: types.Machines{},
stripEmailDomain: true,
alias: "testy",
machines: types.Machines{},
},
want: set([]string{}, []string{"10.0.0.132/32"}),
wantErr: false,
@ -1103,9 +1087,8 @@ func Test_expandAlias(t *testing.T) {
},
},
args: args{
alias: "homeNetwork",
machines: types.Machines{},
stripEmailDomain: true,
alias: "homeNetwork",
machines: types.Machines{},
},
want: set([]string{}, []string{"192.168.1.0/24"}),
wantErr: false,
@ -1116,10 +1099,9 @@ func Test_expandAlias(t *testing.T) {
pol: ACLPolicy{},
},
args: args{
alias: "10.0.0.0/16",
machines: types.Machines{},
aclPolicy: ACLPolicy{},
stripEmailDomain: true,
alias: "10.0.0.0/16",
machines: types.Machines{},
aclPolicy: ACLPolicy{},
},
want: set([]string{}, []string{"10.0.0.0/16"}),
wantErr: false,
@ -1169,7 +1151,6 @@ func Test_expandAlias(t *testing.T) {
User: types.User{Name: "joe"},
},
},
stripEmailDomain: true,
},
want: set([]string{
"100.64.0.1", "100.64.0.2",
@ -1214,7 +1195,6 @@ func Test_expandAlias(t *testing.T) {
User: types.User{Name: "mickael"},
},
},
stripEmailDomain: true,
},
want: set([]string{}, []string{}),
wantErr: true,
@ -1254,7 +1234,6 @@ func Test_expandAlias(t *testing.T) {
User: types.User{Name: "mickael"},
},
},
stripEmailDomain: true,
},
want: set([]string{"100.64.0.1", "100.64.0.2"}, []string{}),
wantErr: false,
@ -1302,7 +1281,6 @@ func Test_expandAlias(t *testing.T) {
User: types.User{Name: "mickael"},
},
},
stripEmailDomain: true,
},
want: set([]string{"100.64.0.1", "100.64.0.2"}, []string{}),
wantErr: false,
@ -1352,7 +1330,6 @@ func Test_expandAlias(t *testing.T) {
User: types.User{Name: "joe"},
},
},
stripEmailDomain: true,
},
want: set([]string{"100.64.0.4"}, []string{}),
wantErr: false,
@ -1363,7 +1340,6 @@ func Test_expandAlias(t *testing.T) {
got, err := test.field.pol.ExpandAlias(
test.args.machines,
test.args.alias,
test.args.stripEmailDomain,
)
if (err != nil) != test.wantErr {
t.Errorf("expandAlias() error = %v, wantErr %v", err, test.wantErr)
@ -1379,10 +1355,9 @@ func Test_expandAlias(t *testing.T) {
func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
type args struct {
aclPolicy *ACLPolicy
nodes types.Machines
user string
stripEmailDomain bool
aclPolicy *ACLPolicy
nodes types.Machines
user string
}
tests := []struct {
name string
@ -1426,8 +1401,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
User: types.User{Name: "joe"},
},
},
user: "joe",
stripEmailDomain: true,
user: "joe",
},
want: types.Machines{
{
@ -1477,8 +1451,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
User: types.User{Name: "joe"},
},
},
user: "joe",
stripEmailDomain: true,
user: "joe",
},
want: types.Machines{
{
@ -1519,8 +1492,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
User: types.User{Name: "joe"},
},
},
user: "joe",
stripEmailDomain: true,
user: "joe",
},
want: types.Machines{
{
@ -1565,8 +1537,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
User: types.User{Name: "joe"},
},
},
user: "joe",
stripEmailDomain: true,
user: "joe",
},
want: types.Machines{
{
@ -1606,7 +1577,6 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
test.args.aclPolicy,
test.args.nodes,
test.args.user,
test.args.stripEmailDomain,
)
if !reflect.DeepEqual(got, test.want) {
t.Errorf("excludeCorrectlyTaggedNodes() = %v, want %v", got, test.want)
@ -1620,9 +1590,8 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
pol ACLPolicy
}
type args struct {
machine types.Machine
peers types.Machines
stripEmailDomain bool
machine types.Machine
peers types.Machines
}
tests := []struct {
name string
@ -1652,9 +1621,8 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
},
},
args: args{
machine: types.Machine{},
peers: types.Machines{},
stripEmailDomain: true,
machine: types.Machine{},
peers: types.Machines{},
},
want: []tailcfg.FilterRule{
{
@ -1709,7 +1677,6 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
User: types.User{Name: "mickael"},
},
},
stripEmailDomain: true,
},
want: []tailcfg.FilterRule{
{
@ -1743,7 +1710,6 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
got, err := tt.field.pol.generateFilterRules(
&tt.args.machine,
tt.args.peers,
tt.args.stripEmailDomain,
)
if (err != nil) != tt.wantErr {
t.Errorf("ACLgenerateFilterRules() error = %v, wantErr %v", err, tt.wantErr)
@ -1761,9 +1727,8 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
func Test_getTags(t *testing.T) {
type args struct {
aclPolicy *ACLPolicy
machine types.Machine
stripEmailDomain bool
aclPolicy *ACLPolicy
machine types.Machine
}
tests := []struct {
name string
@ -1787,7 +1752,6 @@ func Test_getTags(t *testing.T) {
RequestTags: []string{"tag:valid"},
},
},
stripEmailDomain: false,
},
wantValid: []string{"tag:valid"},
wantInvalid: nil,
@ -1808,7 +1772,6 @@ func Test_getTags(t *testing.T) {
RequestTags: []string{"tag:valid", "tag:invalid"},
},
},
stripEmailDomain: false,
},
wantValid: []string{"tag:valid"},
wantInvalid: []string{"tag:invalid"},
@ -1833,7 +1796,6 @@ func Test_getTags(t *testing.T) {
},
},
},
stripEmailDomain: false,
},
wantValid: []string{"tag:valid"},
wantInvalid: []string{"tag:invalid"},
@ -1854,7 +1816,6 @@ func Test_getTags(t *testing.T) {
RequestTags: []string{"tag:invalid", "very-invalid"},
},
},
stripEmailDomain: false,
},
wantValid: nil,
wantInvalid: []string{"tag:invalid", "very-invalid"},
@ -1871,7 +1832,6 @@ func Test_getTags(t *testing.T) {
RequestTags: []string{"tag:invalid", "very-invalid"},
},
},
stripEmailDomain: false,
},
wantValid: nil,
wantInvalid: []string{"tag:invalid", "very-invalid"},
@ -1881,7 +1841,6 @@ func Test_getTags(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
gotValid, gotInvalid := test.args.aclPolicy.GetTagsOfMachine(
test.args.machine,
test.args.stripEmailDomain,
)
for _, valid := range gotValid {
if !util.StringOrPrefixListContains(test.wantValid, valid) {
@ -2589,7 +2548,7 @@ func TestSSHRules(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.pol.generateSSHRules(&tt.machine, tt.peers, false)
got, err := tt.pol.generateSSHRules(&tt.machine, tt.peers)
assert.NoError(t, err)
if diff := cmp.Diff(tt.want, got); diff != "" {