From ee8aefda714667a4783218204e995e729f82f6b6 Mon Sep 17 00:00:00 2001
From: Piotr Gawron <piotr.gawron@uni.lu>
Date: Fri, 1 Dec 2017 11:00:55 +0100
Subject: [PATCH] StudySubjectAddForm requires study parameter

during save study is automatically injected into model
---
 smash/web/docx_helper.py                      |  4 +++
 smash/web/forms.py                            | 29 +++++++++++++++----
 .../tests/forms/test_StudySubjectAddForm.py   | 17 +++++------
 smash/web/tests/view/test_subjects.py         |  5 ++--
 smash/web/views/subject.py                    | 10 +++----
 5 files changed, 42 insertions(+), 23 deletions(-)

diff --git a/smash/web/docx_helper.py b/smash/web/docx_helper.py
index fb78711c..49c2fc87 100644
--- a/smash/web/docx_helper.py
+++ b/smash/web/docx_helper.py
@@ -1,5 +1,9 @@
+import logging
+
 from docx import Document
 
+logger = logging.getLogger(__name__)
+
 
 def process_file(path_to_docx, path_to_new_docx, changes_to_apply):
     """
diff --git a/smash/web/forms.py b/smash/web/forms.py
index ade762d5..0c062b1a 100644
--- a/smash/web/forms.py
+++ b/smash/web/forms.py
@@ -66,6 +66,23 @@ def validate_subject_mpower_number(self, cleaned_data):
                 self.add_error('mpower_id', "mPower number already in use")
 
 
+def get_worker_from_args(kwargs):
+    user = kwargs.pop('user', None)
+    if user is None:
+        raise TypeError("User not defined")
+    result = Worker.get_by_user(user)
+    if result is None:
+        raise TypeError("Worker not defined for: " + user.username)
+    return result
+
+
+def get_study_from_args(kwargs):
+    study = kwargs.pop('study', None)
+    if study is None:
+        raise TypeError("Study not defined")
+    return study
+
+
 class StudySubjectAddForm(ModelForm):
     datetime_contact_reminder = forms.DateTimeField(label="Contact on",
                                                     widget=forms.DateTimeInput(DATETIMEPICKER_DATE_ATTRS),
@@ -78,16 +95,16 @@ class StudySubjectAddForm(ModelForm):
         exclude = ['resigned', 'resign_reason']
 
     def __init__(self, *args, **kwargs):
-        user = kwargs.pop('user', None)
-        if user is None:
-            raise TypeError("User not defined")
-        self.user = Worker.get_by_user(user)
-        if self.user is None:
-            raise TypeError("Worker not defined for: " + user.username)
+        self.user = get_worker_from_args(kwargs)
+        self.study = get_study_from_args(kwargs)
 
         super(ModelForm, self).__init__(*args, **kwargs)
         self.fields['screening_number'].required = False
 
+    def save(self, commit=True):
+        self.instance.study_id = self.study.id
+        return super(ModelForm, self).save(commit)
+
     def build_screening_number(self, cleaned_data):
         screening_number = cleaned_data.get('screening_number', None)
         if not screening_number:
diff --git a/smash/web/tests/forms/test_StudySubjectAddForm.py b/smash/web/tests/forms/test_StudySubjectAddForm.py
index 21ff5486..1a206bff 100644
--- a/smash/web/tests/forms/test_StudySubjectAddForm.py
+++ b/smash/web/tests/forms/test_StudySubjectAddForm.py
@@ -23,7 +23,7 @@ class StudySubjectAddFormTests(LoggedInWithWorkerTestCase):
         }
 
     def test_validation(self):
-        form = StudySubjectAddForm(data=self.sample_data, user=self.user)
+        form = StudySubjectAddForm(data=self.sample_data, user=self.user, study=self.study)
         form.is_valid()
         self.assertTrue(form.is_valid())
 
@@ -31,15 +31,14 @@ class StudySubjectAddFormTests(LoggedInWithWorkerTestCase):
         form_data = self.sample_data
         form_data['screening_number'] = "123"
 
-        form = StudySubjectAddForm(data=form_data, user=self.user)
+        form = StudySubjectAddForm(data=form_data, user=self.user, study=self.study)
         form.is_valid()
         form.instance.subject_id = self.subject.id
-        form.instance.study_id = self.study.id
         self.assertTrue(form.is_valid())
         self.assertIsNone(form.fields['year_of_diagnosis'].initial)
         form.save()
 
-        form2 = StudySubjectAddForm(data=form_data, user=self.user)
+        form2 = StudySubjectAddForm(data=form_data, user=self.user, study=self.study)
         validation_status = form2.is_valid()
         self.assertFalse(validation_status)
         self.assertTrue("screening_number" in form2.errors)
@@ -48,15 +47,14 @@ class StudySubjectAddFormTests(LoggedInWithWorkerTestCase):
         form_data = self.sample_data
         form_data['nd_number'] = "ND0123"
 
-        form = StudySubjectAddForm(data=form_data, user=self.user)
+        form = StudySubjectAddForm(data=form_data, user=self.user, study=self.study)
         form.is_valid()
         self.assertTrue(form.is_valid())
         form.instance.subject_id = self.subject.id
-        form.instance.study_id = self.study.id
         form.save()
 
         form_data['screening_number'] = "2"
-        form2 = StudySubjectAddForm(data=form_data, user=self.user)
+        form2 = StudySubjectAddForm(data=form_data, user=self.user, study=self.study)
         validation_status = form2.is_valid()
         self.assertFalse(validation_status)
         self.assertTrue("nd_number" in form2.errors)
@@ -65,15 +63,14 @@ class StudySubjectAddFormTests(LoggedInWithWorkerTestCase):
         form_data = self.sample_data
         form_data['mpower_id'] = "123"
 
-        form = StudySubjectAddForm(data=form_data, user=self.user)
+        form = StudySubjectAddForm(data=form_data, user=self.user, study=self.study)
         form.is_valid()
         self.assertTrue(form.is_valid())
         form.instance.subject_id = self.subject.id
-        form.instance.study_id = self.study.id
         form.save()
 
         form_data['screening_number'] = "2"
-        form2 = StudySubjectAddForm(data=form_data, user=self.user)
+        form2 = StudySubjectAddForm(data=form_data, user=self.user, study=self.study)
         validation_status = form2.is_valid()
         self.assertFalse(validation_status)
         self.assertTrue("mpower_id" in form2.errors)
diff --git a/smash/web/tests/view/test_subjects.py b/smash/web/tests/view/test_subjects.py
index 83820e01..7d47368c 100644
--- a/smash/web/tests/view/test_subjects.py
+++ b/smash/web/tests/view/test_subjects.py
@@ -9,7 +9,7 @@ from web.models.constants import SEX_CHOICES_MALE, SUBJECT_TYPE_CHOICES_CONTROL,
     COUNTRY_AFGHANISTAN_ID, COUNTRY_OTHER_ID, MAIL_TEMPLATE_CONTEXT_SUBJECT
 from web.tests import LoggedInWithWorkerTestCase
 from web.tests.functions import create_study_subject, create_visit, create_appointment, get_test_location, \
-    create_language, get_resource_path
+    create_language, get_resource_path, get_test_study
 from web.views.notifications import get_today_midnight_date
 
 logger = logging.getLogger(__name__)
@@ -19,6 +19,7 @@ class SubjectsViewTests(LoggedInWithWorkerTestCase):
     def setUp(self):
         super(SubjectsViewTests, self).setUp()
         self.study_subject = create_study_subject()
+        self.study = get_test_study()
 
     def test_render_subjects_add(self):
         self.worker.save()
@@ -105,7 +106,7 @@ class SubjectsViewTests(LoggedInWithWorkerTestCase):
         return form_data
 
     def create_add_form_data_for_study_subject(self):
-        form_study_subject = StudySubjectAddForm(prefix="study_subject", user=self.user)
+        form_study_subject = StudySubjectAddForm(prefix="study_subject", user=self.user, study=self.study)
         form_subject = SubjectAddForm(prefix="subject")
         form_data = {}
         for key, value in form_study_subject.initial.items():
diff --git a/smash/web/views/subject.py b/smash/web/views/subject.py
index d8d04046..7542d73b 100644
--- a/smash/web/views/subject.py
+++ b/smash/web/views/subject.py
@@ -4,10 +4,10 @@ import logging
 from django.contrib import messages
 from django.shortcuts import redirect, get_object_or_404
 
-from ..models.constants import GLOBAL_STUDY_ID
 from . import wrap_response
 from ..forms import StudySubjectAddForm, StudySubjectEditForm, VisitDetailForm, SubjectEditForm, SubjectAddForm
-from ..models import StudySubject, MailTemplate, Worker
+from ..models import StudySubject, MailTemplate, Worker, Study
+from ..models.constants import GLOBAL_STUDY_ID
 
 SUBJECT_LIST_GENERIC = "GENERIC"
 SUBJECT_LIST_NO_VISIT = "NO_VISIT"
@@ -25,14 +25,14 @@ def subjects(request):
 
 
 def subject_add(request):
+    study = Study.objects.filter(id=GLOBAL_STUDY_ID)[0]
     if request.method == 'POST':
-        study_subject_form = StudySubjectAddForm(request.POST, request.FILES, prefix="study_subject", user=request.user)
+        study_subject_form = StudySubjectAddForm(request.POST, request.FILES, prefix="study_subject", user=request.user, study=study)
 
         subject_form = SubjectAddForm(request.POST, request.FILES, prefix="subject")
         if study_subject_form.is_valid() and subject_form.is_valid():
             subject = subject_form.save()
             study_subject_form.instance.subject_id = subject.id
-            study_subject_form.instance.study_id = GLOBAL_STUDY_ID
             study_subject_form.save()
             messages.add_message(request, messages.SUCCESS, 'Subject created')
             return redirect('web.views.subject_edit', id=study_subject_form.instance.id)
@@ -40,7 +40,7 @@ def subject_add(request):
             messages.add_message(request, messages.ERROR, 'Invalid data. Please fix data and try again.')
 
     else:
-        study_subject_form = StudySubjectAddForm(user=request.user, prefix="study_subject")
+        study_subject_form = StudySubjectAddForm(user=request.user, prefix="study_subject", study=study)
         subject_form = SubjectAddForm(prefix="subject")
 
     return wrap_response(request, 'subjects/add.html',
-- 
GitLab