From c92adf295a56a004afcb52512cd34f30d27ffeb5 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jan=20Bedna=C5=99=C3=ADk?= <jan.bednarik@gmail.com>
Date: Sat, 17 Feb 2018 20:41:31 +0100
Subject: [PATCH] Custom middleware for JWT token auth.

---
 openlobby/core/api/utils.py             |  6 ++
 openlobby/core/auth.py                  |  6 --
 openlobby/core/middleware.py            | 33 +++++++++++
 openlobby/settings.py                   |  1 +
 tests/snapshots/snap_test_middleware.py | 24 ++++++++
 tests/test_middleware.py                | 77 +++++++++++++++++++++++++
 6 files changed, 141 insertions(+), 6 deletions(-)
 create mode 100644 openlobby/core/api/utils.py
 create mode 100644 tests/snapshots/snap_test_middleware.py
 create mode 100644 tests/test_middleware.py

diff --git a/openlobby/core/api/utils.py b/openlobby/core/api/utils.py
new file mode 100644
index 0000000..e2503c6
--- /dev/null
+++ b/openlobby/core/api/utils.py
@@ -0,0 +1,6 @@
+from django.http import JsonResponse
+
+
+def graphql_error_response(message, status_code=400):
+    error = {'message': message}
+    return JsonResponse({'errors': [error]}, status=status_code)
diff --git a/openlobby/core/auth.py b/openlobby/core/auth.py
index fc45cb4..b49b3a9 100644
--- a/openlobby/core/auth.py
+++ b/openlobby/core/auth.py
@@ -1,5 +1,4 @@
 from django.conf import settings
-import json
 import jwt
 import time
 
@@ -18,8 +17,3 @@ def create_access_token(username, expiration=None):
 def parse_access_token(token):
     payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
     return payload['sub']
-
-
-def graphql_error_response(message, code=400):
-    error = {'message': message}
-    return json.dumps({'errors': [error]}), code, {'Content-Type': 'application/json'}
diff --git a/openlobby/core/middleware.py b/openlobby/core/middleware.py
index e69de29..a6c1083 100644
--- a/openlobby/core/middleware.py
+++ b/openlobby/core/middleware.py
@@ -0,0 +1,33 @@
+import re
+
+from .api.utils import graphql_error_response
+from .auth import parse_access_token
+from .models import User
+
+
+class TokenAuthMiddleware:
+    """Custom authentication middleware which using JWT token."""
+
+    def __init__(self, get_response):
+        self.get_response = get_response
+
+    def __call__(self, request):
+        auth_header = request.META.get('HTTP_AUTHORIZATION')
+        if auth_header is not None:
+            m = re.match(r'Bearer (?P<token>.+)', auth_header)
+            if m:
+                token = m.group('token')
+            else:
+                return graphql_error_response('Wrong Authorization header. Expected: "Bearer <token>"')
+
+            try:
+                username = parse_access_token(token)
+            except Exception:
+                return graphql_error_response('Invalid Token.', 401)
+
+            try:
+                request.user = User.objects.get(username=username)
+            except User.DoesNotExist:
+                pass
+
+        return self.get_response(request)
diff --git a/openlobby/settings.py b/openlobby/settings.py
index 5fe4023..dad0d61 100644
--- a/openlobby/settings.py
+++ b/openlobby/settings.py
@@ -37,6 +37,7 @@ MIDDLEWARE = [
     'django.middleware.common.CommonMiddleware',
     'django.middleware.csrf.CsrfViewMiddleware',
     'django.contrib.auth.middleware.AuthenticationMiddleware',
+    'openlobby.core.middleware.TokenAuthMiddleware',
     'django.middleware.clickjacking.XFrameOptionsMiddleware',
     'django.contrib.messages.middleware.MessageMiddleware',
 ]
diff --git a/tests/snapshots/snap_test_middleware.py b/tests/snapshots/snap_test_middleware.py
new file mode 100644
index 0000000..2be80f5
--- /dev/null
+++ b/tests/snapshots/snap_test_middleware.py
@@ -0,0 +1,24 @@
+# -*- coding: utf-8 -*-
+# snapshottest: v1 - https://goo.gl/zC4yUc
+from __future__ import unicode_literals
+
+from snapshottest import Snapshot
+
+
+snapshots = Snapshot()
+
+snapshots['test_wrong_header 1'] = {
+    'errors': [
+        {
+            'message': 'Wrong Authorization header. Expected: "Bearer <token>"'
+        }
+    ]
+}
+
+snapshots['test_invalid_token 1'] = {
+    'errors': [
+        {
+            'message': 'Invalid Token.'
+        }
+    ]
+}
diff --git a/tests/test_middleware.py b/tests/test_middleware.py
new file mode 100644
index 0000000..37dbd17
--- /dev/null
+++ b/tests/test_middleware.py
@@ -0,0 +1,77 @@
+import pytest
+import json
+from unittest.mock import Mock
+
+from openlobby.core.auth import create_access_token
+from openlobby.core.middleware import TokenAuthMiddleware
+from openlobby.core.models import User
+
+
+pytestmark = pytest.mark.django_db
+
+
+def test_no_auth_header():
+    request = Mock()
+    request.user = None
+    request.META.get.return_value = None
+
+    middleware = TokenAuthMiddleware(lambda r: r)
+    response = middleware(request)
+
+    request.META.get.assert_called_once_with('HTTP_AUTHORIZATION')
+    assert response == request
+    assert response.user is None
+
+
+def test_authorized_user():
+    user = User.objects.create(username='wolfe', first_name='Winston',
+        last_name='Wolfe', email='winston@wolfe.com')
+    request = Mock()
+    request.user = None
+    request.META.get.return_value = 'Bearer {}'.format(create_access_token('wolfe'))
+
+    middleware = TokenAuthMiddleware(lambda r: r)
+    response = middleware(request)
+
+    request.META.get.assert_called_once_with('HTTP_AUTHORIZATION')
+    assert response == request
+    assert response.user == user
+
+
+def test_wrong_header(snapshot):
+    request = Mock()
+    request.user = None
+    request.META.get.return_value = 'WRONG {}'.format(create_access_token('unknown'))
+
+    middleware = TokenAuthMiddleware(lambda r: r)
+    response = middleware(request)
+
+    request.META.get.assert_called_once_with('HTTP_AUTHORIZATION')
+    assert response.status_code == 400
+    snapshot.assert_match(json.loads(response.content))
+
+
+def test_invalid_token(snapshot):
+    request = Mock()
+    request.user = None
+    request.META.get.return_value = 'Bearer XXX{}'.format(create_access_token('unknown'))
+
+    middleware = TokenAuthMiddleware(lambda r: r)
+    response = middleware(request)
+
+    request.META.get.assert_called_once_with('HTTP_AUTHORIZATION')
+    assert response.status_code == 401
+    snapshot.assert_match(json.loads(response.content))
+
+
+def test_unknown_user(snapshot):
+    request = Mock()
+    request.user = None
+    request.META.get.return_value = 'Bearer {}'.format(create_access_token('unknown'))
+
+    middleware = TokenAuthMiddleware(lambda r: r)
+    response = middleware(request)
+
+    request.META.get.assert_called_once_with('HTTP_AUTHORIZATION')
+    assert response == request
+    assert response.user is None
-- 
GitLab