#!/usr/bin/env nix-shell
#!nix-shell --pure --keep LLDAP_TOKEN -i python3 -p "python3.withPackages (ps: with ps; [ requests gql aiohttp])"

from typing import Any
import subprocess
import secrets
import json
import os

import gql
from gql.dsl import DSLQuery, DSLSchema, dsl_gql, DSLMutation
from gql.transport.aiohttp import AIOHTTPTransport
from pprint import pprint


from .users import LLDAPUsers
from .groups import LLDAPGroups
from .attributes import LLDAPAttributes

import logging

logging.basicConfig()
logging.getLogger("lldapbootstrap").setLevel(logging.DEBUG)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


class LLDAP:
    def __init__(self, server_url: str, auth_token: str):
        self._server_url: str = server_url
        self._server_auth_token: str = auth_token

        self._client: gql.Client = self._init_gql_client()

        self._users = LLDAPUsers(self._client)
        self._groups = LLDAPGroups(self._client)
        self._attrsUser = LLDAPAttributes(self._client, using_user_attributes=True)
        self._attrsGroup = LLDAPAttributes(self._client, using_group_attributes=True)

    def _init_gql_client(self) -> gql.Client:
        # Select your transport with a defined url endpoint
        transport = AIOHTTPTransport(
            url=f"{self._server_url}/api/graphql",
            headers={"Authorization": f"Bearer {self._server_auth_token}"},
        )

        # Create a GraphQL client using the defined transport
        client = gql.Client(transport=transport, fetch_schema_from_transport=True)

        # force fetch schema
        query = gql.gql(
            """
            query {
            users {
                displayName
            }
            }
        """
        )
        result = client.execute(query)

        return client

    def _run_ensure_attrs_user(self, attrsClass, neededAttrsUser: list[dict[str, Any]]):
        dictNeededAttrUser = {v["name"]: v for v in neededAttrsUser}
        remoteAttrs = attrsClass.list_all()

        # add needed attributes
        for neededAttr in neededAttrsUser:
            neededAttrName = neededAttr["name"]

            if neededAttrName in remoteAttrs:
                cattr = remoteAttrs[neededAttrName]
                if (
                    neededAttr["attributeType"] != cattr.attributeType
                    or neededAttr["isEditable"] != cattr.isEditable
                    or neededAttr["isList"] != cattr.isList
                    or neededAttr["isVisible"] != cattr.isVisible
                ):
                    logger.debug(
                        f"attribute '{neededAttrName}' out of sync, deleting and adding again"
                    )
                    attrsClass.delete(neededAttrName)
                else:
                    continue

            attrsClass.create(
                neededAttrName,
                attributeType=neededAttr["attributeType"],
                isEditable=neededAttr["isEditable"],
                isList=neededAttr["isList"],
                isVisible=neededAttr["isVisible"],
            )

        # remove unneeded attributes
        for remoteAttrName, remoteAttr in remoteAttrs.items():
            # skip hardcoded ones
            if remoteAttr.isHardcoded:
                continue

            if remoteAttrName not in dictNeededAttrUser:
                attrsClass.delete(remoteAttrName)

    def _run_ensure_groups(self, neededGroups: list[dict[str, Any]]):
        tmpNeededGroups = {v["display_name"]: v for v in neededGroups}
        remoteGroups = self._groups.list_all()

        for neededGroup in neededGroups:
            neededGroupDisplay_Name = neededGroup["display_name"]

            if neededGroupDisplay_Name not in remoteGroups:
                self._groups.create(neededGroupDisplay_Name)

                # refresh groups
                remoteGroups = self._groups.list_all()

            remoteGroup = remoteGroups[neededGroupDisplay_Name]

            # we cannot update the display name, and we never would anyways
            del neededGroup["display_name"]

            self._groups.update(remoteGroup.groupId, neededGroup)

        # delete unused groups
        for remoteGroupName, remoteGroup in remoteGroups.items():
            # skip all lldap_ groups
            if remoteGroupName.startswith("lldap_"):
                continue

            if remoteGroupName not in tmpNeededGroups:
                self._groups.delete(remoteGroup.groupId)

    def _run_ensure_users(
        self,
        neededUsers: list[dict[str, Any]],
        softDelete: bool = True,
    ):
        tmpNeededUsers = {v["user_id"]: v for v in neededUsers}
        remoteUsers = self._users.list_all()

        for neededUser in neededUsers:
            # get required info from dict, and DELETE from dict
            # while we're at it. This means that we can safely use
            # `neededUser` for updating later
            neededUserId = neededUser.pop("user_id")
            neededUserGroups = neededUser.pop("groups", [])
            neededUserPassword: str | None = neededUser.pop("password", None)

            # create user if needed
            if neededUserId not in remoteUsers:
                self._users.create(
                    neededUserId,
                    neededUser.get("mail", "no-email-specified"),
                )

                # refresh users
                remoteUsers = self._users.list_all()

            # update user
            self._users.update(neededUserId, neededUser)

            # set correct groups
            remoteUser = remoteUsers[neededUserId]

            # print warning about groups attribute
            if neededUserGroups:
                logger.info(
                    f"using attribute 'groups' for userId '{neededUserId}', for setting groups, NOT SETTING AS ATTRIBUTE!!"
                )

            # add to correct groups
            for groupName in neededUserGroups:
                if groupName not in remoteUser.groups:
                    self._users.add_group(
                        neededUserId,
                        self._groups.name_to_id(groupName),
                    )

            # remove from unused groups
            for groupName, group in remoteUser.groups.items():
                if groupName not in neededUserGroups:
                    self._users.remove_group(
                        neededUserId,
                        self._groups.name_to_id(groupName),
                    )

            if neededUserPassword:
                logger.info(
                    f"using attribute 'password' for userId '{neededUserId}', for setting password, NOT SETTING AS ATTRIBUTE!!"
                )

                if neededUserPassword.startswith("file:"):
                    passwordFile = neededUserPassword[len("file:") :]
                    logger.debug(
                        f"reading password from file from file '{passwordFile}' for user '{neededUserId}'"
                    )

                    self._user_set_password(
                        neededUserId,
                        open(passwordFile, "r").read().strip(),
                    )
                elif neededUserPassword.startswith("env:"):
                    cleanedPasswordEnv = neededUserPassword.strip()[len("env:") :]
                    logger.debug(
                        f"reading password from envvar '{cleanedPasswordEnv}' for user '{neededUserId}'"
                    )

                    password = os.getenv(cleanedPasswordEnv)
                    if not password:
                        raise Exception(
                            f"could not find env '{cleanedPasswordEnv}' for getting password"
                        )
                    self._user_set_password(
                        neededUserId,
                        password.strip(),
                    )
                else:
                    logger.debug(
                        f"using the raw value of password, as the password for user '{neededUserId}'"
                    )
                    self._user_set_password(
                        neededUserId,
                        neededUserPassword.strip(),
                    )

        # delete unused users
        for remoteUserName, remoteUser in remoteUsers.items():
            if remoteUserName not in tmpNeededUsers:
                if softDelete:
                    self._user_disable(remoteUserName)
                else:
                    self._users.delete(remoteUser.userId)

    def _user_disable(self, userId: str, disabled_group_name: str = "disabled"):
        user = self._users.get(userId)
        if not user:
            return

        # remove all groups
        for groupName, groupId in user.groups.items():
            if groupName == disabled_group_name:
                continue

            self._users.remove_group(userId, groupId)

        # if disabled group is in the users groups, then return
        if disabled_group_name in user.groups:
            return

        # ensure group exists
        groups = self._groups.list_all()
        if disabled_group_name not in groups:
            self._groups.create(disabled_group_name)

        # add disabled group
        self._users.add_group(userId, self._groups.name_to_id(disabled_group_name))

        # set password to a long string
        self._user_set_password(userId, secrets.token_urlsafe(128))

    def _user_set_password(self, userId: str, password: str):
        subprocess.check_output(
            [
                "lldap_set_password",
                f"--base-url={self._server_url}",
                f"--token={self._server_auth_token}",
                f"--username={userId}",
                f"--password={password}",
            ]
        )

    def run(self):
        data = json.load(open("test2.json", "r"))

        self._run_ensure_attrs_user(self._attrsUser, data["user_attributes"])
        self._run_ensure_attrs_user(self._attrsGroup, data["group_attributes"])
        self._run_ensure_groups(data["groups"])
        self._run_ensure_users(data["users"])


if __name__ == "__main__":
    auth_token = os.getenv("LLDAP_TOKEN")
    if not auth_token:
        raise Exception("No LLDAP_TOKEN provided. please set")

    x = LLDAP("https://ldap.fricloud.dk", auth_token)
    x.run()