import logging
import typing

import gql
import jwt
from django.conf import settings
from django.contrib.auth.models import Group
from django_http_exceptions import HTTPExceptions
from gql.transport.exceptions import TransportQueryError
from gql.transport.requests import RequestsHTTPTransport
from pirates.auth import PiratesOIDCAuthenticationBackend

logging.basicConfig(level=logging.DEBUG)


class RegistryOIDCAuthenticationBackend(PiratesOIDCAuthenticationBackend):
    def _assign_new_user_groups(
        self, user, new_user_groups: list, existing_user_groups=None
    ) -> None:
        if existing_user_groups is None:
            existing_user_groups = user.groups.all()

        for group in new_user_groups:
            group_name = f"chobo_{group}"

            group = Group.objects.filter(name=group_name)

            if not group.exists():
                group = Group(name=group_name)
                group.save()
            else:
                group = group[0]

            if group not in existing_user_groups:
                user.groups.add(group)

        user.save()

    def _remove_old_user_groups(
        self, user, new_user_groups: list, existing_user_groups=None
    ) -> None:
        if existing_user_groups is None:
            existing_user_groups = user.groups.all()

        for group in existing_user_groups:
            if group.name.replace("chobo_", "") not in new_user_groups:
                user.groups.remove(group)

    def get_chobotnice_groups(self, access_token):
        transport = RequestsHTTPTransport(url=settings.CHOBOTNICE_API_URL)
        client = gql.Client(
            transport=transport,
            fetch_schema_from_transport=True,
        )

        query = gql.gql(
            f"""
                {{
                    allPeople(
                        filters: {{keycloakId: {{exact: "{access_token['sub']}"}}}}
                    ) {{
                        edges {{
                            node {{
                                groupMemberships {{
                                    group {{
                                        shortcut
                                    }}
                                }}
                            }}
                        }}
                    }}
                }}
            """
        )

        try:
            result = client.execute(query)
        except TransportQueryError:
            # rv_gid was not found
            raise HTTPExceptions.BAD_REQUEST

        groups = []

        for person in result["allPeople"]["edges"]:
            for group_membership in person["node"]["groupMemberships"]:
                groups.append(group_membership["group"]["shortcut"])

        return groups

    def get_or_create_user(self, access_token, id_token, payload):
        user = super().get_or_create_user(access_token, id_token, payload)

        if user is None:
            return

        decoded_access_token = jwt.decode(
            access_token, options={"verify_signature": False}
        )

        user.preferred_username = decoded_access_token["preferred_username"]
        existing_user_groups = user.groups.all()
        new_user_groups = self.get_chobotnice_groups(decoded_access_token)

        self._remove_old_user_groups(
            user,
            new_user_groups=new_user_groups,
            existing_user_groups=existing_user_groups,
        )
        self._assign_new_user_groups(
            user,
            new_user_groups=new_user_groups,
            existing_user_groups=existing_user_groups,
        )

        user.update_group_based_admin()
        user.save(saved_by_auth=True)

        return user