Switch to use gorilla's mux as muxer

This commit is contained in:
Juan Font Alonso 2022-06-18 18:41:42 +02:00
parent d5e331a2fb
commit d89fb68a7a
4 changed files with 92 additions and 67 deletions

137
app.go
View file

@ -18,6 +18,7 @@ import (
"github.com/coreos/go-oidc/v3/oidc"
"github.com/gin-gonic/gin"
"github.com/gorilla/mux"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
@ -326,48 +327,56 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
return handler(ctx, req)
}
func (h *Headscale) httpAuthenticationMiddleware(ctx *gin.Context) {
log.Trace().
Caller().
Str("client_address", ctx.ClientIP()).
Msg("HTTP authentication invoked")
authHeader := ctx.GetHeader("authorization")
if !strings.HasPrefix(authHeader, AuthPrefix) {
log.Error().
func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(
w http.ResponseWriter,
r *http.Request,
) {
log.Trace().
Caller().
Str("client_address", ctx.ClientIP()).
Msg(`missing "Bearer " prefix in "Authorization" header`)
ctx.AbortWithStatus(http.StatusUnauthorized)
Str("client_address", r.RemoteAddr).
Msg("HTTP authentication invoked")
return
}
authHeader := r.Header.Get("X-Session-Token")
valid, err := h.ValidateAPIKey(strings.TrimPrefix(authHeader, AuthPrefix))
if err != nil {
log.Error().
Caller().
Err(err).
Str("client_address", ctx.ClientIP()).
Msg("failed to validate token")
if !strings.HasPrefix(authHeader, AuthPrefix) {
log.Error().
Caller().
Str("client_address", r.RemoteAddr).
Msg(`missing "Bearer " prefix in "Authorization" header`)
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte("Unauthorized"))
ctx.AbortWithStatus(http.StatusInternalServerError)
return
}
return
}
valid, err := h.ValidateAPIKey(strings.TrimPrefix(authHeader, AuthPrefix))
if err != nil {
log.Error().
Caller().
Err(err).
Str("client_address", r.RemoteAddr).
Msg("failed to validate token")
if !valid {
log.Info().
Str("client_address", ctx.ClientIP()).
Msg("invalid token")
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("Unauthorized"))
ctx.AbortWithStatus(http.StatusUnauthorized)
return
}
return
}
if !valid {
log.Info().
Str("client_address", r.RemoteAddr).
Msg("invalid token")
ctx.Next()
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte("Unauthorized"))
return
}
next.ServeHTTP(w, r)
})
}
// ensureUnixSocketIsAbsent will check if the given path for headscales unix socket is clear
@ -390,39 +399,42 @@ func (h *Headscale) createPrometheusRouter() *gin.Engine {
return promRouter
}
func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *gin.Engine {
router := gin.Default()
func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *mux.Router {
router := mux.NewRouter()
router.GET(
router.HandleFunc(
"/health",
func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"healthy": "ok"}) },
)
router.GET("/key", gin.WrapF(h.KeyHandler))
router.GET("/register", gin.WrapF(h.RegisterWebAPI))
router.POST("/machine/:id/map", h.PollNetMapHandler)
router.POST("/machine/:id", h.RegistrationHandler)
router.GET("/oidc/register/:mkey", h.RegisterOIDC)
router.GET("/oidc/callback", gin.WrapF(h.OIDCCallback))
router.GET("/apple", gin.WrapF(h.AppleConfigMessage))
router.GET("/apple/:platform", gin.WrapF(h.ApplePlatformConfig))
router.GET("/windows", gin.WrapF(h.WindowsConfigMessage))
router.GET("/windows/tailscale.reg", gin.WrapF(h.WindowsRegConfig))
router.GET("/swagger", gin.WrapF(SwaggerUI))
router.GET("/swagger/v1/openapiv2.json", gin.WrapF(SwaggerAPIv1))
func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("{\"healthy\": \"ok\"}"))
}).Methods(http.MethodGet)
router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet)
router.HandleFunc("/register", h.RegisterWebAPI).Methods(http.MethodGet)
router.HandleFunc("/machine/:id/map", h.PollNetMapHandler).Methods(http.MethodPost)
router.HandleFunc("/machine/:id", h.RegistrationHandler).Methods(http.MethodPost)
router.HandleFunc("/oidc/register/:mkey", h.RegisterOIDC).Methods(http.MethodGet)
router.HandleFunc("/oidc/callback", h.OIDCCallback).Methods(http.MethodGet)
router.HandleFunc("/apple", h.AppleConfigMessage).Methods(http.MethodGet)
router.HandleFunc("/apple/:platform", h.ApplePlatformConfig).Methods(http.MethodGet)
router.HandleFunc("/windows", h.WindowsConfigMessage).Methods(http.MethodGet)
router.HandleFunc("/windows/tailscale.reg", h.WindowsRegConfig).Methods(http.MethodGet)
router.HandleFunc("/swagger", SwaggerUI).Methods(http.MethodGet)
router.HandleFunc("/swagger/v1/openapiv2.json", SwaggerAPIv1).Methods(http.MethodGet)
if h.cfg.DERP.ServerEnabled {
router.Any("/derp", h.DERPHandler)
router.Any("/derp/probe", h.DERPProbeHandler)
router.Any("/bootstrap-dns", h.DERPBootstrapDNSHandler)
router.HandleFunc("/derp", h.DERPHandler)
router.HandleFunc("/derp/probe", h.DERPProbeHandler)
router.HandleFunc("/bootstrap-dns", h.DERPBootstrapDNSHandler)
}
api := router.Group("/api")
api := router.PathPrefix("/api").Subrouter()
api.Use(h.httpAuthenticationMiddleware)
{
api.Any("/v1/*any", gin.WrapF(grpcMux.ServeHTTP))
api.HandleFunc("/v1/*any", grpcMux.ServeHTTP)
}
router.NoRoute(stdoutHandler)
router.PathPrefix("/").HandlerFunc(stdoutHandler)
return router
}
@ -811,13 +823,16 @@ func (h *Headscale) getLastStateChange(namespaces ...string) time.Time {
}
}
func stdoutHandler(ctx *gin.Context) {
body, _ := io.ReadAll(ctx.Request.Body)
func stdoutHandler(
w http.ResponseWriter,
r *http.Request,
) {
body, _ := io.ReadAll(r.Body)
log.Trace().
Interface("header", ctx.Request.Header).
Interface("proto", ctx.Request.Proto).
Interface("url", ctx.Request.URL).
Interface("header", r.Header).
Interface("proto", r.Proto).
Interface("url", r.URL).
Bytes("body", body).
Msg("Request did not match")
}