create DB struct

This is step one in detaching the Database layer from Headscale (h). The
ultimate goal is to have all function that does database operations in
its own package, and keep the business logic and writing separate.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2023-05-11 09:09:18 +02:00 committed by Kristoffer Dalby
parent b01f1f1867
commit 14e29a7bee
48 changed files with 1731 additions and 1572 deletions

View file

@ -14,6 +14,7 @@ import (
"github.com/coreos/go-oidc/v3/oidc"
"github.com/gorilla/mux"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"golang.org/x/oauth2"
"tailscale.com/types/key"
@ -21,16 +22,22 @@ import (
const (
randomByteSize = 16
)
errEmptyOIDCCallbackParams = Error("empty OIDC callback params")
errNoOIDCIDToken = Error("could not extract ID Token for OIDC callback")
errOIDCAllowedDomains = Error("authenticated principal does not match any allowed domain")
errOIDCAllowedGroups = Error("authenticated principal is not in any allowed group")
errOIDCAllowedUsers = Error("authenticated principal does not match any allowed user")
errOIDCInvalidMachineState = Error(
var (
errEmptyOIDCCallbackParams = errors.New("empty OIDC callback params")
errNoOIDCIDToken = errors.New("could not extract ID Token for OIDC callback")
errOIDCAllowedDomains = errors.New(
"authenticated principal does not match any allowed domain",
)
errOIDCAllowedGroups = errors.New("authenticated principal is not in any allowed group")
errOIDCAllowedUsers = errors.New(
"authenticated principal does not match any allowed user",
)
errOIDCInvalidMachineState = errors.New(
"requested machine state key expired before authorisation completed",
)
errOIDCNodeKeyMissing = Error("could not get node key from cache")
errOIDCNodeKeyMissing = errors.New("could not get node key from cache")
)
type IDTokenClaims struct {
@ -94,7 +101,7 @@ func (h *Headscale) RegisterOIDC(
Bool("ok", ok).
Msg("Received oidc register call")
if !NodePublicKeyRegex.Match([]byte(nodeKeyStr)) {
if !util.NodePublicKeyRegex.Match([]byte(nodeKeyStr)) {
log.Warn().Str("node_key", nodeKeyStr).Msg("Invalid node key passed to registration url")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
@ -115,7 +122,7 @@ func (h *Headscale) RegisterOIDC(
// the template and log an error.
var nodeKey key.NodePublic
err := nodeKey.UnmarshalText(
[]byte(NodePublicKeyEnsurePrefix(nodeKeyStr)),
[]byte(util.NodePublicKeyEnsurePrefix(nodeKeyStr)),
)
if !ok || nodeKeyStr == "" || err != nil {
@ -149,7 +156,11 @@ func (h *Headscale) RegisterOIDC(
stateStr := hex.EncodeToString(randomBlob)[:32]
// place the node key into the state cache, so it can be retrieved later
h.registrationCache.Set(stateStr, NodePublicKeyStripPrefix(nodeKey), registerCacheExpiration)
h.registrationCache.Set(
stateStr,
util.NodePublicKeyStripPrefix(nodeKey),
registerCacheExpiration,
)
// Add any extra parameter provided in the configuration to the Authorize Endpoint request
extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams))
@ -406,7 +417,7 @@ func validateOIDCAllowedDomains(
) error {
if len(allowedDomains) > 0 {
if at := strings.LastIndex(claims.Email, "@"); at < 0 ||
!IsStringInSlice(allowedDomains, claims.Email[at+1:]) {
!util.IsStringInSlice(allowedDomains, claims.Email[at+1:]) {
log.Error().Msg("authenticated principal does not match any allowed domain")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
@ -436,7 +447,7 @@ func validateOIDCAllowedGroups(
) error {
if len(allowedGroups) > 0 {
for _, group := range allowedGroups {
if IsStringInSlice(claims.Groups, group) {
if util.IsStringInSlice(claims.Groups, group) {
return nil
}
}
@ -466,7 +477,7 @@ func validateOIDCAllowedUsers(
claims *IDTokenClaims,
) error {
if len(allowedUsers) > 0 &&
!IsStringInSlice(allowedUsers, claims.Email) {
!util.IsStringInSlice(allowedUsers, claims.Email) {
log.Error().Msg("authenticated principal does not match any allowed user")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
@ -531,7 +542,7 @@ func (h *Headscale) validateMachineForOIDCCallback(
}
err := nodeKey.UnmarshalText(
[]byte(NodePublicKeyEnsurePrefix(nodeKeyFromCache)),
[]byte(util.NodePublicKeyEnsurePrefix(nodeKeyFromCache)),
)
if err != nil {
log.Error().
@ -555,7 +566,7 @@ func (h *Headscale) validateMachineForOIDCCallback(
// The error is not important, because if it does not
// exist, then this is a new machine and we will move
// on to registration.
machine, _ := h.GetMachineByNodeKey(nodeKey)
machine, _ := h.db.GetMachineByNodeKey(nodeKey)
if machine != nil {
log.Trace().
@ -563,7 +574,7 @@ func (h *Headscale) validateMachineForOIDCCallback(
Str("machine", machine.Hostname).
Msg("machine already registered, reauthenticating")
err := h.RefreshMachine(machine, expiry)
err := h.db.RefreshMachine(machine, expiry)
if err != nil {
log.Error().
Caller().
@ -653,9 +664,9 @@ func (h *Headscale) findOrCreateNewUserForOIDCCallback(
writer http.ResponseWriter,
userName string,
) (*User, error) {
user, err := h.GetUser(userName)
user, err := h.db.GetUser(userName)
if errors.Is(err, ErrUserNotFound) {
user, err = h.CreateUser(userName)
user, err = h.db.CreateUser(userName)
if err != nil {
log.Error().
@ -702,7 +713,9 @@ func (h *Headscale) registerMachineForOIDCCallback(
nodeKey *key.NodePublic,
expiry time.Time,
) error {
if _, err := h.RegisterMachineFromAuthCallback(
if _, err := h.db.RegisterMachineFromAuthCallback(
// TODO(kradalby): find a better way to use the cache across modules
h.registrationCache,
nodeKey.String(),
user.Name,
&expiry,