Only load needed part of configuration (#2109)

This commit is contained in:
Kristoffer Dalby 2024-09-07 09:23:58 +02:00 committed by GitHub
parent f368ed01ed
commit 8a3a0fee3c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 196 additions and 324 deletions

View file

@ -684,7 +684,7 @@ func (api headscaleV1APIServer) GetPolicy(
case types.PolicyModeDB:
p, err := api.h.db.GetPolicy()
if err != nil {
return nil, err
return nil, fmt.Errorf("loading ACL from database: %w", err)
}
return &v1.GetPolicyResponse{
@ -696,20 +696,20 @@ func (api headscaleV1APIServer) GetPolicy(
absPath := util.AbsolutePathFromConfigPath(api.h.cfg.Policy.Path)
f, err := os.Open(absPath)
if err != nil {
return nil, err
return nil, fmt.Errorf("reading policy from path %q: %w", absPath, err)
}
defer f.Close()
b, err := io.ReadAll(f)
if err != nil {
return nil, err
return nil, fmt.Errorf("reading policy from file: %w", err)
}
return &v1.GetPolicyResponse{Policy: string(b)}, nil
}
return nil, nil
return nil, fmt.Errorf("no supported policy mode found in configuration, policy.mode: %q", api.h.cfg.Policy.Mode)
}
func (api headscaleV1APIServer) SetPolicy(

View file

@ -212,6 +212,12 @@ type Tuning struct {
NodeMapSessionBufferedChanSize int
}
// LoadConfig prepares and loads the Headscale configuration into Viper.
// This means it sets the default values, reads the configuration file and
// environment variables, and handles deprecated configuration options.
// It has to be called before LoadServerConfig and LoadCLIConfig.
// The configuration is not validated and the caller should check for errors
// using a validation function.
func LoadConfig(path string, isFile bool) error {
if isFile {
viper.SetConfigFile(path)
@ -284,14 +290,14 @@ func LoadConfig(path string, isFile bool) error {
viper.SetDefault("prefixes.allocation", string(IPAllocationStrategySequential))
if IsCLIConfigured() {
return nil
}
if err := viper.ReadInConfig(); err != nil {
return fmt.Errorf("fatal error reading config file: %w", err)
}
return nil
}
func validateServerConfig() error {
depr := deprecator{
warns: make(set.Set[string]),
fatals: make(set.Set[string]),
@ -360,12 +366,12 @@ func LoadConfig(path string, isFile bool) error {
if errorText != "" {
// nolint
return errors.New(strings.TrimSuffix(errorText, "\n"))
} else {
return nil
}
return nil
}
func GetTLSConfig() TLSConfig {
func tlsConfig() TLSConfig {
return TLSConfig{
LetsEncrypt: LetsEncryptConfig{
Hostname: viper.GetString("tls_letsencrypt_hostname"),
@ -384,7 +390,7 @@ func GetTLSConfig() TLSConfig {
}
}
func GetDERPConfig() DERPConfig {
func derpConfig() DERPConfig {
serverEnabled := viper.GetBool("derp.server.enabled")
serverRegionID := viper.GetInt("derp.server.region_id")
serverRegionCode := viper.GetString("derp.server.region_code")
@ -445,7 +451,7 @@ func GetDERPConfig() DERPConfig {
}
}
func GetLogTailConfig() LogTailConfig {
func logtailConfig() LogTailConfig {
enabled := viper.GetBool("logtail.enabled")
return LogTailConfig{
@ -453,7 +459,7 @@ func GetLogTailConfig() LogTailConfig {
}
}
func GetPolicyConfig() PolicyConfig {
func policyConfig() PolicyConfig {
policyPath := viper.GetString("policy.path")
policyMode := viper.GetString("policy.mode")
@ -463,7 +469,7 @@ func GetPolicyConfig() PolicyConfig {
}
}
func GetLogConfig() LogConfig {
func logConfig() LogConfig {
logLevelStr := viper.GetString("log.level")
logLevel, err := zerolog.ParseLevel(logLevelStr)
if err != nil {
@ -473,9 +479,9 @@ func GetLogConfig() LogConfig {
logFormatOpt := viper.GetString("log.format")
var logFormat string
switch logFormatOpt {
case "json":
case JSONLogFormat:
logFormat = JSONLogFormat
case "text":
case TextLogFormat:
logFormat = TextLogFormat
case "":
logFormat = TextLogFormat
@ -491,7 +497,7 @@ func GetLogConfig() LogConfig {
}
}
func GetDatabaseConfig() DatabaseConfig {
func databaseConfig() DatabaseConfig {
debug := viper.GetBool("database.debug")
type_ := viper.GetString("database.type")
@ -543,7 +549,7 @@ func GetDatabaseConfig() DatabaseConfig {
}
}
func DNS() (DNSConfig, error) {
func dns() (DNSConfig, error) {
var dns DNSConfig
// TODO: Use this instead of manually getting settings when
@ -575,12 +581,12 @@ func DNS() (DNSConfig, error) {
return dns, nil
}
// GlobalResolvers returns the global DNS resolvers
// globalResolvers returns the global DNS resolvers
// defined in the config file.
// If a nameserver is a valid IP, it will be used as a regular resolver.
// If a nameserver is a valid URL, it will be used as a DoH resolver.
// If a nameserver is neither a valid URL nor a valid IP, it will be ignored.
func (d *DNSConfig) GlobalResolvers() []*dnstype.Resolver {
func (d *DNSConfig) globalResolvers() []*dnstype.Resolver {
var resolvers []*dnstype.Resolver
for _, nsStr := range d.Nameservers.Global {
@ -613,11 +619,11 @@ func (d *DNSConfig) GlobalResolvers() []*dnstype.Resolver {
return resolvers
}
// SplitResolvers returns a map of domain to DNS resolvers.
// splitResolvers returns a map of domain to DNS resolvers.
// If a nameserver is a valid IP, it will be used as a regular resolver.
// If a nameserver is a valid URL, it will be used as a DoH resolver.
// If a nameserver is neither a valid URL nor a valid IP, it will be ignored.
func (d *DNSConfig) SplitResolvers() map[string][]*dnstype.Resolver {
func (d *DNSConfig) splitResolvers() map[string][]*dnstype.Resolver {
routes := make(map[string][]*dnstype.Resolver)
for domain, nameservers := range d.Nameservers.Split {
var resolvers []*dnstype.Resolver
@ -653,7 +659,7 @@ func (d *DNSConfig) SplitResolvers() map[string][]*dnstype.Resolver {
return routes
}
func DNSToTailcfgDNS(dns DNSConfig) *tailcfg.DNSConfig {
func dnsToTailcfgDNS(dns DNSConfig) *tailcfg.DNSConfig {
cfg := tailcfg.DNSConfig{}
if dns.BaseDomain == "" && dns.MagicDNS {
@ -662,9 +668,9 @@ func DNSToTailcfgDNS(dns DNSConfig) *tailcfg.DNSConfig {
cfg.Proxied = dns.MagicDNS
cfg.ExtraRecords = dns.ExtraRecords
cfg.Resolvers = dns.GlobalResolvers()
cfg.Resolvers = dns.globalResolvers()
routes := dns.SplitResolvers()
routes := dns.splitResolvers()
cfg.Routes = routes
if dns.BaseDomain != "" {
cfg.Domains = []string{dns.BaseDomain}
@ -674,7 +680,7 @@ func DNSToTailcfgDNS(dns DNSConfig) *tailcfg.DNSConfig {
return &cfg
}
func PrefixV4() (*netip.Prefix, error) {
func prefixV4() (*netip.Prefix, error) {
prefixV4Str := viper.GetString("prefixes.v4")
if prefixV4Str == "" {
@ -698,7 +704,7 @@ func PrefixV4() (*netip.Prefix, error) {
return &prefixV4, nil
}
func PrefixV6() (*netip.Prefix, error) {
func prefixV6() (*netip.Prefix, error) {
prefixV6Str := viper.GetString("prefixes.v6")
if prefixV6Str == "" {
@ -723,27 +729,37 @@ func PrefixV6() (*netip.Prefix, error) {
return &prefixV6, nil
}
func GetHeadscaleConfig() (*Config, error) {
if IsCLIConfigured() {
return &Config{
CLI: CLIConfig{
Address: viper.GetString("cli.address"),
APIKey: viper.GetString("cli.api_key"),
Timeout: viper.GetDuration("cli.timeout"),
Insecure: viper.GetBool("cli.insecure"),
},
}, nil
// LoadCLIConfig returns the needed configuration for the CLI client
// of Headscale to connect to a Headscale server.
func LoadCLIConfig() (*Config, error) {
return &Config{
DisableUpdateCheck: viper.GetBool("disable_check_updates"),
UnixSocket: viper.GetString("unix_socket"),
CLI: CLIConfig{
Address: viper.GetString("cli.address"),
APIKey: viper.GetString("cli.api_key"),
Timeout: viper.GetDuration("cli.timeout"),
Insecure: viper.GetBool("cli.insecure"),
},
}, nil
}
// LoadServerConfig returns the full Headscale configuration to
// host a Headscale server. This is called as part of `headscale serve`.
func LoadServerConfig() (*Config, error) {
if err := validateServerConfig(); err != nil {
return nil, err
}
logConfig := GetLogConfig()
logConfig := logConfig()
zerolog.SetGlobalLevel(logConfig.Level)
prefix4, err := PrefixV4()
prefix4, err := prefixV4()
if err != nil {
return nil, err
}
prefix6, err := PrefixV6()
prefix6, err := prefixV6()
if err != nil {
return nil, err
}
@ -763,13 +779,13 @@ func GetHeadscaleConfig() (*Config, error) {
return nil, fmt.Errorf("config error, prefixes.allocation is set to %s, which is not a valid strategy, allowed options: %s, %s", allocStr, IPAllocationStrategySequential, IPAllocationStrategyRandom)
}
dnsConfig, err := DNS()
dnsConfig, err := dns()
if err != nil {
return nil, err
}
derpConfig := GetDERPConfig()
logTailConfig := GetLogTailConfig()
derpConfig := derpConfig()
logTailConfig := logtailConfig()
randomizeClientPort := viper.GetBool("randomize_client_port")
oidcClientSecret := viper.GetString("oidc.client_secret")
@ -806,7 +822,7 @@ func GetHeadscaleConfig() (*Config, error) {
MetricsAddr: viper.GetString("metrics_listen_addr"),
GRPCAddr: viper.GetString("grpc_listen_addr"),
GRPCAllowInsecure: viper.GetBool("grpc_allow_insecure"),
DisableUpdateCheck: viper.GetBool("disable_check_updates"),
DisableUpdateCheck: false,
PrefixV4: prefix4,
PrefixV6: prefix6,
@ -823,11 +839,11 @@ func GetHeadscaleConfig() (*Config, error) {
"ephemeral_node_inactivity_timeout",
),
Database: GetDatabaseConfig(),
Database: databaseConfig(),
TLS: GetTLSConfig(),
TLS: tlsConfig(),
DNSConfig: DNSToTailcfgDNS(dnsConfig),
DNSConfig: dnsToTailcfgDNS(dnsConfig),
DNSUserNameInMagicDNS: dnsConfig.UserNameInMagicDNS,
ACMEEmail: viper.GetString("acme_email"),
@ -870,7 +886,7 @@ func GetHeadscaleConfig() (*Config, error) {
LogTail: logTailConfig,
RandomizeClientPort: randomizeClientPort,
Policy: GetPolicyConfig(),
Policy: policyConfig(),
CLI: CLIConfig{
Address: viper.GetString("cli.address"),
@ -890,10 +906,6 @@ func GetHeadscaleConfig() (*Config, error) {
}, nil
}
func IsCLIConfigured() bool {
return viper.GetString("cli.address") != "" && viper.GetString("cli.api_key") != ""
}
type deprecator struct {
warns set.Set[string]
fatals set.Set[string]

View file

@ -1,6 +1,8 @@
package types
import (
"os"
"path/filepath"
"testing"
"github.com/google/go-cmp/cmp"
@ -22,7 +24,7 @@ func TestReadConfig(t *testing.T) {
name: "unmarshal-dns-full-config",
configPath: "testdata/dns_full.yaml",
setup: func(t *testing.T) (any, error) {
dns, err := DNS()
dns, err := dns()
if err != nil {
return nil, err
}
@ -48,12 +50,12 @@ func TestReadConfig(t *testing.T) {
name: "dns-to-tailcfg.DNSConfig",
configPath: "testdata/dns_full.yaml",
setup: func(t *testing.T) (any, error) {
dns, err := DNS()
dns, err := dns()
if err != nil {
return nil, err
}
return DNSToTailcfgDNS(dns), nil
return dnsToTailcfgDNS(dns), nil
},
want: &tailcfg.DNSConfig{
Proxied: true,
@ -79,7 +81,7 @@ func TestReadConfig(t *testing.T) {
name: "unmarshal-dns-full-no-magic",
configPath: "testdata/dns_full_no_magic.yaml",
setup: func(t *testing.T) (any, error) {
dns, err := DNS()
dns, err := dns()
if err != nil {
return nil, err
}
@ -105,12 +107,12 @@ func TestReadConfig(t *testing.T) {
name: "dns-to-tailcfg.DNSConfig",
configPath: "testdata/dns_full_no_magic.yaml",
setup: func(t *testing.T) (any, error) {
dns, err := DNS()
dns, err := dns()
if err != nil {
return nil, err
}
return DNSToTailcfgDNS(dns), nil
return dnsToTailcfgDNS(dns), nil
},
want: &tailcfg.DNSConfig{
Proxied: false,
@ -136,7 +138,7 @@ func TestReadConfig(t *testing.T) {
name: "base-domain-in-server-url-err",
configPath: "testdata/base-domain-in-server-url.yaml",
setup: func(t *testing.T) (any, error) {
return GetHeadscaleConfig()
return LoadServerConfig()
},
want: nil,
wantErr: "server_url cannot contain the base_domain, this will cause the headscale server and embedded DERP to become unreachable from the Tailscale node.",
@ -145,7 +147,7 @@ func TestReadConfig(t *testing.T) {
name: "base-domain-not-in-server-url",
configPath: "testdata/base-domain-not-in-server-url.yaml",
setup: func(t *testing.T) (any, error) {
cfg, err := GetHeadscaleConfig()
cfg, err := LoadServerConfig()
if err != nil {
return nil, err
}
@ -165,7 +167,7 @@ func TestReadConfig(t *testing.T) {
name: "policy-path-is-loaded",
configPath: "testdata/policy-path-is-loaded.yaml",
setup: func(t *testing.T) (any, error) {
cfg, err := GetHeadscaleConfig()
cfg, err := LoadServerConfig()
if err != nil {
return nil, err
}
@ -245,7 +247,7 @@ func TestReadConfigFromEnv(t *testing.T) {
setup: func(t *testing.T) (any, error) {
t.Logf("all settings: %#v", viper.AllSettings())
dns, err := DNS()
dns, err := dns()
if err != nil {
return nil, err
}
@ -289,3 +291,49 @@ func TestReadConfigFromEnv(t *testing.T) {
})
}
}
func TestTLSConfigValidation(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "headscale")
if err != nil {
t.Fatal(err)
}
// defer os.RemoveAll(tmpDir)
configYaml := []byte(`---
tls_letsencrypt_hostname: example.com
tls_letsencrypt_challenge_type: ""
tls_cert_path: abc.pem
noise:
private_key_path: noise_private.key`)
// Populate a custom config file
configFilePath := filepath.Join(tmpDir, "config.yaml")
err = os.WriteFile(configFilePath, configYaml, 0o600)
if err != nil {
t.Fatalf("Couldn't write file %s", configFilePath)
}
// Check configuration validation errors (1)
err = LoadConfig(tmpDir, false)
assert.NoError(t, err)
err = validateServerConfig()
assert.Error(t, err)
assert.Contains(t, err.Error(), "Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both")
assert.Contains(t, err.Error(), "Fatal config error: the only supported values for tls_letsencrypt_challenge_type are")
assert.Contains(t, err.Error(), "Fatal config error: server_url must start with https:// or http://")
// Check configuration validation errors (2)
configYaml = []byte(`---
noise:
private_key_path: noise_private.key
server_url: http://127.0.0.1:8080
tls_letsencrypt_hostname: example.com
tls_letsencrypt_challenge_type: TLS-ALPN-01
`)
err = os.WriteFile(configFilePath, configYaml, 0o600)
if err != nil {
t.Fatalf("Couldn't write file %s", configFilePath)
}
err = LoadConfig(tmpDir, false)
assert.NoError(t, err)
}