diff --git a/oidc/auth.py b/oidc/auth.py index 1c9b54f233e2ed60772fdae164aeef6fb4fc6df8..1ec381397eb91635cc8e626e4683c4458da169b9 100644 --- a/oidc/auth.py +++ b/oidc/auth.py @@ -1,3 +1,4 @@ +import jwt import logging from django.conf import settings @@ -7,15 +8,13 @@ from pirates.auth import PiratesOIDCAuthenticationBackend logging.basicConfig(level=logging.DEBUG) -class NastenkaOIDCAuthenticationBackend(PiratesOIDCAuthenticationBackend): - def _assign_new_user_groups(self, user, claims, user_groups=None) -> None: +class RegistryOIDCAuthenticationBackend(PiratesOIDCAuthenticationBackend): + def _assign_new_user_groups(self, user, access_token, user_groups=None) -> None: if user_groups is None: user_groups = user.groups.all() - for role in claims["resource_access"][settings.OIDC_RP_RESOURCE_ACCESS_CLIENT][ - "roles" - ]: - group_name = f"sso_{role}" + for group in access_token["groups"]: + group_name = f"sso_{group}" group = Group.objects.filter(name=group_name) @@ -30,28 +29,25 @@ class NastenkaOIDCAuthenticationBackend(PiratesOIDCAuthenticationBackend): user.save() - def create_user(self, claims): - user = super().create_user(claims) + def _remove_old_user_groups(self, user, access_token, user_groups=None) -> None: + if user_groups is None: + user_groups = user.groups.all() - if "resource_access" not in claims: - return user + for group in user_groups: + if group.name.replace("sso_", "") not in access_token["groups"]: + user.groups.remove(group) - self._assign_new_user_groups(user, claims) + def get_or_create_user(self, access_token, id_token, payload): + user = super().get_or_create_user(access_token, id_token, payload) - return user + if user is None: + return - def update_user(self, user, claims): - if "resource_access" not in claims: - return user + decoded_access_token = jwt.decode(access_token, options={"verify_signature": False}) user_groups = user.groups.all() - for group in user_groups: - if group.name.replace("sso_", "") not in ( - claims["resource_access"][settings.OIDC_RP_CLIENT_ID]["roles"] - ): - user.groups.remove(group) - - self._assign_new_user_groups(user, claims, user_groups) + self._remove_old_user_groups(user, decoded_access_token, user_groups=user_groups) + self._assign_new_user_groups(user, decoded_access_token, user_groups=user_groups) return user diff --git a/registry/settings/base.py b/registry/settings/base.py index bb94bc3a129730a61637be35b7affc1d500febde..422a33ef949738b022000275c6753d11543ef325 100644 --- a/registry/settings/base.py +++ b/registry/settings/base.py @@ -129,7 +129,7 @@ AUTH_PASSWORD_VALIDATORS = [ AUTH_USER_MODEL = "users.User" AUTHENTICATION_BACKENDS = ( - "oidc.auth.NastenkaOIDCAuthenticationBackend", + "oidc.auth.RegistryOIDCAuthenticationBackend", "django.contrib.auth.backends.ModelBackend", "guardian.backends.ObjectPermissionBackend", ) @@ -141,7 +141,7 @@ LOGOUT_REDIRECT_URL = "/" OIDC_RP_CLIENT_ID = env.str("OIDC_RP_CLIENT_ID") OIDC_RP_CLIENT_SECRET = env.str("OIDC_RP_CLIENT_SECRET") OIDC_RP_REALM_URL = env.str("OIDC_RP_REALM_URL") -OIDC_RP_SCOPES = "openid email roles" +OIDC_RP_SCOPES = "openid profile groups" OIDC_RP_SIGN_ALGO = "RS256" OIDC_RP_RESOURCE_ACCESS_CLIENT = env.str( "OIDC_RESOURCE_ACCESS_CLIENT", OIDC_RP_CLIENT_ID