Split code into modules
This is a massive commit that restructures the code into modules: db/ All functions related to modifying the Database types/ All type definitions and methods that can be exclusivly used on these types without dependencies policy/ All Policy related code, now without dependencies on the Database. policy/matcher/ Dedicated code to match machines in a list of FilterRules Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
14e29a7bee
commit
feb15365b5
51 changed files with 4677 additions and 4290 deletions
166
hscontrol/app.go
166
hscontrol/app.go
|
@ -23,6 +23,9 @@ import (
|
|||
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
|
||||
"github.com/juanfont/headscale"
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/db"
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/patrickmn/go-cache"
|
||||
zerolog "github.com/philip-bui/grpc-zerolog"
|
||||
|
@ -73,7 +76,7 @@ const (
|
|||
// Headscale represents the base app of the service.
|
||||
type Headscale struct {
|
||||
cfg *Config
|
||||
db *HSDatabase
|
||||
db *db.HSDatabase
|
||||
dbString string
|
||||
dbType string
|
||||
dbDebug bool
|
||||
|
@ -83,7 +86,7 @@ type Headscale struct {
|
|||
DERPMap *tailcfg.DERPMap
|
||||
DERPServer *DERPServer
|
||||
|
||||
aclPolicy *ACLPolicy
|
||||
ACLPolicy *policy.ACLPolicy
|
||||
aclRules []tailcfg.FilterRule
|
||||
sshPolicy *tailcfg.SSHPolicy
|
||||
|
||||
|
@ -99,6 +102,12 @@ type Headscale struct {
|
|||
|
||||
stateUpdateChan chan struct{}
|
||||
cancelStateUpdateChan chan struct{}
|
||||
|
||||
// TODO(kradalby): Temporary measure to make sure we can update policy
|
||||
// across modules, will be removed when aclRules are no longer stored
|
||||
// globally but generated per node basis.
|
||||
policyUpdateChan chan struct{}
|
||||
cancelPolicyUpdateChan chan struct{}
|
||||
}
|
||||
|
||||
func NewHeadscale(cfg *Config) (*Headscale, error) {
|
||||
|
@ -119,7 +128,7 @@ func NewHeadscale(cfg *Config) (*Headscale, error) {
|
|||
|
||||
var dbString string
|
||||
switch cfg.DBtype {
|
||||
case Postgres:
|
||||
case db.Postgres:
|
||||
dbString = fmt.Sprintf(
|
||||
"host=%s dbname=%s user=%s",
|
||||
cfg.DBhost,
|
||||
|
@ -142,7 +151,7 @@ func NewHeadscale(cfg *Config) (*Headscale, error) {
|
|||
if cfg.DBpass != "" {
|
||||
dbString += fmt.Sprintf(" password=%s", cfg.DBpass)
|
||||
}
|
||||
case Sqlite:
|
||||
case db.Sqlite:
|
||||
dbString = cfg.DBpath
|
||||
default:
|
||||
return nil, errUnsupportedDatabase
|
||||
|
@ -166,23 +175,28 @@ func NewHeadscale(cfg *Config) (*Headscale, error) {
|
|||
|
||||
stateUpdateChan: make(chan struct{}),
|
||||
cancelStateUpdateChan: make(chan struct{}),
|
||||
|
||||
policyUpdateChan: make(chan struct{}),
|
||||
cancelPolicyUpdateChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
go app.watchStateChannel()
|
||||
go app.watchPolicyChannel()
|
||||
|
||||
db, err := NewHeadscaleDatabase(
|
||||
database, err := db.NewHeadscaleDatabase(
|
||||
cfg.DBtype,
|
||||
dbString,
|
||||
cfg.OIDC.StripEmaildomain,
|
||||
app.dbDebug,
|
||||
app.stateUpdateChan,
|
||||
app.policyUpdateChan,
|
||||
cfg.IPPrefixes,
|
||||
cfg.BaseDomain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
app.db = db
|
||||
app.db = database
|
||||
|
||||
if cfg.OIDC.Issuer != "" {
|
||||
err = app.initOIDC()
|
||||
|
@ -228,7 +242,7 @@ func (h *Headscale) redirect(w http.ResponseWriter, req *http.Request) {
|
|||
func (h *Headscale) expireEphemeralNodes(milliSeconds int64) {
|
||||
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
|
||||
for range ticker.C {
|
||||
h.expireEphemeralNodesWorker()
|
||||
h.db.ExpireEphemeralMachines(h.cfg.EphemeralNodeInactivityTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -237,112 +251,20 @@ func (h *Headscale) expireEphemeralNodes(milliSeconds int64) {
|
|||
func (h *Headscale) expireExpiredMachines(milliSeconds int64) {
|
||||
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
|
||||
for range ticker.C {
|
||||
h.expireExpiredMachinesWorker()
|
||||
h.db.ExpireExpiredMachines(h.getLastStateChange())
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Headscale) failoverSubnetRoutes(milliSeconds int64) {
|
||||
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
|
||||
for range ticker.C {
|
||||
err := h.db.handlePrimarySubnetFailover()
|
||||
err := h.db.HandlePrimarySubnetFailover()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to handle primary subnet failover")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Headscale) expireEphemeralNodesWorker() {
|
||||
users, err := h.db.ListUsers()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Error listing users")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
machines, err := h.db.ListMachinesByUser(user.Name)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("user", user.Name).
|
||||
Msg("Error listing machines in user")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
expiredFound := false
|
||||
for _, machine := range machines {
|
||||
if machine.isEphemeral() && machine.LastSeen != nil &&
|
||||
time.Now().
|
||||
After(machine.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) {
|
||||
expiredFound = true
|
||||
log.Info().
|
||||
Str("machine", machine.Hostname).
|
||||
Msg("Ephemeral client removed from database")
|
||||
|
||||
err = h.db.db.Unscoped().Delete(machine).Error
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("machine", machine.Hostname).
|
||||
Msg("🤮 Cannot delete ephemeral machine from the database")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if expiredFound {
|
||||
h.setLastStateChangeToNow()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Headscale) expireExpiredMachinesWorker() {
|
||||
users, err := h.db.ListUsers()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Error listing users")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
machines, err := h.db.ListMachinesByUser(user.Name)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("user", user.Name).
|
||||
Msg("Error listing machines in user")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
expiredFound := false
|
||||
for index, machine := range machines {
|
||||
if machine.isExpired() &&
|
||||
machine.Expiry.After(h.getLastStateChange(user)) {
|
||||
expiredFound = true
|
||||
|
||||
err := h.db.ExpireMachine(&machines[index])
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("machine", machine.Hostname).
|
||||
Str("name", machine.GivenName).
|
||||
Msg("🤮 Cannot expire machine")
|
||||
} else {
|
||||
log.Info().
|
||||
Str("machine", machine.Hostname).
|
||||
Str("name", machine.GivenName).
|
||||
Msg("Machine successfully expired")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if expiredFound {
|
||||
h.setLastStateChangeToNow()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
|
||||
req interface{},
|
||||
info *grpc.UnaryServerInfo,
|
||||
|
@ -565,6 +487,8 @@ func (h *Headscale) Serve() error {
|
|||
go h.scheduledDERPMapUpdateWorker(derpMapCancelChannel)
|
||||
}
|
||||
|
||||
// TODO(kradalby): These should have cancel channels and be cleaned
|
||||
// up on shutdown.
|
||||
go h.expireEphemeralNodes(updateInterval)
|
||||
go h.expireExpiredMachines(updateInterval)
|
||||
|
||||
|
@ -774,10 +698,12 @@ func (h *Headscale) Serve() error {
|
|||
|
||||
if h.cfg.ACL.PolicyPath != "" {
|
||||
aclPath := util.AbsolutePathFromConfigPath(h.cfg.ACL.PolicyPath)
|
||||
err := h.LoadACLPolicyFromPath(aclPath)
|
||||
pol, err := policy.LoadACLPolicyFromPath(aclPath)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to reload ACL policy")
|
||||
}
|
||||
|
||||
h.ACLPolicy = pol
|
||||
log.Info().
|
||||
Str("path", aclPath).
|
||||
Msg("ACL policy successfully reloaded, notifying nodes of change")
|
||||
|
@ -824,12 +750,12 @@ func (h *Headscale) Serve() error {
|
|||
close(h.stateUpdateChan)
|
||||
close(h.cancelStateUpdateChan)
|
||||
|
||||
<-h.cancelPolicyUpdateChan
|
||||
close(h.policyUpdateChan)
|
||||
close(h.cancelPolicyUpdateChan)
|
||||
|
||||
// Close db connections
|
||||
db, err := h.db.db.DB()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to get db handle")
|
||||
}
|
||||
err = db.Close()
|
||||
err = h.db.Close()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to close db")
|
||||
}
|
||||
|
@ -936,6 +862,30 @@ func (h *Headscale) watchStateChannel() {
|
|||
}
|
||||
}
|
||||
|
||||
// TODO(kradalby): baby steps, make this more robust.
|
||||
func (h *Headscale) watchPolicyChannel() {
|
||||
for {
|
||||
select {
|
||||
case <-h.policyUpdateChan:
|
||||
machines, err := h.db.ListMachines()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to fetch machines during policy update")
|
||||
}
|
||||
|
||||
rules, sshPolicy, err := policy.GenerateFilterRules(h.ACLPolicy, machines, h.cfg.OIDC.StripEmaildomain)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to update ACL rules")
|
||||
}
|
||||
|
||||
h.aclRules = rules
|
||||
h.sshPolicy = sshPolicy
|
||||
|
||||
case <-h.cancelPolicyUpdateChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Headscale) setLastStateChangeToNow() {
|
||||
var err error
|
||||
|
||||
|
@ -958,7 +908,7 @@ func (h *Headscale) setLastStateChangeToNow() {
|
|||
}
|
||||
}
|
||||
|
||||
func (h *Headscale) getLastStateChange(users ...User) time.Time {
|
||||
func (h *Headscale) getLastStateChange(users ...types.User) time.Time {
|
||||
times := []time.Time{}
|
||||
|
||||
// getLastStateChange takes a list of users as a "filter", if no users
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue