Skip to content
Snippets Groups Projects
Commit 9cc8d68c authored by jan.bednarik's avatar jan.bednarik
Browse files

Sort reports

parent 90849893
Branches
No related tags found
No related merge requests found
...@@ -6,10 +6,11 @@ from .paginator import Paginator ...@@ -6,10 +6,11 @@ from .paginator import Paginator
from .sanitizers import extract_text from .sanitizers import extract_text
from .. import search from .. import search
from ..models import OpenIdClient from ..models import OpenIdClient
from ..models import User, Report, UserSort from ..models import User, Report, UserSort, ReportSort
UserSortEnum = graphene.Enum.from_enum(UserSort) UserSortEnum = graphene.Enum.from_enum(UserSort)
ReportSortEnum = graphene.Enum.from_enum(ReportSort)
class AuthorsConnection(relay.Connection): class AuthorsConnection(relay.Connection):
...@@ -50,6 +51,10 @@ class Query: ...@@ -50,6 +51,10 @@ class Query:
description="Fulltext search in Reports. Returns first 10 nodes if pagination is not specified.", description="Fulltext search in Reports. Returns first 10 nodes if pagination is not specified.",
query=graphene.String(description="Text to search for."), query=graphene.String(description="Text to search for."),
highlight=graphene.Boolean(default_value=False, description=highlight_help), highlight=graphene.Boolean(default_value=False, description=highlight_help),
sort=ReportSortEnum(),
reversed=graphene.Boolean(
default_value=False, description="Reverse order of sort."
),
) )
viewer = graphene.Field(types.User, description="Active user viewing API.") viewer = graphene.Field(types.User, description="Active user viewing API.")
login_shortcuts = graphene.List( login_shortcuts = graphene.List(
...@@ -84,10 +89,20 @@ class Query: ...@@ -84,10 +89,20 @@ class Query:
def resolve_search_reports(self, info, **kwargs): def resolve_search_reports(self, info, **kwargs):
paginator = Paginator(**kwargs) paginator = Paginator(**kwargs)
query = kwargs.get("query", "") query = kwargs.get("query", "")
query = extract_text(query) query = extract_text(query)
params = {"highlight": kwargs.get("highlight")}
params = {
"highlight": kwargs.get("highlight"),
"reversed": kwargs.get("reversed"),
}
if "sort" in kwargs:
params["sort"] = ReportSort(kwargs["sort"])
response = search.search_reports(paginator, query=query, **params) response = search.search_reports(paginator, query=query, **params)
total = response.hits.total total = response.hits.total
page_info = paginator.get_page_info(total) page_info = paginator.get_page_info(total)
......
...@@ -15,6 +15,12 @@ class UserSort(Enum): ...@@ -15,6 +15,12 @@ class UserSort(Enum):
TOTAL_REPORTS = "total_reports" TOTAL_REPORTS = "total_reports"
class ReportSort(Enum):
DATE = "date"
PUBLISHED = "published"
RELEVANCE = "relevance"
class CustomUserManager(UserManager): class CustomUserManager(UserManager):
def with_total_reports(self): def with_total_reports(self):
return self.get_queryset().annotate( return self.get_queryset().annotate(
......
from .documents import ReportDoc from .documents import ReportDoc
from .models import ReportSort
HIGHLIGHT_PARAMS = { HIGHLIGHT_PARAMS = {
...@@ -8,7 +9,15 @@ HIGHLIGHT_PARAMS = { ...@@ -8,7 +9,15 @@ HIGHLIGHT_PARAMS = {
} }
def search_reports(paginator, *, query=None, highlight=False, author_id=None): def search_reports(
paginator,
*,
query=None,
highlight=False,
author_id=None,
sort=ReportSort.PUBLISHED,
reversed=False,
):
fields = [ fields = [
"title", "title",
"body", "body",
...@@ -31,7 +40,12 @@ def search_reports(paginator, *, query=None, highlight=False, author_id=None): ...@@ -31,7 +40,12 @@ def search_reports(paginator, *, query=None, highlight=False, author_id=None):
if highlight: if highlight:
s = s.highlight(*fields, **HIGHLIGHT_PARAMS) s = s.highlight(*fields, **HIGHLIGHT_PARAMS)
s = s.sort("-published") if sort == ReportSort.PUBLISHED:
s = s.sort("published" if reversed else "-published")
elif sort == ReportSort.DATE:
s = s.sort("date" if reversed else "-date")
elif sort == ReportSort.RELEVANCE:
s = s.sort({"_score": {"order": "asc" if reversed else "desc"}}, "-published")
s = s[paginator.slice_from : paginator.slice_to] s = s[paginator.slice_from : paginator.slice_to]
return s.execute() return s.execute()
...@@ -56,7 +56,7 @@ reports = [ ...@@ -56,7 +56,7 @@ reports = [
}, },
{ {
"id": 3, "id": 3,
"date": arrow.get(2018, 1, 5).datetime, "date": arrow.get(2018, 1, 2).datetime,
"published": arrow.get(2018, 1, 6).datetime, "published": arrow.get(2018, 1, 6).datetime,
"edited": arrow.get(2018, 1, 6, 7).datetime, "edited": arrow.get(2018, 1, 6, 7).datetime,
"title": "The Return of the King", "title": "The Return of the King",
......
...@@ -232,7 +232,7 @@ snapshots['test_with_reports 1'] = { ...@@ -232,7 +232,7 @@ snapshots['test_with_reports 1'] = {
'cursor': 'MQ==', 'cursor': 'MQ==',
'node': { 'node': {
'body': 'Aragorn is the King. And we have lost the Ring.', 'body': 'Aragorn is the King. And we have lost the Ring.',
'date': '2018-01-05 00:00:00+00:00', 'date': '2018-01-02 00:00:00+00:00',
'edited': '2018-01-06 07:00:00+00:00', 'edited': '2018-01-06 07:00:00+00:00',
'extra': None, 'extra': None,
'id': 'UmVwb3J0OjM=', 'id': 'UmVwb3J0OjM=',
......
...@@ -23,7 +23,7 @@ snapshots['test_all 1'] = { ...@@ -23,7 +23,7 @@ snapshots['test_all 1'] = {
'totalReports': 2 'totalReports': 2
}, },
'body': 'Aragorn is the King. And we have lost the Ring.', 'body': 'Aragorn is the King. And we have lost the Ring.',
'date': '2018-01-05 00:00:00+00:00', 'date': '2018-01-02 00:00:00+00:00',
'edited': '2018-01-06 07:00:00+00:00', 'edited': '2018-01-06 07:00:00+00:00',
'extra': None, 'extra': None,
'hasRevisions': False, 'hasRevisions': False,
...@@ -153,7 +153,7 @@ snapshots['test_highlight 1'] = { ...@@ -153,7 +153,7 @@ snapshots['test_highlight 1'] = {
'totalReports': 2 'totalReports': 2
}, },
'body': 'Aragorn is the King. And we have lost the <mark>Ring</mark>.', 'body': 'Aragorn is the King. And we have lost the <mark>Ring</mark>.',
'date': '2018-01-05 00:00:00+00:00', 'date': '2018-01-02 00:00:00+00:00',
'edited': '2018-01-06 07:00:00+00:00', 'edited': '2018-01-06 07:00:00+00:00',
'extra': None, 'extra': None,
'hasRevisions': False, 'hasRevisions': False,
......
import pytest import pytest
from graphql_relay import from_global_id
from ..dummy import prepare_reports from ..dummy import prepare_reports
...@@ -231,3 +232,37 @@ def test_last_before(call_api, snapshot): ...@@ -231,3 +232,37 @@ def test_last_before(call_api, snapshot):
""" """
response = call_api(query) response = call_api(query)
snapshot.assert_match(response) snapshot.assert_match(response)
@pytest.mark.parametrize(
"params, expected_ids",
[
("sort: PUBLISHED", [3, 2, 1]),
("sort: PUBLISHED, reversed: true", [1, 2, 3]),
("sort: DATE", [2, 3, 1]),
("sort: DATE, reversed: true", [1, 3, 2]),
("sort: RELEVANCE", [3, 2, 1]),
("sort: RELEVANCE, reversed: true", [3, 2, 1]),
('query: "ring", sort: RELEVANCE', [1, 3]),
('query: "ring", sort: RELEVANCE, reversed: true', [3, 1]),
],
)
def test_sort(params, expected_ids, call_api, snapshot):
prepare_reports()
query = f"""
query {{
searchReports ({params}) {{
totalCount
edges {{
cursor
node {{
id
}}
}}
}}
}}
"""
response = call_api(query)
ids = [edge["node"]["id"] for edge in response["data"]["searchReports"]["edges"]]
ids = [int(id) for type, id in map(from_global_id, ids)]
assert ids == expected_ids
import pytest import pytest
from openlobby.core.api.paginator import Paginator, encode_cursor from openlobby.core.api.paginator import Paginator, encode_cursor
from openlobby.core.models import ReportSort
from openlobby.core.search import search_reports from openlobby.core.search import search_reports
from .dummy import prepare_reports from .dummy import prepare_reports
...@@ -13,7 +14,7 @@ pytestmark = [pytest.mark.django_db, pytest.mark.usefixtures("django_es")] ...@@ -13,7 +14,7 @@ pytestmark = [pytest.mark.django_db, pytest.mark.usefixtures("django_es")]
"query, expected_ids", "query, expected_ids",
[("", [3, 2, 1]), ("sauron", [3, 2]), ("towers", [2]), ("Aragorn Gandalf", [3, 1])], [("", [3, 2, 1]), ("sauron", [3, 2]), ("towers", [2]), ("Aragorn Gandalf", [3, 1])],
) )
def test_search_reports(query, expected_ids): def test_search_reports__query(query, expected_ids):
prepare_reports() prepare_reports()
paginator = Paginator() paginator = Paginator()
response = search_reports(paginator, query=query) response = search_reports(paginator, query=query)
...@@ -35,9 +36,8 @@ def test_search_reports__highlight(): ...@@ -35,9 +36,8 @@ def test_search_reports__highlight():
) )
def test_search_reports__pagination(first, after, expected_ids): def test_search_reports__pagination(first, after, expected_ids):
prepare_reports() prepare_reports()
query = ""
paginator = Paginator(first=first, after=after) paginator = Paginator(first=first, after=after)
response = search_reports(paginator, query=query) response = search_reports(paginator)
assert expected_ids == [int(r.meta.id) for r in response] assert expected_ids == [int(r.meta.id) for r in response]
...@@ -58,3 +58,30 @@ def test_search_reports__by_author__pagination(first, after, expected_ids): ...@@ -58,3 +58,30 @@ def test_search_reports__by_author__pagination(first, after, expected_ids):
paginator = Paginator(first=first, after=after) paginator = Paginator(first=first, after=after)
response = search_reports(paginator, author_id=author_id) response = search_reports(paginator, author_id=author_id)
assert expected_ids == [int(r.meta.id) for r in response] assert expected_ids == [int(r.meta.id) for r in response]
def test_search_reports__sort__default():
prepare_reports()
paginator = Paginator()
response = search_reports(paginator)
assert [3, 2, 1] == [int(r.meta.id) for r in response]
@pytest.mark.parametrize(
"query, sort, reversed, expected_ids",
[
(None, ReportSort.PUBLISHED, False, [3, 2, 1]),
(None, ReportSort.PUBLISHED, True, [1, 2, 3]),
(None, ReportSort.DATE, False, [2, 3, 1]),
(None, ReportSort.DATE, True, [1, 3, 2]),
(None, ReportSort.RELEVANCE, False, [3, 2, 1]),
(None, ReportSort.RELEVANCE, True, [3, 2, 1]),
("ring", ReportSort.RELEVANCE, False, [1, 3]),
("ring", ReportSort.RELEVANCE, True, [3, 1]),
],
)
def test_search_reports__sort(query, sort, reversed, expected_ids):
prepare_reports()
paginator = Paginator()
response = search_reports(paginator, query=query, sort=sort, reversed=reversed)
assert expected_ids == [int(r.meta.id) for r in response]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment