From ba168d0c33f0717c086cdf729dfd8e37267befb2 Mon Sep 17 00:00:00 2001 From: eyjhb Date: Tue, 3 Jun 2025 16:33:27 +0200 Subject: [PATCH] add ability to use context with serve --- hscontrol/app.go | 213 +++++++++++++++++++++++++---------------------- 1 file changed, 112 insertions(+), 101 deletions(-) diff --git a/hscontrol/app.go b/hscontrol/app.go index d62acb34..870d307b 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -555,7 +555,7 @@ func nodesChangedHook( } // Serve launches the HTTP and gRPC server service Headscale and the API. -func (h *Headscale) Serve() error { +func (h *Headscale) Serve(ctx context.Context) error { capver.CanOldCodeBeCleanedUp() if profilingEnabled { @@ -631,7 +631,7 @@ func (h *Headscale) Serve() error { // Start all scheduled tasks, e.g. expiring nodes, derp updates and // records updates - scheduleCtx, scheduleCancel := context.WithCancel(context.Background()) + scheduleCtx, scheduleCancel := context.WithCancel(ctx) defer scheduleCancel() go h.scheduledTasks(scheduleCtx) @@ -644,7 +644,6 @@ func (h *Headscale) Serve() error { // Prepare group for running listeners errorGroup := new(errgroup.Group) - ctx := context.Background() ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -837,112 +836,124 @@ func (h *Headscale) Serve() error { sigFunc := func(c chan os.Signal) { // Wait for a SIGINT or SIGKILL: for { - sig := <-c - switch sig { - case syscall.SIGHUP: - log.Info(). - Str("signal", sig.String()). - Msg("Received SIGHUP, reloading ACL and Config") - - if h.cfg.Policy.IsEmpty() { - continue - } - - if err := h.loadPolicyManager(); err != nil { - log.Error().Err(err).Msg("failed to reload Policy") - } - - pol, err := h.policyBytes() - if err != nil { - log.Error().Err(err).Msg("failed to get policy blob") - } - - changed, err := h.polMan.SetPolicy(pol) - if err != nil { - log.Error().Err(err).Msg("failed to set new policy") - } - - if changed { + select { + case sig := <-c: + switch sig { + case syscall.SIGHUP: log.Info(). - Msg("ACL policy successfully reloaded, notifying nodes of change") + Str("signal", sig.String()). + Msg("Received SIGHUP, reloading ACL and Config") - err = h.autoApproveNodes() - if err != nil { - log.Error().Err(err).Msg("failed to approve routes after new policy") + if h.cfg.Policy.IsEmpty() { + continue } - ctx := types.NotifyCtx(context.Background(), "acl-sighup", "na") - h.nodeNotifier.NotifyAll(ctx, types.UpdateFull()) + if err := h.loadPolicyManager(); err != nil { + log.Error().Err(err).Msg("failed to reload Policy") + } + + pol, err := h.policyBytes() + if err != nil { + log.Error().Err(err).Msg("failed to get policy blob") + } + + changed, err := h.polMan.SetPolicy(pol) + if err != nil { + log.Error().Err(err).Msg("failed to set new policy") + } + + if changed { + log.Info(). + Msg("ACL policy successfully reloaded, notifying nodes of change") + + err = h.autoApproveNodes() + if err != nil { + log.Error().Err(err).Msg("failed to approve routes after new policy") + } + + ctx := types.NotifyCtx(context.Background(), "acl-sighup", "na") + h.nodeNotifier.NotifyAll(ctx, types.UpdateFull()) + } + default: + info := func(msg string) { log.Info().Msg(msg) } + log.Info(). + Str("signal", sig.String()). + Msg("Received signal to stop, shutting down gracefully") + + scheduleCancel() + h.ephemeralGC.Close() + + // Gracefully shut down servers + ctx, cancel := context.WithTimeout( + context.Background(), + types.HTTPShutdownTimeout, + ) + info("shutting down debug http server") + if err := debugHTTPServer.Shutdown(ctx); err != nil { + log.Error().Err(err).Msg("failed to shutdown prometheus http") + } + info("shutting down main http server") + if err := httpServer.Shutdown(ctx); err != nil { + log.Error().Err(err).Msg("failed to shutdown http") + } + + info("closing node notifier") + h.nodeNotifier.Close() + + info("waiting for netmap stream to close") + h.pollNetMapStreamWG.Wait() + + info("shutting down grpc server (socket)") + grpcSocket.GracefulStop() + + if grpcServer != nil { + info("shutting down grpc server (external)") + grpcServer.GracefulStop() + grpcListener.Close() + } + + if tailsqlContext != nil { + info("shutting down tailsql") + tailsqlContext.Done() + } + + // Close network listeners + info("closing network listeners") + debugHTTPListener.Close() + httpListener.Close() + grpcGatewayConn.Close() + + // Stop listening (and unlink the socket if unix type): + info("closing socket listener") + socketListener.Close() + + // Close db connections + info("closing database connection") + err = h.db.Close() + if err != nil { + log.Error().Err(err).Msg("failed to close db") + } + + log.Info(). + Msg("Headscale stopped") + + // And we're done: + cancel() + + return } - default: - info := func(msg string) { log.Info().Msg(msg) } - log.Info(). - Str("signal", sig.String()). - Msg("Received signal to stop, shutting down gracefully") - - scheduleCancel() - h.ephemeralGC.Close() - - // Gracefully shut down servers - ctx, cancel := context.WithTimeout( - context.Background(), - types.HTTPShutdownTimeout, - ) - info("shutting down debug http server") - if err := debugHTTPServer.Shutdown(ctx); err != nil { - log.Error().Err(err).Msg("failed to shutdown prometheus http") + case <-ctx.Done(): + // send signal to kill + // could be done a lot better + select { + case sigc <- os.Kill: + fmt.Println("sent kill message") + default: + fmt.Println("no kill message sent") } - info("shutting down main http server") - if err := httpServer.Shutdown(ctx); err != nil { - log.Error().Err(err).Msg("failed to shutdown http") - } - - info("closing node notifier") - h.nodeNotifier.Close() - - info("waiting for netmap stream to close") - h.pollNetMapStreamWG.Wait() - - info("shutting down grpc server (socket)") - grpcSocket.GracefulStop() - - if grpcServer != nil { - info("shutting down grpc server (external)") - grpcServer.GracefulStop() - grpcListener.Close() - } - - if tailsqlContext != nil { - info("shutting down tailsql") - tailsqlContext.Done() - } - - // Close network listeners - info("closing network listeners") - debugHTTPListener.Close() - httpListener.Close() - grpcGatewayConn.Close() - - // Stop listening (and unlink the socket if unix type): - info("closing socket listener") - socketListener.Close() - - // Close db connections - info("closing database connection") - err = h.db.Close() - if err != nil { - log.Error().Err(err).Msg("failed to close db") - } - - log.Info(). - Msg("Headscale stopped") - - // And we're done: - cancel() - - return } } + } errorGroup.Go(func() error { sigFunc(sigc)