import typing import logging import jwt from django.contrib.auth.models import Group from django.conf import settings from pirates.auth import PiratesOIDCAuthenticationBackend logging.basicConfig(level=logging.DEBUG) class RegistryOIDCAuthenticationBackend(PiratesOIDCAuthenticationBackend): def _assign_new_user_groups( self, user, access_token: dict, user_groups: typing.Union[None, list] = None ) -> None: if user_groups is None: user_groups = user.groups.all() for group in access_token["groups"]: if group.startswith("_"): continue group_name = f"sso_{group}" group = Group.objects.filter(name=group_name) if not group.exists(): group = Group(name=group_name) group.save() else: group = group[0] if group not in user_groups: user.groups.add(group) def _remove_old_user_groups( self, user, access_token: dict, user_groups: typing.Union[None, list] = None ) -> None: if user_groups is None: user_groups = user.groups.all() for group in user_groups: if group.name.replace("sso_", "") not in access_token["groups"]: user.groups.remove(group) def get_or_create_user(self, access_token, id_token, payload): user = super().get_or_create_user(access_token, id_token, payload) if user is None: return decoded_access_token = jwt.decode( access_token, options={"verify_signature": False} ) user_groups = user.groups.all() self._remove_old_user_groups( user, decoded_access_token, user_groups=user_groups ) self._assign_new_user_groups( user, decoded_access_token, user_groups=user_groups ) user.update_group_based_admin() user.save(saved_by_auth=True) return user