create DB struct
This is step one in detaching the Database layer from Headscale (h). The ultimate goal is to have all function that does database operations in its own package, and keep the business logic and writing separate. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
b01f1f1867
commit
14e29a7bee
48 changed files with 1731 additions and 1572 deletions
|
@ -14,6 +14,7 @@ import (
|
|||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/oauth2"
|
||||
"tailscale.com/types/key"
|
||||
|
@ -21,16 +22,22 @@ import (
|
|||
|
||||
const (
|
||||
randomByteSize = 16
|
||||
)
|
||||
|
||||
errEmptyOIDCCallbackParams = Error("empty OIDC callback params")
|
||||
errNoOIDCIDToken = Error("could not extract ID Token for OIDC callback")
|
||||
errOIDCAllowedDomains = Error("authenticated principal does not match any allowed domain")
|
||||
errOIDCAllowedGroups = Error("authenticated principal is not in any allowed group")
|
||||
errOIDCAllowedUsers = Error("authenticated principal does not match any allowed user")
|
||||
errOIDCInvalidMachineState = Error(
|
||||
var (
|
||||
errEmptyOIDCCallbackParams = errors.New("empty OIDC callback params")
|
||||
errNoOIDCIDToken = errors.New("could not extract ID Token for OIDC callback")
|
||||
errOIDCAllowedDomains = errors.New(
|
||||
"authenticated principal does not match any allowed domain",
|
||||
)
|
||||
errOIDCAllowedGroups = errors.New("authenticated principal is not in any allowed group")
|
||||
errOIDCAllowedUsers = errors.New(
|
||||
"authenticated principal does not match any allowed user",
|
||||
)
|
||||
errOIDCInvalidMachineState = errors.New(
|
||||
"requested machine state key expired before authorisation completed",
|
||||
)
|
||||
errOIDCNodeKeyMissing = Error("could not get node key from cache")
|
||||
errOIDCNodeKeyMissing = errors.New("could not get node key from cache")
|
||||
)
|
||||
|
||||
type IDTokenClaims struct {
|
||||
|
@ -94,7 +101,7 @@ func (h *Headscale) RegisterOIDC(
|
|||
Bool("ok", ok).
|
||||
Msg("Received oidc register call")
|
||||
|
||||
if !NodePublicKeyRegex.Match([]byte(nodeKeyStr)) {
|
||||
if !util.NodePublicKeyRegex.Match([]byte(nodeKeyStr)) {
|
||||
log.Warn().Str("node_key", nodeKeyStr).Msg("Invalid node key passed to registration url")
|
||||
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
|
@ -115,7 +122,7 @@ func (h *Headscale) RegisterOIDC(
|
|||
// the template and log an error.
|
||||
var nodeKey key.NodePublic
|
||||
err := nodeKey.UnmarshalText(
|
||||
[]byte(NodePublicKeyEnsurePrefix(nodeKeyStr)),
|
||||
[]byte(util.NodePublicKeyEnsurePrefix(nodeKeyStr)),
|
||||
)
|
||||
|
||||
if !ok || nodeKeyStr == "" || err != nil {
|
||||
|
@ -149,7 +156,11 @@ func (h *Headscale) RegisterOIDC(
|
|||
stateStr := hex.EncodeToString(randomBlob)[:32]
|
||||
|
||||
// place the node key into the state cache, so it can be retrieved later
|
||||
h.registrationCache.Set(stateStr, NodePublicKeyStripPrefix(nodeKey), registerCacheExpiration)
|
||||
h.registrationCache.Set(
|
||||
stateStr,
|
||||
util.NodePublicKeyStripPrefix(nodeKey),
|
||||
registerCacheExpiration,
|
||||
)
|
||||
|
||||
// Add any extra parameter provided in the configuration to the Authorize Endpoint request
|
||||
extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams))
|
||||
|
@ -406,7 +417,7 @@ func validateOIDCAllowedDomains(
|
|||
) error {
|
||||
if len(allowedDomains) > 0 {
|
||||
if at := strings.LastIndex(claims.Email, "@"); at < 0 ||
|
||||
!IsStringInSlice(allowedDomains, claims.Email[at+1:]) {
|
||||
!util.IsStringInSlice(allowedDomains, claims.Email[at+1:]) {
|
||||
log.Error().Msg("authenticated principal does not match any allowed domain")
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
|
@ -436,7 +447,7 @@ func validateOIDCAllowedGroups(
|
|||
) error {
|
||||
if len(allowedGroups) > 0 {
|
||||
for _, group := range allowedGroups {
|
||||
if IsStringInSlice(claims.Groups, group) {
|
||||
if util.IsStringInSlice(claims.Groups, group) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
@ -466,7 +477,7 @@ func validateOIDCAllowedUsers(
|
|||
claims *IDTokenClaims,
|
||||
) error {
|
||||
if len(allowedUsers) > 0 &&
|
||||
!IsStringInSlice(allowedUsers, claims.Email) {
|
||||
!util.IsStringInSlice(allowedUsers, claims.Email) {
|
||||
log.Error().Msg("authenticated principal does not match any allowed user")
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
|
@ -531,7 +542,7 @@ func (h *Headscale) validateMachineForOIDCCallback(
|
|||
}
|
||||
|
||||
err := nodeKey.UnmarshalText(
|
||||
[]byte(NodePublicKeyEnsurePrefix(nodeKeyFromCache)),
|
||||
[]byte(util.NodePublicKeyEnsurePrefix(nodeKeyFromCache)),
|
||||
)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
|
@ -555,7 +566,7 @@ func (h *Headscale) validateMachineForOIDCCallback(
|
|||
// The error is not important, because if it does not
|
||||
// exist, then this is a new machine and we will move
|
||||
// on to registration.
|
||||
machine, _ := h.GetMachineByNodeKey(nodeKey)
|
||||
machine, _ := h.db.GetMachineByNodeKey(nodeKey)
|
||||
|
||||
if machine != nil {
|
||||
log.Trace().
|
||||
|
@ -563,7 +574,7 @@ func (h *Headscale) validateMachineForOIDCCallback(
|
|||
Str("machine", machine.Hostname).
|
||||
Msg("machine already registered, reauthenticating")
|
||||
|
||||
err := h.RefreshMachine(machine, expiry)
|
||||
err := h.db.RefreshMachine(machine, expiry)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
@ -653,9 +664,9 @@ func (h *Headscale) findOrCreateNewUserForOIDCCallback(
|
|||
writer http.ResponseWriter,
|
||||
userName string,
|
||||
) (*User, error) {
|
||||
user, err := h.GetUser(userName)
|
||||
user, err := h.db.GetUser(userName)
|
||||
if errors.Is(err, ErrUserNotFound) {
|
||||
user, err = h.CreateUser(userName)
|
||||
user, err = h.db.CreateUser(userName)
|
||||
|
||||
if err != nil {
|
||||
log.Error().
|
||||
|
@ -702,7 +713,9 @@ func (h *Headscale) registerMachineForOIDCCallback(
|
|||
nodeKey *key.NodePublic,
|
||||
expiry time.Time,
|
||||
) error {
|
||||
if _, err := h.RegisterMachineFromAuthCallback(
|
||||
if _, err := h.db.RegisterMachineFromAuthCallback(
|
||||
// TODO(kradalby): find a better way to use the cache across modules
|
||||
h.registrationCache,
|
||||
nodeKey.String(),
|
||||
user.Name,
|
||||
&expiry,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue