Only load needed part of configuration (#2109)

This commit is contained in:
Kristoffer Dalby 2024-09-07 09:23:58 +02:00 committed by GitHub
parent f368ed01ed
commit 8a3a0fee3c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 196 additions and 324 deletions

View file

@ -212,6 +212,12 @@ type Tuning struct {
NodeMapSessionBufferedChanSize int
}
// LoadConfig prepares and loads the Headscale configuration into Viper.
// This means it sets the default values, reads the configuration file and
// environment variables, and handles deprecated configuration options.
// It has to be called before LoadServerConfig and LoadCLIConfig.
// The configuration is not validated and the caller should check for errors
// using a validation function.
func LoadConfig(path string, isFile bool) error {
if isFile {
viper.SetConfigFile(path)
@ -284,14 +290,14 @@ func LoadConfig(path string, isFile bool) error {
viper.SetDefault("prefixes.allocation", string(IPAllocationStrategySequential))
if IsCLIConfigured() {
return nil
}
if err := viper.ReadInConfig(); err != nil {
return fmt.Errorf("fatal error reading config file: %w", err)
}
return nil
}
func validateServerConfig() error {
depr := deprecator{
warns: make(set.Set[string]),
fatals: make(set.Set[string]),
@ -360,12 +366,12 @@ func LoadConfig(path string, isFile bool) error {
if errorText != "" {
// nolint
return errors.New(strings.TrimSuffix(errorText, "\n"))
} else {
return nil
}
return nil
}
func GetTLSConfig() TLSConfig {
func tlsConfig() TLSConfig {
return TLSConfig{
LetsEncrypt: LetsEncryptConfig{
Hostname: viper.GetString("tls_letsencrypt_hostname"),
@ -384,7 +390,7 @@ func GetTLSConfig() TLSConfig {
}
}
func GetDERPConfig() DERPConfig {
func derpConfig() DERPConfig {
serverEnabled := viper.GetBool("derp.server.enabled")
serverRegionID := viper.GetInt("derp.server.region_id")
serverRegionCode := viper.GetString("derp.server.region_code")
@ -445,7 +451,7 @@ func GetDERPConfig() DERPConfig {
}
}
func GetLogTailConfig() LogTailConfig {
func logtailConfig() LogTailConfig {
enabled := viper.GetBool("logtail.enabled")
return LogTailConfig{
@ -453,7 +459,7 @@ func GetLogTailConfig() LogTailConfig {
}
}
func GetPolicyConfig() PolicyConfig {
func policyConfig() PolicyConfig {
policyPath := viper.GetString("policy.path")
policyMode := viper.GetString("policy.mode")
@ -463,7 +469,7 @@ func GetPolicyConfig() PolicyConfig {
}
}
func GetLogConfig() LogConfig {
func logConfig() LogConfig {
logLevelStr := viper.GetString("log.level")
logLevel, err := zerolog.ParseLevel(logLevelStr)
if err != nil {
@ -473,9 +479,9 @@ func GetLogConfig() LogConfig {
logFormatOpt := viper.GetString("log.format")
var logFormat string
switch logFormatOpt {
case "json":
case JSONLogFormat:
logFormat = JSONLogFormat
case "text":
case TextLogFormat:
logFormat = TextLogFormat
case "":
logFormat = TextLogFormat
@ -491,7 +497,7 @@ func GetLogConfig() LogConfig {
}
}
func GetDatabaseConfig() DatabaseConfig {
func databaseConfig() DatabaseConfig {
debug := viper.GetBool("database.debug")
type_ := viper.GetString("database.type")
@ -543,7 +549,7 @@ func GetDatabaseConfig() DatabaseConfig {
}
}
func DNS() (DNSConfig, error) {
func dns() (DNSConfig, error) {
var dns DNSConfig
// TODO: Use this instead of manually getting settings when
@ -575,12 +581,12 @@ func DNS() (DNSConfig, error) {
return dns, nil
}
// GlobalResolvers returns the global DNS resolvers
// globalResolvers returns the global DNS resolvers
// defined in the config file.
// If a nameserver is a valid IP, it will be used as a regular resolver.
// If a nameserver is a valid URL, it will be used as a DoH resolver.
// If a nameserver is neither a valid URL nor a valid IP, it will be ignored.
func (d *DNSConfig) GlobalResolvers() []*dnstype.Resolver {
func (d *DNSConfig) globalResolvers() []*dnstype.Resolver {
var resolvers []*dnstype.Resolver
for _, nsStr := range d.Nameservers.Global {
@ -613,11 +619,11 @@ func (d *DNSConfig) GlobalResolvers() []*dnstype.Resolver {
return resolvers
}
// SplitResolvers returns a map of domain to DNS resolvers.
// splitResolvers returns a map of domain to DNS resolvers.
// If a nameserver is a valid IP, it will be used as a regular resolver.
// If a nameserver is a valid URL, it will be used as a DoH resolver.
// If a nameserver is neither a valid URL nor a valid IP, it will be ignored.
func (d *DNSConfig) SplitResolvers() map[string][]*dnstype.Resolver {
func (d *DNSConfig) splitResolvers() map[string][]*dnstype.Resolver {
routes := make(map[string][]*dnstype.Resolver)
for domain, nameservers := range d.Nameservers.Split {
var resolvers []*dnstype.Resolver
@ -653,7 +659,7 @@ func (d *DNSConfig) SplitResolvers() map[string][]*dnstype.Resolver {
return routes
}
func DNSToTailcfgDNS(dns DNSConfig) *tailcfg.DNSConfig {
func dnsToTailcfgDNS(dns DNSConfig) *tailcfg.DNSConfig {
cfg := tailcfg.DNSConfig{}
if dns.BaseDomain == "" && dns.MagicDNS {
@ -662,9 +668,9 @@ func DNSToTailcfgDNS(dns DNSConfig) *tailcfg.DNSConfig {
cfg.Proxied = dns.MagicDNS
cfg.ExtraRecords = dns.ExtraRecords
cfg.Resolvers = dns.GlobalResolvers()
cfg.Resolvers = dns.globalResolvers()
routes := dns.SplitResolvers()
routes := dns.splitResolvers()
cfg.Routes = routes
if dns.BaseDomain != "" {
cfg.Domains = []string{dns.BaseDomain}
@ -674,7 +680,7 @@ func DNSToTailcfgDNS(dns DNSConfig) *tailcfg.DNSConfig {
return &cfg
}
func PrefixV4() (*netip.Prefix, error) {
func prefixV4() (*netip.Prefix, error) {
prefixV4Str := viper.GetString("prefixes.v4")
if prefixV4Str == "" {
@ -698,7 +704,7 @@ func PrefixV4() (*netip.Prefix, error) {
return &prefixV4, nil
}
func PrefixV6() (*netip.Prefix, error) {
func prefixV6() (*netip.Prefix, error) {
prefixV6Str := viper.GetString("prefixes.v6")
if prefixV6Str == "" {
@ -723,27 +729,37 @@ func PrefixV6() (*netip.Prefix, error) {
return &prefixV6, nil
}
func GetHeadscaleConfig() (*Config, error) {
if IsCLIConfigured() {
return &Config{
CLI: CLIConfig{
Address: viper.GetString("cli.address"),
APIKey: viper.GetString("cli.api_key"),
Timeout: viper.GetDuration("cli.timeout"),
Insecure: viper.GetBool("cli.insecure"),
},
}, nil
// LoadCLIConfig returns the needed configuration for the CLI client
// of Headscale to connect to a Headscale server.
func LoadCLIConfig() (*Config, error) {
return &Config{
DisableUpdateCheck: viper.GetBool("disable_check_updates"),
UnixSocket: viper.GetString("unix_socket"),
CLI: CLIConfig{
Address: viper.GetString("cli.address"),
APIKey: viper.GetString("cli.api_key"),
Timeout: viper.GetDuration("cli.timeout"),
Insecure: viper.GetBool("cli.insecure"),
},
}, nil
}
// LoadServerConfig returns the full Headscale configuration to
// host a Headscale server. This is called as part of `headscale serve`.
func LoadServerConfig() (*Config, error) {
if err := validateServerConfig(); err != nil {
return nil, err
}
logConfig := GetLogConfig()
logConfig := logConfig()
zerolog.SetGlobalLevel(logConfig.Level)
prefix4, err := PrefixV4()
prefix4, err := prefixV4()
if err != nil {
return nil, err
}
prefix6, err := PrefixV6()
prefix6, err := prefixV6()
if err != nil {
return nil, err
}
@ -763,13 +779,13 @@ func GetHeadscaleConfig() (*Config, error) {
return nil, fmt.Errorf("config error, prefixes.allocation is set to %s, which is not a valid strategy, allowed options: %s, %s", allocStr, IPAllocationStrategySequential, IPAllocationStrategyRandom)
}
dnsConfig, err := DNS()
dnsConfig, err := dns()
if err != nil {
return nil, err
}
derpConfig := GetDERPConfig()
logTailConfig := GetLogTailConfig()
derpConfig := derpConfig()
logTailConfig := logtailConfig()
randomizeClientPort := viper.GetBool("randomize_client_port")
oidcClientSecret := viper.GetString("oidc.client_secret")
@ -806,7 +822,7 @@ func GetHeadscaleConfig() (*Config, error) {
MetricsAddr: viper.GetString("metrics_listen_addr"),
GRPCAddr: viper.GetString("grpc_listen_addr"),
GRPCAllowInsecure: viper.GetBool("grpc_allow_insecure"),
DisableUpdateCheck: viper.GetBool("disable_check_updates"),
DisableUpdateCheck: false,
PrefixV4: prefix4,
PrefixV6: prefix6,
@ -823,11 +839,11 @@ func GetHeadscaleConfig() (*Config, error) {
"ephemeral_node_inactivity_timeout",
),
Database: GetDatabaseConfig(),
Database: databaseConfig(),
TLS: GetTLSConfig(),
TLS: tlsConfig(),
DNSConfig: DNSToTailcfgDNS(dnsConfig),
DNSConfig: dnsToTailcfgDNS(dnsConfig),
DNSUserNameInMagicDNS: dnsConfig.UserNameInMagicDNS,
ACMEEmail: viper.GetString("acme_email"),
@ -870,7 +886,7 @@ func GetHeadscaleConfig() (*Config, error) {
LogTail: logTailConfig,
RandomizeClientPort: randomizeClientPort,
Policy: GetPolicyConfig(),
Policy: policyConfig(),
CLI: CLIConfig{
Address: viper.GetString("cli.address"),
@ -890,10 +906,6 @@ func GetHeadscaleConfig() (*Config, error) {
}, nil
}
func IsCLIConfigured() bool {
return viper.GetString("cli.address") != "" && viper.GetString("cli.api_key") != ""
}
type deprecator struct {
warns set.Set[string]
fatals set.Set[string]