Split code into modules

This is a massive commit that restructures the code into modules:

db/
    All functions related to modifying the Database

types/
    All type definitions and methods that can be exclusivly used on
    these types without dependencies

policy/
    All Policy related code, now without dependencies on the Database.

policy/matcher/
    Dedicated code to match machines in a list of FilterRules

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2023-05-21 19:37:59 +03:00 committed by Kristoffer Dalby
parent 14e29a7bee
commit feb15365b5
51 changed files with 4677 additions and 4290 deletions

View file

@ -23,6 +23,9 @@ import (
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"github.com/juanfont/headscale"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/patrickmn/go-cache"
zerolog "github.com/philip-bui/grpc-zerolog"
@ -73,7 +76,7 @@ const (
// Headscale represents the base app of the service.
type Headscale struct {
cfg *Config
db *HSDatabase
db *db.HSDatabase
dbString string
dbType string
dbDebug bool
@ -83,7 +86,7 @@ type Headscale struct {
DERPMap *tailcfg.DERPMap
DERPServer *DERPServer
aclPolicy *ACLPolicy
ACLPolicy *policy.ACLPolicy
aclRules []tailcfg.FilterRule
sshPolicy *tailcfg.SSHPolicy
@ -99,6 +102,12 @@ type Headscale struct {
stateUpdateChan chan struct{}
cancelStateUpdateChan chan struct{}
// TODO(kradalby): Temporary measure to make sure we can update policy
// across modules, will be removed when aclRules are no longer stored
// globally but generated per node basis.
policyUpdateChan chan struct{}
cancelPolicyUpdateChan chan struct{}
}
func NewHeadscale(cfg *Config) (*Headscale, error) {
@ -119,7 +128,7 @@ func NewHeadscale(cfg *Config) (*Headscale, error) {
var dbString string
switch cfg.DBtype {
case Postgres:
case db.Postgres:
dbString = fmt.Sprintf(
"host=%s dbname=%s user=%s",
cfg.DBhost,
@ -142,7 +151,7 @@ func NewHeadscale(cfg *Config) (*Headscale, error) {
if cfg.DBpass != "" {
dbString += fmt.Sprintf(" password=%s", cfg.DBpass)
}
case Sqlite:
case db.Sqlite:
dbString = cfg.DBpath
default:
return nil, errUnsupportedDatabase
@ -166,23 +175,28 @@ func NewHeadscale(cfg *Config) (*Headscale, error) {
stateUpdateChan: make(chan struct{}),
cancelStateUpdateChan: make(chan struct{}),
policyUpdateChan: make(chan struct{}),
cancelPolicyUpdateChan: make(chan struct{}),
}
go app.watchStateChannel()
go app.watchPolicyChannel()
db, err := NewHeadscaleDatabase(
database, err := db.NewHeadscaleDatabase(
cfg.DBtype,
dbString,
cfg.OIDC.StripEmaildomain,
app.dbDebug,
app.stateUpdateChan,
app.policyUpdateChan,
cfg.IPPrefixes,
cfg.BaseDomain)
if err != nil {
return nil, err
}
app.db = db
app.db = database
if cfg.OIDC.Issuer != "" {
err = app.initOIDC()
@ -228,7 +242,7 @@ func (h *Headscale) redirect(w http.ResponseWriter, req *http.Request) {
func (h *Headscale) expireEphemeralNodes(milliSeconds int64) {
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
for range ticker.C {
h.expireEphemeralNodesWorker()
h.db.ExpireEphemeralMachines(h.cfg.EphemeralNodeInactivityTimeout)
}
}
@ -237,112 +251,20 @@ func (h *Headscale) expireEphemeralNodes(milliSeconds int64) {
func (h *Headscale) expireExpiredMachines(milliSeconds int64) {
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
for range ticker.C {
h.expireExpiredMachinesWorker()
h.db.ExpireExpiredMachines(h.getLastStateChange())
}
}
func (h *Headscale) failoverSubnetRoutes(milliSeconds int64) {
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
for range ticker.C {
err := h.db.handlePrimarySubnetFailover()
err := h.db.HandlePrimarySubnetFailover()
if err != nil {
log.Error().Err(err).Msg("failed to handle primary subnet failover")
}
}
}
func (h *Headscale) expireEphemeralNodesWorker() {
users, err := h.db.ListUsers()
if err != nil {
log.Error().Err(err).Msg("Error listing users")
return
}
for _, user := range users {
machines, err := h.db.ListMachinesByUser(user.Name)
if err != nil {
log.Error().
Err(err).
Str("user", user.Name).
Msg("Error listing machines in user")
return
}
expiredFound := false
for _, machine := range machines {
if machine.isEphemeral() && machine.LastSeen != nil &&
time.Now().
After(machine.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) {
expiredFound = true
log.Info().
Str("machine", machine.Hostname).
Msg("Ephemeral client removed from database")
err = h.db.db.Unscoped().Delete(machine).Error
if err != nil {
log.Error().
Err(err).
Str("machine", machine.Hostname).
Msg("🤮 Cannot delete ephemeral machine from the database")
}
}
}
if expiredFound {
h.setLastStateChangeToNow()
}
}
}
func (h *Headscale) expireExpiredMachinesWorker() {
users, err := h.db.ListUsers()
if err != nil {
log.Error().Err(err).Msg("Error listing users")
return
}
for _, user := range users {
machines, err := h.db.ListMachinesByUser(user.Name)
if err != nil {
log.Error().
Err(err).
Str("user", user.Name).
Msg("Error listing machines in user")
return
}
expiredFound := false
for index, machine := range machines {
if machine.isExpired() &&
machine.Expiry.After(h.getLastStateChange(user)) {
expiredFound = true
err := h.db.ExpireMachine(&machines[index])
if err != nil {
log.Error().
Err(err).
Str("machine", machine.Hostname).
Str("name", machine.GivenName).
Msg("🤮 Cannot expire machine")
} else {
log.Info().
Str("machine", machine.Hostname).
Str("name", machine.GivenName).
Msg("Machine successfully expired")
}
}
}
if expiredFound {
h.setLastStateChangeToNow()
}
}
}
func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
@ -565,6 +487,8 @@ func (h *Headscale) Serve() error {
go h.scheduledDERPMapUpdateWorker(derpMapCancelChannel)
}
// TODO(kradalby): These should have cancel channels and be cleaned
// up on shutdown.
go h.expireEphemeralNodes(updateInterval)
go h.expireExpiredMachines(updateInterval)
@ -774,10 +698,12 @@ func (h *Headscale) Serve() error {
if h.cfg.ACL.PolicyPath != "" {
aclPath := util.AbsolutePathFromConfigPath(h.cfg.ACL.PolicyPath)
err := h.LoadACLPolicyFromPath(aclPath)
pol, err := policy.LoadACLPolicyFromPath(aclPath)
if err != nil {
log.Error().Err(err).Msg("Failed to reload ACL policy")
}
h.ACLPolicy = pol
log.Info().
Str("path", aclPath).
Msg("ACL policy successfully reloaded, notifying nodes of change")
@ -824,12 +750,12 @@ func (h *Headscale) Serve() error {
close(h.stateUpdateChan)
close(h.cancelStateUpdateChan)
<-h.cancelPolicyUpdateChan
close(h.policyUpdateChan)
close(h.cancelPolicyUpdateChan)
// Close db connections
db, err := h.db.db.DB()
if err != nil {
log.Error().Err(err).Msg("Failed to get db handle")
}
err = db.Close()
err = h.db.Close()
if err != nil {
log.Error().Err(err).Msg("Failed to close db")
}
@ -936,6 +862,30 @@ func (h *Headscale) watchStateChannel() {
}
}
// TODO(kradalby): baby steps, make this more robust.
func (h *Headscale) watchPolicyChannel() {
for {
select {
case <-h.policyUpdateChan:
machines, err := h.db.ListMachines()
if err != nil {
log.Error().Err(err).Msg("failed to fetch machines during policy update")
}
rules, sshPolicy, err := policy.GenerateFilterRules(h.ACLPolicy, machines, h.cfg.OIDC.StripEmaildomain)
if err != nil {
log.Error().Err(err).Msg("failed to update ACL rules")
}
h.aclRules = rules
h.sshPolicy = sshPolicy
case <-h.cancelPolicyUpdateChan:
return
}
}
}
func (h *Headscale) setLastStateChangeToNow() {
var err error
@ -958,7 +908,7 @@ func (h *Headscale) setLastStateChangeToNow() {
}
}
func (h *Headscale) getLastStateChange(users ...User) time.Time {
func (h *Headscale) getLastStateChange(users ...types.User) time.Time {
times := []time.Time{}
// getLastStateChange takes a list of users as a "filter", if no users