From ca07ad1e1e0a0cc3901bf76e08b7f0493c0c37bc Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Valenta?= <git@imaniti.org>
Date: Sun, 9 Jul 2023 12:15:11 +0900
Subject: [PATCH] sync Groups from Chobotnice

---
 oidc/auth.py              | 84 +++++++++++++++++++++++++++++----------
 ucebnice/settings/base.py |  7 ++++
 2 files changed, 70 insertions(+), 21 deletions(-)

diff --git a/oidc/auth.py b/oidc/auth.py
index 97ef6a2..b6785c1 100644
--- a/oidc/auth.py
+++ b/oidc/auth.py
@@ -11,16 +11,13 @@ logging.basicConfig(level=logging.DEBUG)
 
 class UcebniceOIDCAuthenticationBackend(PiratesOIDCAuthenticationBackend):
     def _assign_new_user_groups(
-        self, user, access_token: dict, user_groups: typing.Union[None, list] = None
+        self, user, new_user_groups: list, existing_user_groups=None
     ) -> None:
-        if user_groups is None:
-            user_groups = user.groups.all()
+        if existing_user_groups is None:
+            existing_user_groups = user.groups.all()
 
-        for group in access_token["groups"]:
-            if group.startswith("_"):  # Ignore internal Keycloak groups
-                continue
-
-            group_name = f"sso_{group}"
+        for group in new_user_groups:
+            group_name = f"chobo_{group}"
 
             group = Group.objects.filter(name=group_name)
 
@@ -30,19 +27,62 @@ class UcebniceOIDCAuthenticationBackend(PiratesOIDCAuthenticationBackend):
             else:
                 group = group[0]
 
-            if group not in user_groups:
+            if group not in existing_user_groups:
                 user.groups.add(group)
 
+        user.save()
+
     def _remove_old_user_groups(
-        self, user, access_token: dict, user_groups: typing.Union[None, list] = None
+        self, user, new_user_groups: list, existing_user_groups=None
     ) -> None:
-        if user_groups is None:
-            user_groups = user.groups.all()
+        if existing_user_groups is None:
+            existing_user_groups = user.groups.all()
 
-        for group in user_groups:
-            if group.name.replace("sso_", "") not in access_token["groups"]:
+        for group in existing_user_groups:
+            if group.name.replace("chobo_", "") not in new_user_groups:
                 user.groups.remove(group)
 
+    def get_chobotnice_groups(self, access_token):
+        transport = RequestsHTTPTransport(url=settings.CHOBOTNICE_API_URL)
+        client = gql.Client(
+            transport=transport,
+            fetch_schema_from_transport=True,
+        )
+
+        query = gql.gql(
+            f"""
+                {{
+                    allPeople(
+                        filters: {{keycloakId: {{exact: "{access_token['sub']}"}}}}
+                    ) {{
+                        edges {{
+                            node {{
+                                groupMemberships {{
+                                    group {{
+                                        shortcut
+                                    }}
+                                }}
+                            }}
+                        }}
+                    }}
+                }}
+            """
+        )
+
+        try:
+            result = client.execute(query)
+        except TransportQueryError:
+            # rv_gid was not found
+            raise HTTPExceptions.BAD_REQUEST
+
+        groups = []
+
+        for person in result["allPeople"]["edges"]:
+            for group_membership in person["node"]["groupMemberships"]:
+                groups.append(group_membership["group"]["shortcut"])
+
+        return groups
+
     def get_or_create_user(self, access_token, id_token, payload):
         user = super().get_or_create_user(access_token, id_token, payload)
 
@@ -53,17 +93,19 @@ class UcebniceOIDCAuthenticationBackend(PiratesOIDCAuthenticationBackend):
             access_token, options={"verify_signature": False}
         )
 
-        user.sso_username = decoded_access_token["preferred_username"]
-        user.email = decoded_access_token["email"]
-        user_groups = user.groups.all()
+        user.preferred_username = decoded_access_token["preferred_username"]
+        existing_user_groups = user.groups.all()
+        new_user_groups = self.get_chobotnice_groups(decoded_access_token)
 
         self._remove_old_user_groups(
-            user, decoded_access_token, user_groups=user_groups
+            user,
+            new_user_groups=new_user_groups,
+            existing_user_groups=existing_user_groups,
         )
         self._assign_new_user_groups(
-            user, decoded_access_token, user_groups=user_groups
+            user,
+            new_user_groups=new_user_groups,
+            existing_user_groups=existing_user_groups,
         )
 
-        user.save()
-
         return user
diff --git a/ucebnice/settings/base.py b/ucebnice/settings/base.py
index 1a6d089..96f9c48 100644
--- a/ucebnice/settings/base.py
+++ b/ucebnice/settings/base.py
@@ -200,6 +200,13 @@ ADMIN_INDEX_SHOW_REMAINING_APPS = True
 ADMIN_ORDERING = {}
 
 
+# Chobotnice
+
+CHOBOTNICE_API_URL = env.str(
+    "CHOBOTNICE_API_URL", "https://chobotnice.pirati.cz/graphql/"
+)
+
+
 # DBsettings
 
 DBSETTINGS_VALUE_LENGTH = 65536
-- 
GitLab