136 lines
3.9 KiB
Python
136 lines
3.9 KiB
Python
from typing import Any
|
|
import re
|
|
|
|
import gql
|
|
from gql.dsl import DSLQuery, DSLSchema, dsl_gql, DSLMutation
|
|
from gql.transport.aiohttp import AIOHTTPTransport
|
|
|
|
from .utils import to_camelcase, to_snakecase
|
|
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Group:
|
|
def __init__(
|
|
self,
|
|
raw_attrs: list[dict[str, str]],
|
|
):
|
|
self._attributes: dict[str, Any] = {
|
|
item["name"]: item["value"] for item in raw_attrs
|
|
}
|
|
|
|
self.groupId: int = int(self.__getattr__("groupId")[0])
|
|
self.name: str = self.__getattr__("displayName")[0]
|
|
|
|
def _attributes_camelcase(self) -> dict[str, str]:
|
|
return {to_camelcase(k): v for k, v in self._attributes.items()}
|
|
|
|
def __getattr__(self, key: str):
|
|
return self._attributes_camelcase().get(key, "")
|
|
|
|
def __repr__(self):
|
|
return f"<Group groupId={self.groupId} name={self.name} />"
|
|
|
|
|
|
class LLDAPGroups:
|
|
def __init__(self, client: gql.Client):
|
|
self._client = client
|
|
|
|
def list_all(self) -> dict[str, Group]:
|
|
ds = DSLSchema(self._client.schema)
|
|
query = dsl_gql(
|
|
DSLQuery(
|
|
ds.Query.groups().select(
|
|
ds.Group.attributes.select(
|
|
ds.AttributeValue.name,
|
|
ds.AttributeValue.value,
|
|
ds.AttributeValue.schema.select(
|
|
ds.AttributeSchema.isHardcoded,
|
|
),
|
|
),
|
|
),
|
|
),
|
|
)
|
|
|
|
result = self._client.execute(query)
|
|
|
|
groups: dict[str, Group] = {}
|
|
for group in result["groups"]:
|
|
g = Group(
|
|
raw_attrs=group.get("attributes", []),
|
|
)
|
|
groups[g.name] = g
|
|
|
|
return groups
|
|
|
|
def create(self, groupName: str):
|
|
logger.debug(f"creating group with name '{groupName}'")
|
|
|
|
ds = DSLSchema(self._client.schema)
|
|
query = dsl_gql(
|
|
DSLMutation(
|
|
ds.Mutation.createGroup.args(
|
|
name=groupName,
|
|
).select(ds.Group.displayName)
|
|
),
|
|
)
|
|
self._client.execute(query)
|
|
|
|
def get_by_name(self, groupName: str) -> Group | None:
|
|
groups = self.list_all()
|
|
return groups.get(groupName)
|
|
|
|
def get_by_id(self, groupId: int) -> Group | None:
|
|
groups = self.list_all()
|
|
for group in groups.values():
|
|
if group.groupId == groupId:
|
|
return group
|
|
|
|
return None
|
|
|
|
def name_to_id(self, groupName: str) -> int:
|
|
group = self.get_by_name(groupName)
|
|
if not group:
|
|
raise Exception(f"no group with the name {groupName}")
|
|
|
|
return group.groupId
|
|
|
|
def update(self, groupId: int, attrs: dict[str, str | list[str]]):
|
|
insertAttributes: list[dict[str, str | list[str]]] = []
|
|
for k, v in attrs.items():
|
|
if isinstance(v, str):
|
|
v = [v]
|
|
|
|
insertAttributes.append({"name": to_snakecase(k), "value": v})
|
|
|
|
ds = DSLSchema(self._client.schema)
|
|
query = dsl_gql(
|
|
DSLMutation(
|
|
ds.Mutation.updateGroup.args(
|
|
group={
|
|
"id": groupId,
|
|
"insertAttributes": insertAttributes,
|
|
},
|
|
).select(ds.Success.ok),
|
|
),
|
|
)
|
|
self._client.execute(query)
|
|
|
|
def delete(self, groupId: int):
|
|
logger.debug(f"deleting group with id '{groupId}'")
|
|
|
|
ds = DSLSchema(self._client.schema)
|
|
query = dsl_gql(
|
|
DSLMutation(
|
|
ds.Mutation.deleteGroup.args(
|
|
groupId=groupId,
|
|
).select(ds.Success.ok),
|
|
),
|
|
)
|
|
self._client.execute(query)
|
|
|
|
def test(self):
|
|
self.list_all()
|
|
# self.update("testusername", {"displayName": "Test User Name"})
|