diff --git a/smash/web/tests/functions.py b/smash/web/tests/functions.py index 4adab0a0359062730e5d801833ce0a2ddae4fe3d..cd6915a53a3f6a41227f2bb34fdcddc6f8fd4607 100644 --- a/smash/web/tests/functions.py +++ b/smash/web/tests/functions.py @@ -38,11 +38,15 @@ def create_subject(): country="france") -def create_user(): +def create_user(username=None, password=None): + if username is None: + username = 'piotr' + if password is None: + password = 'top_secret' user = User.objects.create_user( - username='piotr', + username=username, email='jacob@bla', - password='top_secret') + password=password) create_worker(user) return user diff --git a/smash/web/tests/test_view_visit.py b/smash/web/tests/test_view_visit.py index e1979cea43307c6793cc2685ee0f99146170264b..fb60256b0a01e473d1fb5de044d1aa5154838630 100644 --- a/smash/web/tests/test_view_visit.py +++ b/smash/web/tests/test_view_visit.py @@ -1,24 +1,29 @@ -from django.contrib.auth.models import User -from django.test import TestCase, RequestFactory +from django.test import Client +from django.test import TestCase from django.urls import reverse -from functions import create_subject, create_visit, create_appointment +from functions import create_subject, create_visit, create_appointment, create_user from web.views.visit import visit_details class VisitViewTests(TestCase): def setUp(self): - # Every test needs access to the request factory. - self.factory = RequestFactory() - self.user = User.objects.create_user( - username='piotr', email='jacob@bla', password='top_secret') + username = 'piotr' + password = 'top_secret' + + self.client = Client() + self.user = create_user(username, password) + self.client.login(username=username, password=password) def test_visit_details_request(self): subject = create_subject() visit = create_visit(subject) create_appointment(visit) - request = self.factory.get(reverse('web.views.visit_details', args=[visit.id])) - request.user = self.user - response = visit_details(request, visit.id) + response = self.client.get(reverse('web.views.visit_details', args=[visit.id])) self.assertEqual(response.status_code, 200) + + def test_visit_list(self): + visit = create_visit() + + request = self.client.get(reverse('web.views.visits'))