import logging import jwt from django.conf import settings from django.contrib.auth.models import Group from pirates.auth import PiratesOIDCAuthenticationBackend logging.basicConfig(level=logging.DEBUG) class RegistryOIDCAuthenticationBackend(PiratesOIDCAuthenticationBackend): def _assign_new_user_groups(self, user, access_token, user_groups=None) -> None: if user_groups is None: user_groups = user.groups.all() for group in access_token["groups"]: group_name = f"sso_{group}" group = Group.objects.filter(name=group_name) if not group.exists(): group = Group(name=group_name) group.save() else: group = group[0] if group not in user_groups: user.groups.add(group) user.save() def _remove_old_user_groups(self, user, access_token, user_groups=None) -> None: if user_groups is None: user_groups = user.groups.all() for group in user_groups: if group.name.replace("sso_", "") not in access_token["groups"]: user.groups.remove(group) def get_or_create_user(self, access_token, id_token, payload): user = super().get_or_create_user(access_token, id_token, payload) if user is None: return decoded_access_token = jwt.decode( access_token, options={"verify_signature": False} ) user_groups = user.groups.all() self._remove_old_user_groups( user, decoded_access_token, user_groups=user_groups ) self._assign_new_user_groups( user, decoded_access_token, user_groups=user_groups ) return user