create DB struct
This is step one in detaching the Database layer from Headscale (h). The ultimate goal is to have all function that does database operations in its own package, and keep the business logic and writing separate. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
b01f1f1867
commit
14e29a7bee
48 changed files with 1731 additions and 1572 deletions
|
@ -23,6 +23,7 @@ import (
|
|||
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
|
||||
"github.com/juanfont/headscale"
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/patrickmn/go-cache"
|
||||
zerolog "github.com/philip-bui/grpc-zerolog"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
|
@ -41,24 +42,21 @@ import (
|
|||
"google.golang.org/grpc/peer"
|
||||
"google.golang.org/grpc/reflection"
|
||||
"google.golang.org/grpc/status"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/dnstype"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
const (
|
||||
errSTUNAddressNotSet = Error("STUN address not set")
|
||||
errUnsupportedDatabase = Error("unsupported DB")
|
||||
errUnsupportedLetsEncryptChallengeType = Error(
|
||||
var (
|
||||
errSTUNAddressNotSet = errors.New("STUN address not set")
|
||||
errUnsupportedDatabase = errors.New("unsupported DB")
|
||||
errUnsupportedLetsEncryptChallengeType = errors.New(
|
||||
"unknown value for Lets Encrypt challenge type",
|
||||
)
|
||||
)
|
||||
|
||||
const (
|
||||
AuthPrefix = "Bearer "
|
||||
Postgres = "postgres"
|
||||
Sqlite = "sqlite3"
|
||||
updateInterval = 5000
|
||||
HTTPReadTimeout = 30 * time.Second
|
||||
HTTPShutdownTimeout = 3 * time.Second
|
||||
|
@ -75,7 +73,7 @@ const (
|
|||
// Headscale represents the base app of the service.
|
||||
type Headscale struct {
|
||||
cfg *Config
|
||||
db *gorm.DB
|
||||
db *HSDatabase
|
||||
dbString string
|
||||
dbType string
|
||||
dbDebug bool
|
||||
|
@ -96,10 +94,11 @@ type Headscale struct {
|
|||
|
||||
registrationCache *cache.Cache
|
||||
|
||||
ipAllocationMutex sync.Mutex
|
||||
|
||||
shutdownChan chan struct{}
|
||||
pollNetMapStreamWG sync.WaitGroup
|
||||
|
||||
stateUpdateChan chan struct{}
|
||||
cancelStateUpdateChan chan struct{}
|
||||
}
|
||||
|
||||
func NewHeadscale(cfg *Config) (*Headscale, error) {
|
||||
|
@ -164,13 +163,27 @@ func NewHeadscale(cfg *Config) (*Headscale, error) {
|
|||
registrationCache: registrationCache,
|
||||
pollNetMapStreamWG: sync.WaitGroup{},
|
||||
lastStateChange: xsync.NewMapOf[time.Time](),
|
||||
|
||||
stateUpdateChan: make(chan struct{}),
|
||||
cancelStateUpdateChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
err = app.initDB()
|
||||
go app.watchStateChannel()
|
||||
|
||||
db, err := NewHeadscaleDatabase(
|
||||
cfg.DBtype,
|
||||
dbString,
|
||||
cfg.OIDC.StripEmaildomain,
|
||||
app.dbDebug,
|
||||
app.stateUpdateChan,
|
||||
cfg.IPPrefixes,
|
||||
cfg.BaseDomain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
app.db = db
|
||||
|
||||
if cfg.OIDC.Issuer != "" {
|
||||
err = app.initOIDC()
|
||||
if err != nil {
|
||||
|
@ -231,7 +244,7 @@ func (h *Headscale) expireExpiredMachines(milliSeconds int64) {
|
|||
func (h *Headscale) failoverSubnetRoutes(milliSeconds int64) {
|
||||
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
|
||||
for range ticker.C {
|
||||
err := h.handlePrimarySubnetFailover()
|
||||
err := h.db.handlePrimarySubnetFailover()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to handle primary subnet failover")
|
||||
}
|
||||
|
@ -239,7 +252,7 @@ func (h *Headscale) failoverSubnetRoutes(milliSeconds int64) {
|
|||
}
|
||||
|
||||
func (h *Headscale) expireEphemeralNodesWorker() {
|
||||
users, err := h.ListUsers()
|
||||
users, err := h.db.ListUsers()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Error listing users")
|
||||
|
||||
|
@ -247,7 +260,7 @@ func (h *Headscale) expireEphemeralNodesWorker() {
|
|||
}
|
||||
|
||||
for _, user := range users {
|
||||
machines, err := h.ListMachinesByUser(user.Name)
|
||||
machines, err := h.db.ListMachinesByUser(user.Name)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
|
@ -267,7 +280,7 @@ func (h *Headscale) expireEphemeralNodesWorker() {
|
|||
Str("machine", machine.Hostname).
|
||||
Msg("Ephemeral client removed from database")
|
||||
|
||||
err = h.db.Unscoped().Delete(machine).Error
|
||||
err = h.db.db.Unscoped().Delete(machine).Error
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
|
@ -284,7 +297,7 @@ func (h *Headscale) expireEphemeralNodesWorker() {
|
|||
}
|
||||
|
||||
func (h *Headscale) expireExpiredMachinesWorker() {
|
||||
users, err := h.ListUsers()
|
||||
users, err := h.db.ListUsers()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Error listing users")
|
||||
|
||||
|
@ -292,7 +305,7 @@ func (h *Headscale) expireExpiredMachinesWorker() {
|
|||
}
|
||||
|
||||
for _, user := range users {
|
||||
machines, err := h.ListMachinesByUser(user.Name)
|
||||
machines, err := h.db.ListMachinesByUser(user.Name)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
|
@ -308,7 +321,7 @@ func (h *Headscale) expireExpiredMachinesWorker() {
|
|||
machine.Expiry.After(h.getLastStateChange(user)) {
|
||||
expiredFound = true
|
||||
|
||||
err := h.ExpireMachine(&machines[index])
|
||||
err := h.db.ExpireMachine(&machines[index])
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
|
@ -387,7 +400,7 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
|
|||
)
|
||||
}
|
||||
|
||||
valid, err := h.ValidateAPIKey(strings.TrimPrefix(token, AuthPrefix))
|
||||
valid, err := h.db.ValidateAPIKey(strings.TrimPrefix(token, AuthPrefix))
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
@ -438,7 +451,7 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
|
|||
return
|
||||
}
|
||||
|
||||
valid, err := h.ValidateAPIKey(strings.TrimPrefix(authHeader, AuthPrefix))
|
||||
valid, err := h.db.ValidateAPIKey(strings.TrimPrefix(authHeader, AuthPrefix))
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
@ -597,7 +610,7 @@ func (h *Headscale) Serve() error {
|
|||
h.cfg.UnixSocket,
|
||||
[]grpc.DialOption{
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithContextDialer(GrpcSocketDialer),
|
||||
grpc.WithContextDialer(util.GrpcSocketDialer),
|
||||
}...,
|
||||
)
|
||||
if err != nil {
|
||||
|
@ -760,7 +773,7 @@ func (h *Headscale) Serve() error {
|
|||
// TODO(kradalby): Reload config on SIGHUP
|
||||
|
||||
if h.cfg.ACL.PolicyPath != "" {
|
||||
aclPath := AbsolutePathFromConfigPath(h.cfg.ACL.PolicyPath)
|
||||
aclPath := util.AbsolutePathFromConfigPath(h.cfg.ACL.PolicyPath)
|
||||
err := h.LoadACLPolicyFromPath(aclPath)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to reload ACL policy")
|
||||
|
@ -778,6 +791,7 @@ func (h *Headscale) Serve() error {
|
|||
Msg("Received signal to stop, shutting down gracefully")
|
||||
|
||||
close(h.shutdownChan)
|
||||
|
||||
h.pollNetMapStreamWG.Wait()
|
||||
|
||||
// Gracefully shut down servers
|
||||
|
@ -806,8 +820,12 @@ func (h *Headscale) Serve() error {
|
|||
// Stop listening (and unlink the socket if unix type):
|
||||
socketListener.Close()
|
||||
|
||||
<-h.cancelStateUpdateChan
|
||||
close(h.stateUpdateChan)
|
||||
close(h.cancelStateUpdateChan)
|
||||
|
||||
// Close db connections
|
||||
db, err := h.db.DB()
|
||||
db, err := h.db.db.DB()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to get db handle")
|
||||
}
|
||||
|
@ -905,12 +923,25 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
|
|||
}
|
||||
}
|
||||
|
||||
// TODO(kradalby): baby steps, make this more robust.
|
||||
func (h *Headscale) watchStateChannel() {
|
||||
for {
|
||||
select {
|
||||
case <-h.stateUpdateChan:
|
||||
h.setLastStateChangeToNow()
|
||||
|
||||
case <-h.cancelStateUpdateChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Headscale) setLastStateChangeToNow() {
|
||||
var err error
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
users, err := h.ListUsers()
|
||||
users, err := h.db.ListUsers()
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
@ -1002,7 +1033,7 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
|
|||
}
|
||||
|
||||
trimmedPrivateKey := strings.TrimSpace(string(privateKey))
|
||||
privateKeyEnsurePrefix := PrivateKeyEnsurePrefix(trimmedPrivateKey)
|
||||
privateKeyEnsurePrefix := util.PrivateKeyEnsurePrefix(trimmedPrivateKey)
|
||||
|
||||
var machineKey key.MachinePrivate
|
||||
if err = machineKey.UnmarshalText([]byte(privateKeyEnsurePrefix)); err != nil {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue