From 64b8f67e477e6ab4ae332086fd7b70f1d605d1c4 Mon Sep 17 00:00:00 2001
From: Piotr Gawron <piotr.gawron@uni.lu>
Date: Tue, 12 Dec 2017 11:42:21 +0100
Subject: [PATCH] voucher types are limited to the set allowed for the subject

---
 smash/web/forms/voucher_forms.py            |  4 ++++
 smash/web/tests/forms/test_voucher_forms.py |  4 +++-
 smash/web/tests/view/test_voucher.py        |  8 ++++++--
 smash/web/views/voucher.py                  | 17 ++++++++++++++---
 4 files changed, 27 insertions(+), 6 deletions(-)

diff --git a/smash/web/forms/voucher_forms.py b/smash/web/forms/voucher_forms.py
index 68de4581..99bb8973 100644
--- a/smash/web/forms/voucher_forms.py
+++ b/smash/web/forms/voucher_forms.py
@@ -38,7 +38,11 @@ class VoucherForm(ModelForm):
         fields = '__all__'
 
     def __init__(self, *args, **kwargs):
+        voucher_types = kwargs.pop('voucher_types', VoucherType.objects.all())
         super(VoucherForm, self).__init__(*args, **kwargs)
+
+        self.fields['voucher_type'].queryset = voucher_types
+
         self.fields['number'].widget.attrs['readonly'] = True
         self.fields['number'].required = False
 
diff --git a/smash/web/tests/forms/test_voucher_forms.py b/smash/web/tests/forms/test_voucher_forms.py
index e7310f34..64a78eef 100644
--- a/smash/web/tests/forms/test_voucher_forms.py
+++ b/smash/web/tests/forms/test_voucher_forms.py
@@ -16,14 +16,16 @@ class VoucherFormTests(LoggedInWithWorkerTestCase):
         super(VoucherFormTests, self).setUp()
 
     def test_auto_generated_use_date(self):
+        voucher_type = create_voucher_type()
         study_subject = create_study_subject()
+        study_subject.voucher_types.add(voucher_type)
         create_voucher(study_subject)
 
         voucher_form = VoucherForm()
         form_data = {
             "status": VOUCHER_STATUS_USED,
             "usage_partner": str(self.worker.id),
-            "voucher_type": create_voucher_type().id
+            "voucher_type": voucher_type.id
         }
         for key, value in voucher_form.initial.items():
             form_data[key] = format_form_field(value)
diff --git a/smash/web/tests/view/test_voucher.py b/smash/web/tests/view/test_voucher.py
index 871bf563..58332849 100644
--- a/smash/web/tests/view/test_voucher.py
+++ b/smash/web/tests/view/test_voucher.py
@@ -13,7 +13,9 @@ logger = logging.getLogger(__name__)
 
 class VoucherTypeViewTests(LoggedInTestCase):
     def test_render_add_voucher_request(self):
-        response = self.client.get(reverse('web.views.voucher_add'))
+        study_subject = create_study_subject()
+        url = reverse('web.views.voucher_add') + "?study_subject_id=" + str(study_subject.id)
+        response = self.client.get(url)
         self.assertEqual(response.status_code, 200)
 
     def test_render_edit_voucher_request(self):
@@ -22,11 +24,13 @@ class VoucherTypeViewTests(LoggedInTestCase):
         self.assertEqual(response.status_code, 200)
 
     def test_add_voucher(self):
+        voucher_type = create_voucher_type()
         study_subject = create_study_subject()
+        study_subject.voucher_types.add(voucher_type)
         visit_detail_form = VoucherForm()
         form_data = {
             "status": VOUCHER_STATUS_NEW,
-            "voucher_type": create_voucher_type().id
+            "voucher_type": voucher_type.id
         }
         for key, value in visit_detail_form.initial.items():
             form_data[key] = format_form_field(value)
diff --git a/smash/web/views/voucher.py b/smash/web/views/voucher.py
index 840e5d34..4ca8361c 100644
--- a/smash/web/views/voucher.py
+++ b/smash/web/views/voucher.py
@@ -8,7 +8,7 @@ from django.views.generic import ListView
 from django.views.generic import UpdateView
 
 from web.forms import VoucherForm
-from web.models import Voucher
+from web.models import Voucher, StudySubject
 from web.models.constants import GLOBAL_STUDY_ID
 from . import WrappedView
 
@@ -21,6 +21,10 @@ class VoucherListView(ListView, WrappedView):
     template_name = 'vouchers/list.html'
 
 
+def voucher_types_for_study_subject(study_subject_id):
+    return StudySubject.objects.get(id=study_subject_id).voucher_types.all()
+
+
 class VoucherCreateView(CreateView, WrappedView):
     form_class = VoucherForm
     model = Voucher
@@ -39,6 +43,11 @@ class VoucherCreateView(CreateView, WrappedView):
         # noinspection PyUnresolvedReferences
         return reverse_lazy('web.views.subject_edit', kwargs={'id': self.request.GET.get("study_subject_id", -1)})
 
+    def get_form_kwargs(self):
+        kwargs = super(VoucherCreateView, self).get_form_kwargs()
+        kwargs['voucher_types'] = voucher_types_for_study_subject(self.request.GET.get("study_subject_id", -1))
+        return kwargs
+
 
 class VoucherEditView(SuccessMessageMixin, UpdateView, WrappedView):
     form_class = VoucherForm
@@ -51,5 +60,7 @@ class VoucherEditView(SuccessMessageMixin, UpdateView, WrappedView):
 
     def get_success_url(self, **kwargs):
         # noinspection PyUnresolvedReferences
-        study_subject_id = Voucher.objects.get(id=self.kwargs['pk']).study_subject.id
-        return reverse_lazy('web.views.subject_edit', kwargs={'id': study_subject_id})
+        return reverse_lazy('web.views.subject_edit', kwargs={'id': self.get_study_subject_id()})
+
+    def get_study_subject_id(self):
+        return Voucher.objects.get(id=self.kwargs['pk']).study_subject.id
-- 
GitLab