Redo DNS configuration (#2034)

this commit changes and streamlines the dns_config into a new
key, dns. It removes a combination of outdates and incompatible
configuration options that made it easy to confuse what headscale
could and could not do, or what to expect from ones configuration.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2024-08-19 11:41:05 +02:00 committed by GitHub
parent 022fb24cd9
commit ac8491efec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 1036 additions and 453 deletions

View file

@ -36,8 +36,7 @@ func tailNodes(
return tNodes, nil
}
// tailNode converts a Node into a Tailscale Node. includeRoutes is false for shared nodes
// as per the expected behaviour in the official SaaS.
// tailNode converts a Node into a Tailscale Node.
func tailNode(
node *types.Node,
capVer tailcfg.CapabilityVersion,

View file

@ -55,12 +55,14 @@ func TestTailNode(t *testing.T) {
{
name: "empty-node",
node: &types.Node{
Hostinfo: &tailcfg.Hostinfo{},
GivenName: "empty",
Hostinfo: &tailcfg.Hostinfo{},
},
pol: &policy.ACLPolicy{},
dnsConfig: &tailcfg.DNSConfig{},
baseDomain: "",
want: &tailcfg.Node{
Name: "empty",
StableID: "0",
Addresses: []netip.Prefix{},
AllowedIPs: []netip.Prefix{},

View file

@ -166,7 +166,7 @@ func (ns *noiseServer) earlyNoise(protocolVersion int, writer io.Writer) error {
}
const (
MinimumCapVersion tailcfg.CapabilityVersion = 58
MinimumCapVersion tailcfg.CapabilityVersion = 61
)
// NoisePollNetMapHandler takes care of /machine/:id/map using the Noise protocol

View file

@ -20,6 +20,7 @@ import (
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/dnstype"
"tailscale.com/util/set"
)
const (
@ -88,6 +89,20 @@ type Config struct {
Tuning Tuning
}
type DNSConfig struct {
MagicDNS bool `mapstructure:"magic_dns"`
BaseDomain string `mapstructure:"base_domain"`
Nameservers Nameservers
SearchDomains []string `mapstructure:"search_domains"`
ExtraRecords []tailcfg.DNSRecord `mapstructure:"extra_records"`
UserNameInMagicDNS bool `mapstructure:"use_username_in_magic_dns"`
}
type Nameservers struct {
Global []string
Split map[string][]string
}
type SqliteConfig struct {
Path string
WriteAheadLog bool
@ -201,7 +216,8 @@ func LoadConfig(path string, isFile bool) error {
}
}
viper.SetEnvPrefix("headscale")
envPrefix := "headscale"
viper.SetEnvPrefix(envPrefix)
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
viper.AutomaticEnv()
@ -213,9 +229,13 @@ func LoadConfig(path string, isFile bool) error {
viper.SetDefault("log.level", "info")
viper.SetDefault("log.format", TextLogFormat)
viper.SetDefault("dns_config", nil)
viper.SetDefault("dns_config.override_local_dns", true)
viper.SetDefault("dns_config.use_username_in_magic_dns", false)
viper.SetDefault("dns.magic_dns", true)
viper.SetDefault("dns.base_domain", "")
viper.SetDefault("dns.nameservers.global", []string{})
viper.SetDefault("dns.nameservers.split", map[string]string{})
viper.SetDefault("dns.search_domains", []string{})
viper.SetDefault("dns.extra_records", []tailcfg.DNSRecord{})
viper.SetDefault("dns.use_username_in_magic_dns", false)
viper.SetDefault("derp.server.enabled", false)
viper.SetDefault("derp.server.stun.enabled", true)
@ -259,17 +279,33 @@ func LoadConfig(path string, isFile bool) error {
}
if err := viper.ReadInConfig(); err != nil {
log.Warn().Err(err).Msg("Failed to read configuration from disk")
return fmt.Errorf("fatal error reading config file: %w", err)
}
depr := deprecator{
warns: make(set.Set[string]),
fatals: make(set.Set[string]),
}
// Register aliases for backward compatibility
// Has to be called _after_ viper.ReadInConfig()
// https://github.com/spf13/viper/issues/560
// Alias the old ACL Policy path with the new configuration option.
registerAliasAndDeprecate("policy.path", "acl_policy_path")
depr.warnWithAlias("policy.path", "acl_policy_path")
// Move dns_config -> dns
depr.warn("dns_config.override_local_dns")
depr.fatalIfNewKeyIsNotUsed("dns.magic_dns", "dns_config.magic_dns")
depr.fatalIfNewKeyIsNotUsed("dns.base_domain", "dns_config.base_domain")
depr.fatalIfNewKeyIsNotUsed("dns.nameservers.global", "dns_config.nameservers")
depr.fatalIfNewKeyIsNotUsed("dns.nameservers.split", "dns_config.restricted_nameservers")
depr.fatalIfNewKeyIsNotUsed("dns.search_domains", "dns_config.domains")
depr.fatalIfNewKeyIsNotUsed("dns.extra_records", "dns_config.extra_records")
depr.warn("dns_config.use_username_in_magic_dns")
depr.warn("dns.use_username_in_magic_dns")
depr.Log()
// Collect any validation errors and return them all at once
var errorText string
@ -485,123 +521,131 @@ func GetDatabaseConfig() DatabaseConfig {
}
}
func GetDNSConfig() (*tailcfg.DNSConfig, string) {
if viper.IsSet("dns_config") {
dnsConfig := &tailcfg.DNSConfig{}
func DNS() (DNSConfig, error) {
var dns DNSConfig
overrideLocalDNS := viper.GetBool("dns_config.override_local_dns")
// TODO: Use this instead of manually getting settings when
// UnmarshalKey is compatible with Environment Variables.
// err := viper.UnmarshalKey("dns", &dns)
// if err != nil {
// return DNSConfig{}, fmt.Errorf("unmarshaling dns config: %w", err)
// }
if viper.IsSet("dns_config.nameservers") {
nameserversStr := viper.GetStringSlice("dns_config.nameservers")
dns.MagicDNS = viper.GetBool("dns.magic_dns")
dns.BaseDomain = viper.GetString("dns.base_domain")
dns.Nameservers.Global = viper.GetStringSlice("dns.nameservers.global")
dns.Nameservers.Split = viper.GetStringMapStringSlice("dns.nameservers.split")
dns.SearchDomains = viper.GetStringSlice("dns.search_domains")
nameservers := []netip.Addr{}
resolvers := []*dnstype.Resolver{}
if viper.IsSet("dns.extra_records") {
var extraRecords []tailcfg.DNSRecord
for _, nameserverStr := range nameserversStr {
// Search for explicit DNS-over-HTTPS resolvers
if strings.HasPrefix(nameserverStr, "https://") {
resolvers = append(resolvers, &dnstype.Resolver{
Addr: nameserverStr,
})
// This nameserver can not be parsed as an IP address
continue
}
// Parse nameserver as a regular IP
nameserver, err := netip.ParseAddr(nameserverStr)
if err != nil {
log.Error().
Str("func", "getDNSConfig").
Err(err).
Msgf("Could not parse nameserver IP: %s", nameserverStr)
}
nameservers = append(nameservers, nameserver)
resolvers = append(resolvers, &dnstype.Resolver{
Addr: nameserver.String(),
})
}
dnsConfig.Nameservers = nameservers
if overrideLocalDNS {
dnsConfig.Resolvers = resolvers
} else {
dnsConfig.FallbackResolvers = resolvers
}
err := viper.UnmarshalKey("dns.extra_records", &extraRecords)
if err != nil {
return DNSConfig{}, fmt.Errorf("unmarshaling dns extra records: %w", err)
}
if viper.IsSet("dns_config.restricted_nameservers") {
dnsConfig.Routes = make(map[string][]*dnstype.Resolver)
domains := []string{}
restrictedDNS := viper.GetStringMapStringSlice(
"dns_config.restricted_nameservers",
)
for domain, restrictedNameservers := range restrictedDNS {
restrictedResolvers := make(
[]*dnstype.Resolver,
len(restrictedNameservers),
)
for index, nameserverStr := range restrictedNameservers {
nameserver, err := netip.ParseAddr(nameserverStr)
if err != nil {
log.Error().
Str("func", "getDNSConfig").
Err(err).
Msgf("Could not parse restricted nameserver IP: %s", nameserverStr)
}
restrictedResolvers[index] = &dnstype.Resolver{
Addr: nameserver.String(),
}
}
dnsConfig.Routes[domain] = restrictedResolvers
domains = append(domains, domain)
}
dnsConfig.Domains = domains
}
if viper.IsSet("dns_config.extra_records") {
var extraRecords []tailcfg.DNSRecord
err := viper.UnmarshalKey("dns_config.extra_records", &extraRecords)
if err != nil {
log.Error().
Str("func", "getDNSConfig").
Err(err).
Msgf("Could not parse dns_config.extra_records")
}
dnsConfig.ExtraRecords = extraRecords
}
if viper.IsSet("dns_config.magic_dns") {
dnsConfig.Proxied = viper.GetBool("dns_config.magic_dns")
}
var baseDomain string
if viper.IsSet("dns_config.base_domain") {
baseDomain = viper.GetString("dns_config.base_domain")
} else {
baseDomain = "headscale.net" // does not really matter when MagicDNS is not enabled
}
if !viper.GetBool("dns_config.use_username_in_magic_dns") {
dnsConfig.Domains = []string{baseDomain}
} else {
log.Warn().Msg("DNS: Usernames in DNS has been deprecated, this option will be remove in future versions")
log.Warn().Msg("DNS: see 0.23.0 changelog for more information.")
}
if domains := viper.GetStringSlice("dns_config.domains"); len(domains) > 0 {
dnsConfig.Domains = append(dnsConfig.Domains, domains...)
}
log.Trace().Interface("dns_config", dnsConfig).Msg("DNS configuration loaded")
return dnsConfig, baseDomain
dns.ExtraRecords = extraRecords
}
return nil, ""
dns.UserNameInMagicDNS = viper.GetBool("dns.use_username_in_magic_dns")
return dns, nil
}
// GlobalResolvers returns the global DNS resolvers
// defined in the config file.
// If a nameserver is a valid IP, it will be used as a regular resolver.
// If a nameserver is a valid URL, it will be used as a DoH resolver.
// If a nameserver is neither a valid URL nor a valid IP, it will be ignored.
func (d *DNSConfig) GlobalResolvers() []*dnstype.Resolver {
var resolvers []*dnstype.Resolver
for _, nsStr := range d.Nameservers.Global {
warn := ""
if _, err := netip.ParseAddr(nsStr); err == nil {
resolvers = append(resolvers, &dnstype.Resolver{
Addr: nsStr,
})
continue
} else {
warn = fmt.Sprintf("Invalid global nameserver %q. Parsing error: %s ignoring", nsStr, err)
}
if _, err := url.Parse(nsStr); err == nil {
resolvers = append(resolvers, &dnstype.Resolver{
Addr: nsStr,
})
} else {
warn = fmt.Sprintf("Invalid global nameserver %q. Parsing error: %s ignoring", nsStr, err)
}
if warn != "" {
log.Warn().Msg(warn)
}
}
return resolvers
}
// SplitResolvers returns a map of domain to DNS resolvers.
// If a nameserver is a valid IP, it will be used as a regular resolver.
// If a nameserver is a valid URL, it will be used as a DoH resolver.
// If a nameserver is neither a valid URL nor a valid IP, it will be ignored.
func (d *DNSConfig) SplitResolvers() map[string][]*dnstype.Resolver {
routes := make(map[string][]*dnstype.Resolver)
for domain, nameservers := range d.Nameservers.Split {
var resolvers []*dnstype.Resolver
for _, nsStr := range nameservers {
warn := ""
if _, err := netip.ParseAddr(nsStr); err == nil {
resolvers = append(resolvers, &dnstype.Resolver{
Addr: nsStr,
})
continue
} else {
warn = fmt.Sprintf("Invalid split dns nameserver %q. Parsing error: %s ignoring", nsStr, err)
}
if _, err := url.Parse(nsStr); err == nil {
resolvers = append(resolvers, &dnstype.Resolver{
Addr: nsStr,
})
} else {
warn = fmt.Sprintf("Invalid split dns nameserver %q. Parsing error: %s ignoring", nsStr, err)
}
if warn != "" {
log.Warn().Msg(warn)
}
}
routes[domain] = resolvers
}
return routes
}
func DNSToTailcfgDNS(dns DNSConfig) *tailcfg.DNSConfig {
cfg := tailcfg.DNSConfig{}
if dns.BaseDomain == "" && dns.MagicDNS {
log.Fatal().Msg("dns.base_domain must be set when using MagicDNS (dns.magic_dns)")
}
cfg.Proxied = dns.MagicDNS
cfg.ExtraRecords = dns.ExtraRecords
cfg.Resolvers = dns.GlobalResolvers()
routes := dns.SplitResolvers()
cfg.Routes = routes
if dns.BaseDomain != "" {
cfg.Domains = []string{dns.BaseDomain}
}
cfg.Domains = append(cfg.Domains, dns.SearchDomains...)
return &cfg
}
func PrefixV4() (*netip.Prefix, error) {
@ -693,7 +737,11 @@ func GetHeadscaleConfig() (*Config, error) {
return nil, fmt.Errorf("config error, prefixes.allocation is set to %s, which is not a valid strategy, allowed options: %s, %s", allocStr, IPAllocationStrategySequential, IPAllocationStrategyRandom)
}
dnsConfig, baseDomain := GetDNSConfig()
dnsConfig, err := DNS()
if err != nil {
return nil, err
}
derpConfig := GetDERPConfig()
logTailConfig := GetLogTailConfig()
randomizeClientPort := viper.GetBool("randomize_client_port")
@ -711,8 +759,23 @@ func GetHeadscaleConfig() (*Config, error) {
oidcClientSecret = strings.TrimSpace(string(secretBytes))
}
serverURL := viper.GetString("server_url")
// BaseDomain cannot be the same as the server URL.
// This is because Tailscale takes over the domain in BaseDomain,
// causing the headscale server and DERP to be unreachable.
// For Tailscale upstream, the following is true:
// - DERP run on their own domains
// - Control plane runs on login.tailscale.com/controlplane.tailscale.com
// - MagicDNS (BaseDomain) for users is on a *.ts.net domain per tailnet (e.g. tail-scale.ts.net)
//
// TODO(kradalby): remove dnsConfig.UserNameInMagicDNS check when removed.
if !dnsConfig.UserNameInMagicDNS && dnsConfig.BaseDomain != "" && strings.Contains(serverURL, dnsConfig.BaseDomain) {
return nil, errors.New("server_url cannot contain the base_domain, this will cause the headscale server and embedded DERP to become unreachable from the Tailscale node.")
}
return &Config{
ServerURL: viper.GetString("server_url"),
ServerURL: serverURL,
Addr: viper.GetString("listen_addr"),
MetricsAddr: viper.GetString("metrics_listen_addr"),
GRPCAddr: viper.GetString("grpc_listen_addr"),
@ -726,7 +789,7 @@ func GetHeadscaleConfig() (*Config, error) {
NoisePrivateKeyPath: util.AbsolutePathFromConfigPath(
viper.GetString("noise.private_key_path"),
),
BaseDomain: baseDomain,
BaseDomain: dnsConfig.BaseDomain,
DERP: derpConfig,
@ -738,8 +801,8 @@ func GetHeadscaleConfig() (*Config, error) {
TLS: GetTLSConfig(),
DNSConfig: dnsConfig,
DNSUserNameInMagicDNS: viper.GetBool("dns_config.use_username_in_magic_dns"),
DNSConfig: DNSToTailcfgDNS(dnsConfig),
DNSUserNameInMagicDNS: dnsConfig.UserNameInMagicDNS,
ACMEEmail: viper.GetString("acme_email"),
ACMEURL: viper.GetString("acme_url"),
@ -805,19 +868,70 @@ func IsCLIConfigured() bool {
return viper.GetString("cli.address") != "" && viper.GetString("cli.api_key") != ""
}
// registerAliasAndDeprecate will register an alias between the newKey and the oldKey,
type deprecator struct {
warns set.Set[string]
fatals set.Set[string]
}
// warnWithAlias will register an alias between the newKey and the oldKey,
// and log a deprecation warning if the oldKey is set.
func registerAliasAndDeprecate(newKey, oldKey string) {
func (d *deprecator) warnWithAlias(newKey, oldKey string) {
// NOTE: RegisterAlias is called with NEW KEY -> OLD KEY
viper.RegisterAlias(newKey, oldKey)
if viper.IsSet(oldKey) {
log.Warn().Msgf("The %q configuration key is deprecated. Please use %q instead. %q will be removed in the future.", oldKey, newKey, oldKey)
d.warns.Add(fmt.Sprintf("The %q configuration key is deprecated. Please use %q instead. %q will be removed in the future.", oldKey, newKey, oldKey))
}
}
// deprecateAndFatal will log a fatal deprecation warning if the oldKey is set.
func deprecateAndFatal(newKey, oldKey string) {
// fatal deprecates and adds an entry to the fatal list of options if the oldKey is set.
func (d *deprecator) fatal(newKey, oldKey string) {
if viper.IsSet(oldKey) {
log.Fatal().Msgf("The %q configuration key is deprecated. Please use %q instead. %q has been removed.", oldKey, newKey, oldKey)
d.fatals.Add(fmt.Sprintf("The %q configuration key is deprecated. Please use %q instead. %q has been removed.", oldKey, newKey, oldKey))
}
}
// fatalIfNewKeyIsNotUsed deprecates and adds an entry to the fatal list of options if the oldKey is set and the new key is _not_ set.
// If the new key is set, a warning is emitted instead.
func (d *deprecator) fatalIfNewKeyIsNotUsed(newKey, oldKey string) {
if viper.IsSet(oldKey) && !viper.IsSet(newKey) {
d.fatals.Add(fmt.Sprintf("The %q configuration key is deprecated. Please use %q instead. %q has been removed.", oldKey, newKey, oldKey))
} else if viper.IsSet(oldKey) {
d.warns.Add(fmt.Sprintf("The %q configuration key is deprecated. Please use %q instead. %q has been removed.", oldKey, newKey, oldKey))
}
}
// warn deprecates and adds an option to log a warning if the oldKey is set.
func (d *deprecator) warnNoAlias(newKey, oldKey string) {
if viper.IsSet(oldKey) {
d.warns.Add(fmt.Sprintf("The %q configuration key is deprecated. Please use %q instead. %q has been removed.", oldKey, newKey, oldKey))
}
}
// warn deprecates and adds an entry to the warn list of options if the oldKey is set.
func (d *deprecator) warn(oldKey string) {
if viper.IsSet(oldKey) {
d.warns.Add(fmt.Sprintf("The %q configuration key is deprecated and has been removed. Please see the changelog for more details.", oldKey))
}
}
func (d *deprecator) String() string {
var b strings.Builder
for _, w := range d.warns.Slice() {
fmt.Fprintf(&b, "WARN: %s\n", w)
}
for _, f := range d.fatals.Slice() {
fmt.Fprintf(&b, "FATAL: %s\n", f)
}
return b.String()
}
func (d *deprecator) Log() {
if len(d.fatals) > 0 {
log.Fatal().Msg("\n" + d.String())
} else if len(d.warns) > 0 {
log.Warn().Msg("\n" + d.String())
}
}

View file

@ -0,0 +1,272 @@
package types
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
"tailscale.com/tailcfg"
"tailscale.com/types/dnstype"
)
func TestReadConfig(t *testing.T) {
tests := []struct {
name string
configPath string
setup func(*testing.T) (any, error)
want any
wantErr string
}{
{
name: "unmarshal-dns-full-config",
configPath: "testdata/dns_full.yaml",
setup: func(t *testing.T) (any, error) {
dns, err := DNS()
if err != nil {
return nil, err
}
return dns, nil
},
want: DNSConfig{
MagicDNS: true,
BaseDomain: "example.com",
Nameservers: Nameservers{
Global: []string{"1.1.1.1", "1.0.0.1", "2606:4700:4700::1111", "2606:4700:4700::1001", "https://dns.nextdns.io/abc123"},
Split: map[string][]string{"darp.headscale.net": {"1.1.1.1", "8.8.8.8"}, "foo.bar.com": {"1.1.1.1"}},
},
ExtraRecords: []tailcfg.DNSRecord{
{Name: "grafana.myvpn.example.com", Type: "A", Value: "100.64.0.3"},
{Name: "prometheus.myvpn.example.com", Type: "A", Value: "100.64.0.4"},
},
SearchDomains: []string{"test.com", "bar.com"},
UserNameInMagicDNS: true,
},
},
{
name: "dns-to-tailcfg.DNSConfig",
configPath: "testdata/dns_full.yaml",
setup: func(t *testing.T) (any, error) {
dns, err := DNS()
if err != nil {
return nil, err
}
return DNSToTailcfgDNS(dns), nil
},
want: &tailcfg.DNSConfig{
Proxied: true,
Domains: []string{"example.com", "test.com", "bar.com"},
Resolvers: []*dnstype.Resolver{
{Addr: "1.1.1.1"},
{Addr: "1.0.0.1"},
{Addr: "2606:4700:4700::1111"},
{Addr: "2606:4700:4700::1001"},
{Addr: "https://dns.nextdns.io/abc123"},
},
Routes: map[string][]*dnstype.Resolver{
"darp.headscale.net": {{Addr: "1.1.1.1"}, {Addr: "8.8.8.8"}},
"foo.bar.com": {{Addr: "1.1.1.1"}},
},
ExtraRecords: []tailcfg.DNSRecord{
{Name: "grafana.myvpn.example.com", Type: "A", Value: "100.64.0.3"},
{Name: "prometheus.myvpn.example.com", Type: "A", Value: "100.64.0.4"},
},
},
},
{
name: "unmarshal-dns-full-no-magic",
configPath: "testdata/dns_full_no_magic.yaml",
setup: func(t *testing.T) (any, error) {
dns, err := DNS()
if err != nil {
return nil, err
}
return dns, nil
},
want: DNSConfig{
MagicDNS: false,
BaseDomain: "example.com",
Nameservers: Nameservers{
Global: []string{"1.1.1.1", "1.0.0.1", "2606:4700:4700::1111", "2606:4700:4700::1001", "https://dns.nextdns.io/abc123"},
Split: map[string][]string{"darp.headscale.net": {"1.1.1.1", "8.8.8.8"}, "foo.bar.com": {"1.1.1.1"}},
},
ExtraRecords: []tailcfg.DNSRecord{
{Name: "grafana.myvpn.example.com", Type: "A", Value: "100.64.0.3"},
{Name: "prometheus.myvpn.example.com", Type: "A", Value: "100.64.0.4"},
},
SearchDomains: []string{"test.com", "bar.com"},
UserNameInMagicDNS: true,
},
},
{
name: "dns-to-tailcfg.DNSConfig",
configPath: "testdata/dns_full_no_magic.yaml",
setup: func(t *testing.T) (any, error) {
dns, err := DNS()
if err != nil {
return nil, err
}
return DNSToTailcfgDNS(dns), nil
},
want: &tailcfg.DNSConfig{
Proxied: false,
Domains: []string{"example.com", "test.com", "bar.com"},
Resolvers: []*dnstype.Resolver{
{Addr: "1.1.1.1"},
{Addr: "1.0.0.1"},
{Addr: "2606:4700:4700::1111"},
{Addr: "2606:4700:4700::1001"},
{Addr: "https://dns.nextdns.io/abc123"},
},
Routes: map[string][]*dnstype.Resolver{
"darp.headscale.net": {{Addr: "1.1.1.1"}, {Addr: "8.8.8.8"}},
"foo.bar.com": {{Addr: "1.1.1.1"}},
},
ExtraRecords: []tailcfg.DNSRecord{
{Name: "grafana.myvpn.example.com", Type: "A", Value: "100.64.0.3"},
{Name: "prometheus.myvpn.example.com", Type: "A", Value: "100.64.0.4"},
},
},
},
{
name: "base-domain-in-server-url-err",
configPath: "testdata/base-domain-in-server-url.yaml",
setup: func(t *testing.T) (any, error) {
return GetHeadscaleConfig()
},
want: nil,
wantErr: "server_url cannot contain the base_domain, this will cause the headscale server and embedded DERP to become unreachable from the Tailscale node.",
},
{
name: "base-domain-not-in-server-url",
configPath: "testdata/base-domain-not-in-server-url.yaml",
setup: func(t *testing.T) (any, error) {
cfg, err := GetHeadscaleConfig()
if err != nil {
return nil, err
}
return map[string]string{
"server_url": cfg.ServerURL,
"base_domain": cfg.BaseDomain,
}, err
},
want: map[string]string{
"server_url": "https://derp.no",
"base_domain": "clients.derp.no",
},
wantErr: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
viper.Reset()
err := LoadConfig(tt.configPath, true)
assert.NoError(t, err)
conf, err := tt.setup(t)
if tt.wantErr != "" {
assert.Equal(t, tt.wantErr, err.Error())
return
}
assert.NoError(t, err)
if diff := cmp.Diff(tt.want, conf); diff != "" {
t.Errorf("ReadConfig() mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestReadConfigFromEnv(t *testing.T) {
tests := []struct {
name string
configEnv map[string]string
setup func(*testing.T) (any, error)
want any
}{
{
name: "test-random-base-settings-with-env",
configEnv: map[string]string{
"HEADSCALE_LOG_LEVEL": "trace",
"HEADSCALE_DATABASE_SQLITE_WRITE_AHEAD_LOG": "false",
"HEADSCALE_PREFIXES_V4": "100.64.0.0/10",
},
setup: func(t *testing.T) (any, error) {
t.Logf("all settings: %#v", viper.AllSettings())
assert.Equal(t, "trace", viper.GetString("log.level"))
assert.Equal(t, "100.64.0.0/10", viper.GetString("prefixes.v4"))
assert.False(t, viper.GetBool("database.sqlite.write_ahead_log"))
return nil, nil
},
want: nil,
},
{
name: "unmarshal-dns-full-config",
configEnv: map[string]string{
"HEADSCALE_DNS_MAGIC_DNS": "true",
"HEADSCALE_DNS_BASE_DOMAIN": "example.com",
"HEADSCALE_DNS_NAMESERVERS_GLOBAL": `1.1.1.1 8.8.8.8`,
"HEADSCALE_DNS_SEARCH_DOMAINS": "test.com bar.com",
"HEADSCALE_DNS_USE_USERNAME_IN_MAGIC_DNS": "true",
// TODO(kradalby): Figure out how to pass these as env vars
// "HEADSCALE_DNS_NAMESERVERS_SPLIT": `{foo.bar.com: ["1.1.1.1"]}`,
// "HEADSCALE_DNS_EXTRA_RECORDS": `[{ name: "prometheus.myvpn.example.com", type: "A", value: "100.64.0.4" }]`,
},
setup: func(t *testing.T) (any, error) {
t.Logf("all settings: %#v", viper.AllSettings())
dns, err := DNS()
if err != nil {
return nil, err
}
return dns, nil
},
want: DNSConfig{
MagicDNS: true,
BaseDomain: "example.com",
Nameservers: Nameservers{
Global: []string{"1.1.1.1", "8.8.8.8"},
Split: map[string][]string{
// "foo.bar.com": {"1.1.1.1"},
},
},
ExtraRecords: []tailcfg.DNSRecord{
// {Name: "prometheus.myvpn.example.com", Type: "A", Value: "100.64.0.4"},
},
SearchDomains: []string{"test.com", "bar.com"},
UserNameInMagicDNS: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
for k, v := range tt.configEnv {
t.Setenv(k, v)
}
viper.Reset()
err := LoadConfig("testdata/minimal.yaml", true)
assert.NoError(t, err)
conf, err := tt.setup(t)
assert.NoError(t, err)
if diff := cmp.Diff(tt.want, conf); diff != "" {
t.Errorf("ReadConfig() mismatch (-want +got):\n%s", diff)
}
})
}
}

View file

@ -394,40 +394,39 @@ func (node *Node) Proto() *v1.Node {
}
func (node *Node) GetFQDN(cfg *Config, baseDomain string) (string, error) {
var hostname string
if cfg.DNSConfig != nil && cfg.DNSConfig.Proxied { // MagicDNS
if node.GivenName == "" {
return "", fmt.Errorf("failed to create valid FQDN: %w", ErrNodeHasNoGivenName)
}
if node.GivenName == "" {
return "", fmt.Errorf("failed to create valid FQDN: %w", ErrNodeHasNoGivenName)
}
hostname := node.GivenName
if baseDomain != "" {
hostname = fmt.Sprintf(
"%s.%s",
node.GivenName,
baseDomain,
)
}
if cfg.DNSUserNameInMagicDNS {
if node.User.Name == "" {
return "", fmt.Errorf("failed to create valid FQDN: %w", ErrNodeUserHasNoName)
}
hostname = fmt.Sprintf(
"%s.%s.%s",
node.GivenName,
node.User.Name,
baseDomain,
)
if cfg.DNSUserNameInMagicDNS {
if node.User.Name == "" {
return "", fmt.Errorf("failed to create valid FQDN: %w", ErrNodeUserHasNoName)
}
if len(hostname) > MaxHostnameLength {
return "", fmt.Errorf(
"failed to create valid FQDN (%s): %w",
hostname,
ErrHostnameTooLong,
)
}
} else {
hostname = node.GivenName
hostname = fmt.Sprintf(
"%s.%s.%s",
node.GivenName,
node.User.Name,
baseDomain,
)
}
if len(hostname) > MaxHostnameLength {
return "", fmt.Errorf(
"failed to create valid FQDN (%s): %w",
hostname,
ErrHostnameTooLong,
)
}
return hostname, nil

View file

@ -195,7 +195,7 @@ func TestNodeFQDN(t *testing.T) {
DNSUserNameInMagicDNS: true,
},
domain: "example.com",
want: "test",
want: "test.user.example.com",
},
{
name: "no-dnsconfig-with-username",
@ -206,7 +206,7 @@ func TestNodeFQDN(t *testing.T) {
},
},
domain: "example.com",
want: "test",
want: "test.example.com",
},
{
name: "all-set",
@ -271,7 +271,7 @@ func TestNodeFQDN(t *testing.T) {
DNSUserNameInMagicDNS: false,
},
domain: "example.com",
want: "test",
want: "test.example.com",
},
{
name: "no-dnsconfig",
@ -282,7 +282,7 @@ func TestNodeFQDN(t *testing.T) {
},
},
domain: "example.com",
want: "test",
want: "test.example.com",
},
}

View file

@ -0,0 +1,16 @@
noise:
private_key_path: "private_key.pem"
prefixes:
v6: fd7a:115c:a1e0::/48
v4: 100.64.0.0/10
database:
type: sqlite3
server_url: "https://derp.no"
dns:
magic_dns: true
base_domain: derp.no
use_username_in_magic_dns: false

View file

@ -0,0 +1,16 @@
noise:
private_key_path: "private_key.pem"
prefixes:
v6: fd7a:115c:a1e0::/48
v4: 100.64.0.0/10
database:
type: sqlite3
server_url: "https://derp.no"
dns:
magic_dns: true
base_domain: clients.derp.no
use_username_in_magic_dns: false

37
hscontrol/types/testdata/dns_full.yaml vendored Normal file
View file

@ -0,0 +1,37 @@
# minimum to not fatal
noise:
private_key_path: "private_key.pem"
server_url: "https://derp.no"
dns:
magic_dns: true
base_domain: example.com
nameservers:
global:
- 1.1.1.1
- 1.0.0.1
- 2606:4700:4700::1111
- 2606:4700:4700::1001
- https://dns.nextdns.io/abc123
split:
foo.bar.com:
- 1.1.1.1
darp.headscale.net:
- 1.1.1.1
- 8.8.8.8
search_domains:
- test.com
- bar.com
extra_records:
- name: "grafana.myvpn.example.com"
type: "A"
value: "100.64.0.3"
# you can also put it in one line
- { name: "prometheus.myvpn.example.com", type: "A", value: "100.64.0.4" }
use_username_in_magic_dns: true

View file

@ -0,0 +1,37 @@
# minimum to not fatal
noise:
private_key_path: "private_key.pem"
server_url: "https://derp.no"
dns:
magic_dns: false
base_domain: example.com
nameservers:
global:
- 1.1.1.1
- 1.0.0.1
- 2606:4700:4700::1111
- 2606:4700:4700::1001
- https://dns.nextdns.io/abc123
split:
foo.bar.com:
- 1.1.1.1
darp.headscale.net:
- 1.1.1.1
- 8.8.8.8
search_domains:
- test.com
- bar.com
extra_records:
- name: "grafana.myvpn.example.com"
type: "A"
value: "100.64.0.3"
# you can also put it in one line
- { name: "prometheus.myvpn.example.com", type: "A", value: "100.64.0.4" }
use_username_in_magic_dns: true

3
hscontrol/types/testdata/minimal.yaml vendored Normal file
View file

@ -0,0 +1,3 @@
noise:
private_key_path: "private_key.pem"
server_url: "https://derp.no"