Only load needed part of configuration (#2109)
This commit is contained in:
parent
f368ed01ed
commit
8a3a0fee3c
18 changed files with 196 additions and 324 deletions
|
@ -212,6 +212,12 @@ type Tuning struct {
|
|||
NodeMapSessionBufferedChanSize int
|
||||
}
|
||||
|
||||
// LoadConfig prepares and loads the Headscale configuration into Viper.
|
||||
// This means it sets the default values, reads the configuration file and
|
||||
// environment variables, and handles deprecated configuration options.
|
||||
// It has to be called before LoadServerConfig and LoadCLIConfig.
|
||||
// The configuration is not validated and the caller should check for errors
|
||||
// using a validation function.
|
||||
func LoadConfig(path string, isFile bool) error {
|
||||
if isFile {
|
||||
viper.SetConfigFile(path)
|
||||
|
@ -284,14 +290,14 @@ func LoadConfig(path string, isFile bool) error {
|
|||
|
||||
viper.SetDefault("prefixes.allocation", string(IPAllocationStrategySequential))
|
||||
|
||||
if IsCLIConfigured() {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := viper.ReadInConfig(); err != nil {
|
||||
return fmt.Errorf("fatal error reading config file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateServerConfig() error {
|
||||
depr := deprecator{
|
||||
warns: make(set.Set[string]),
|
||||
fatals: make(set.Set[string]),
|
||||
|
@ -360,12 +366,12 @@ func LoadConfig(path string, isFile bool) error {
|
|||
if errorText != "" {
|
||||
// nolint
|
||||
return errors.New(strings.TrimSuffix(errorText, "\n"))
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetTLSConfig() TLSConfig {
|
||||
func tlsConfig() TLSConfig {
|
||||
return TLSConfig{
|
||||
LetsEncrypt: LetsEncryptConfig{
|
||||
Hostname: viper.GetString("tls_letsencrypt_hostname"),
|
||||
|
@ -384,7 +390,7 @@ func GetTLSConfig() TLSConfig {
|
|||
}
|
||||
}
|
||||
|
||||
func GetDERPConfig() DERPConfig {
|
||||
func derpConfig() DERPConfig {
|
||||
serverEnabled := viper.GetBool("derp.server.enabled")
|
||||
serverRegionID := viper.GetInt("derp.server.region_id")
|
||||
serverRegionCode := viper.GetString("derp.server.region_code")
|
||||
|
@ -445,7 +451,7 @@ func GetDERPConfig() DERPConfig {
|
|||
}
|
||||
}
|
||||
|
||||
func GetLogTailConfig() LogTailConfig {
|
||||
func logtailConfig() LogTailConfig {
|
||||
enabled := viper.GetBool("logtail.enabled")
|
||||
|
||||
return LogTailConfig{
|
||||
|
@ -453,7 +459,7 @@ func GetLogTailConfig() LogTailConfig {
|
|||
}
|
||||
}
|
||||
|
||||
func GetPolicyConfig() PolicyConfig {
|
||||
func policyConfig() PolicyConfig {
|
||||
policyPath := viper.GetString("policy.path")
|
||||
policyMode := viper.GetString("policy.mode")
|
||||
|
||||
|
@ -463,7 +469,7 @@ func GetPolicyConfig() PolicyConfig {
|
|||
}
|
||||
}
|
||||
|
||||
func GetLogConfig() LogConfig {
|
||||
func logConfig() LogConfig {
|
||||
logLevelStr := viper.GetString("log.level")
|
||||
logLevel, err := zerolog.ParseLevel(logLevelStr)
|
||||
if err != nil {
|
||||
|
@ -473,9 +479,9 @@ func GetLogConfig() LogConfig {
|
|||
logFormatOpt := viper.GetString("log.format")
|
||||
var logFormat string
|
||||
switch logFormatOpt {
|
||||
case "json":
|
||||
case JSONLogFormat:
|
||||
logFormat = JSONLogFormat
|
||||
case "text":
|
||||
case TextLogFormat:
|
||||
logFormat = TextLogFormat
|
||||
case "":
|
||||
logFormat = TextLogFormat
|
||||
|
@ -491,7 +497,7 @@ func GetLogConfig() LogConfig {
|
|||
}
|
||||
}
|
||||
|
||||
func GetDatabaseConfig() DatabaseConfig {
|
||||
func databaseConfig() DatabaseConfig {
|
||||
debug := viper.GetBool("database.debug")
|
||||
|
||||
type_ := viper.GetString("database.type")
|
||||
|
@ -543,7 +549,7 @@ func GetDatabaseConfig() DatabaseConfig {
|
|||
}
|
||||
}
|
||||
|
||||
func DNS() (DNSConfig, error) {
|
||||
func dns() (DNSConfig, error) {
|
||||
var dns DNSConfig
|
||||
|
||||
// TODO: Use this instead of manually getting settings when
|
||||
|
@ -575,12 +581,12 @@ func DNS() (DNSConfig, error) {
|
|||
return dns, nil
|
||||
}
|
||||
|
||||
// GlobalResolvers returns the global DNS resolvers
|
||||
// globalResolvers returns the global DNS resolvers
|
||||
// defined in the config file.
|
||||
// If a nameserver is a valid IP, it will be used as a regular resolver.
|
||||
// If a nameserver is a valid URL, it will be used as a DoH resolver.
|
||||
// If a nameserver is neither a valid URL nor a valid IP, it will be ignored.
|
||||
func (d *DNSConfig) GlobalResolvers() []*dnstype.Resolver {
|
||||
func (d *DNSConfig) globalResolvers() []*dnstype.Resolver {
|
||||
var resolvers []*dnstype.Resolver
|
||||
|
||||
for _, nsStr := range d.Nameservers.Global {
|
||||
|
@ -613,11 +619,11 @@ func (d *DNSConfig) GlobalResolvers() []*dnstype.Resolver {
|
|||
return resolvers
|
||||
}
|
||||
|
||||
// SplitResolvers returns a map of domain to DNS resolvers.
|
||||
// splitResolvers returns a map of domain to DNS resolvers.
|
||||
// If a nameserver is a valid IP, it will be used as a regular resolver.
|
||||
// If a nameserver is a valid URL, it will be used as a DoH resolver.
|
||||
// If a nameserver is neither a valid URL nor a valid IP, it will be ignored.
|
||||
func (d *DNSConfig) SplitResolvers() map[string][]*dnstype.Resolver {
|
||||
func (d *DNSConfig) splitResolvers() map[string][]*dnstype.Resolver {
|
||||
routes := make(map[string][]*dnstype.Resolver)
|
||||
for domain, nameservers := range d.Nameservers.Split {
|
||||
var resolvers []*dnstype.Resolver
|
||||
|
@ -653,7 +659,7 @@ func (d *DNSConfig) SplitResolvers() map[string][]*dnstype.Resolver {
|
|||
return routes
|
||||
}
|
||||
|
||||
func DNSToTailcfgDNS(dns DNSConfig) *tailcfg.DNSConfig {
|
||||
func dnsToTailcfgDNS(dns DNSConfig) *tailcfg.DNSConfig {
|
||||
cfg := tailcfg.DNSConfig{}
|
||||
|
||||
if dns.BaseDomain == "" && dns.MagicDNS {
|
||||
|
@ -662,9 +668,9 @@ func DNSToTailcfgDNS(dns DNSConfig) *tailcfg.DNSConfig {
|
|||
|
||||
cfg.Proxied = dns.MagicDNS
|
||||
cfg.ExtraRecords = dns.ExtraRecords
|
||||
cfg.Resolvers = dns.GlobalResolvers()
|
||||
cfg.Resolvers = dns.globalResolvers()
|
||||
|
||||
routes := dns.SplitResolvers()
|
||||
routes := dns.splitResolvers()
|
||||
cfg.Routes = routes
|
||||
if dns.BaseDomain != "" {
|
||||
cfg.Domains = []string{dns.BaseDomain}
|
||||
|
@ -674,7 +680,7 @@ func DNSToTailcfgDNS(dns DNSConfig) *tailcfg.DNSConfig {
|
|||
return &cfg
|
||||
}
|
||||
|
||||
func PrefixV4() (*netip.Prefix, error) {
|
||||
func prefixV4() (*netip.Prefix, error) {
|
||||
prefixV4Str := viper.GetString("prefixes.v4")
|
||||
|
||||
if prefixV4Str == "" {
|
||||
|
@ -698,7 +704,7 @@ func PrefixV4() (*netip.Prefix, error) {
|
|||
return &prefixV4, nil
|
||||
}
|
||||
|
||||
func PrefixV6() (*netip.Prefix, error) {
|
||||
func prefixV6() (*netip.Prefix, error) {
|
||||
prefixV6Str := viper.GetString("prefixes.v6")
|
||||
|
||||
if prefixV6Str == "" {
|
||||
|
@ -723,27 +729,37 @@ func PrefixV6() (*netip.Prefix, error) {
|
|||
return &prefixV6, nil
|
||||
}
|
||||
|
||||
func GetHeadscaleConfig() (*Config, error) {
|
||||
if IsCLIConfigured() {
|
||||
return &Config{
|
||||
CLI: CLIConfig{
|
||||
Address: viper.GetString("cli.address"),
|
||||
APIKey: viper.GetString("cli.api_key"),
|
||||
Timeout: viper.GetDuration("cli.timeout"),
|
||||
Insecure: viper.GetBool("cli.insecure"),
|
||||
},
|
||||
}, nil
|
||||
// LoadCLIConfig returns the needed configuration for the CLI client
|
||||
// of Headscale to connect to a Headscale server.
|
||||
func LoadCLIConfig() (*Config, error) {
|
||||
return &Config{
|
||||
DisableUpdateCheck: viper.GetBool("disable_check_updates"),
|
||||
UnixSocket: viper.GetString("unix_socket"),
|
||||
CLI: CLIConfig{
|
||||
Address: viper.GetString("cli.address"),
|
||||
APIKey: viper.GetString("cli.api_key"),
|
||||
Timeout: viper.GetDuration("cli.timeout"),
|
||||
Insecure: viper.GetBool("cli.insecure"),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// LoadServerConfig returns the full Headscale configuration to
|
||||
// host a Headscale server. This is called as part of `headscale serve`.
|
||||
func LoadServerConfig() (*Config, error) {
|
||||
if err := validateServerConfig(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logConfig := GetLogConfig()
|
||||
logConfig := logConfig()
|
||||
zerolog.SetGlobalLevel(logConfig.Level)
|
||||
|
||||
prefix4, err := PrefixV4()
|
||||
prefix4, err := prefixV4()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
prefix6, err := PrefixV6()
|
||||
prefix6, err := prefixV6()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -763,13 +779,13 @@ func GetHeadscaleConfig() (*Config, error) {
|
|||
return nil, fmt.Errorf("config error, prefixes.allocation is set to %s, which is not a valid strategy, allowed options: %s, %s", allocStr, IPAllocationStrategySequential, IPAllocationStrategyRandom)
|
||||
}
|
||||
|
||||
dnsConfig, err := DNS()
|
||||
dnsConfig, err := dns()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
derpConfig := GetDERPConfig()
|
||||
logTailConfig := GetLogTailConfig()
|
||||
derpConfig := derpConfig()
|
||||
logTailConfig := logtailConfig()
|
||||
randomizeClientPort := viper.GetBool("randomize_client_port")
|
||||
|
||||
oidcClientSecret := viper.GetString("oidc.client_secret")
|
||||
|
@ -806,7 +822,7 @@ func GetHeadscaleConfig() (*Config, error) {
|
|||
MetricsAddr: viper.GetString("metrics_listen_addr"),
|
||||
GRPCAddr: viper.GetString("grpc_listen_addr"),
|
||||
GRPCAllowInsecure: viper.GetBool("grpc_allow_insecure"),
|
||||
DisableUpdateCheck: viper.GetBool("disable_check_updates"),
|
||||
DisableUpdateCheck: false,
|
||||
|
||||
PrefixV4: prefix4,
|
||||
PrefixV6: prefix6,
|
||||
|
@ -823,11 +839,11 @@ func GetHeadscaleConfig() (*Config, error) {
|
|||
"ephemeral_node_inactivity_timeout",
|
||||
),
|
||||
|
||||
Database: GetDatabaseConfig(),
|
||||
Database: databaseConfig(),
|
||||
|
||||
TLS: GetTLSConfig(),
|
||||
TLS: tlsConfig(),
|
||||
|
||||
DNSConfig: DNSToTailcfgDNS(dnsConfig),
|
||||
DNSConfig: dnsToTailcfgDNS(dnsConfig),
|
||||
DNSUserNameInMagicDNS: dnsConfig.UserNameInMagicDNS,
|
||||
|
||||
ACMEEmail: viper.GetString("acme_email"),
|
||||
|
@ -870,7 +886,7 @@ func GetHeadscaleConfig() (*Config, error) {
|
|||
LogTail: logTailConfig,
|
||||
RandomizeClientPort: randomizeClientPort,
|
||||
|
||||
Policy: GetPolicyConfig(),
|
||||
Policy: policyConfig(),
|
||||
|
||||
CLI: CLIConfig{
|
||||
Address: viper.GetString("cli.address"),
|
||||
|
@ -890,10 +906,6 @@ func GetHeadscaleConfig() (*Config, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
func IsCLIConfigured() bool {
|
||||
return viper.GetString("cli.address") != "" && viper.GetString("cli.api_key") != ""
|
||||
}
|
||||
|
||||
type deprecator struct {
|
||||
warns set.Set[string]
|
||||
fatals set.Set[string]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue