diff --git a/smash/web/forms.py b/smash/web/forms.py index 5b1d8c25a19162761a85cd8fd9ead0a5a2d07565..c5499f0e5afa216284e0d651d4f52bd7b9b17c60 100644 --- a/smash/web/forms.py +++ b/smash/web/forms.py @@ -128,6 +128,24 @@ class AppointmentEditForm(ModelForm): widget=forms.DateTimeInput(DATETIMEPICKER_DATE_ATTRS) ) + def __init__(self, *args, **kwargs): + user = kwargs.pop('user', None) + if user==None: + raise TypeError("User not defined") + self.user = Worker.get_by_user(user) + if self.user==None: + raise TypeError("Worker not defined for: " + user.username) + + super(ModelForm, self).__init__(*args,**kwargs) + + def clean_location(self): + location = self.cleaned_data['location'] + if self.user.locations.filter(id = location.id).count()==0: + self.add_error('location', "You cannot create appointment for this location") + else: + return location + + class AppointmentAddForm(ModelForm): class Meta: model = Appointment @@ -137,6 +155,24 @@ class AppointmentAddForm(ModelForm): widget=forms.DateTimeInput(DATETIMEPICKER_DATE_ATTRS) ) + def __init__(self, *args, **kwargs): + user = kwargs.pop('user', None) + if user==None: + raise TypeError("User not defined") + self.user = Worker.get_by_user(user) + if self.user==None: + raise TypeError("Worker not defined for: " + user.username) + + super(ModelForm, self).__init__(*args,**kwargs) + + def clean_location(self): + location = self.cleaned_data['location'] + if self.user.locations.filter(id = location.id).count()==0: + self.add_error('location', "You cannot create appointment for this location") + else: + return location + + class VisitDetailForm(ModelForm): datetime_begin = forms.DateField(label="Visit begins on", widget=forms.DateInput(DATEPICKER_DATE_ATTRS, "%Y-%m-%d") @@ -164,7 +200,6 @@ class VisitAddForm(ModelForm): exclude = ['is_finished'] def clean(self): - print self.cleaned_data['appointment_types'] if (self.cleaned_data['datetime_begin']>=self.cleaned_data['datetime_end']): self.add_error('datetime_begin', "Start date must be before end date") self.add_error('datetime_end', "Start date must be before end date") diff --git a/smash/web/models.py b/smash/web/models.py index 71df4a06b74a8a6979d809b9b40aa8ef58ff273e..83dc148eacfcddd156eaa7bf88608411f0c11f5b 100644 --- a/smash/web/models.py +++ b/smash/web/models.py @@ -380,6 +380,21 @@ class Worker (models.Model): return False + @staticmethod + def get_by_user(the_user): + if isinstance(the_user, User): + workers = Worker.objects.filter(user=the_user) + if len(workers)>0: + return workers[0] + else: + return None + elif isinstance(user, Worker): + return user + elif user!=None: + raise TypeError("Unknown class type: "+user.__class__.__name__) + else: + return None + @staticmethod def get_details(the_user): if the_user.is_authenticated == False: diff --git a/smash/web/templates/subjects/index.html b/smash/web/templates/subjects/index.html index 889fb48da4794ed960c6470aeedd36fefb200d3c..252bef06d8ac685754cb5855010e46cc25a05d64 100644 --- a/smash/web/templates/subjects/index.html +++ b/smash/web/templates/subjects/index.html @@ -51,7 +51,7 @@ <td>{{ subject.screening_number }}</td> <td>{{ subject.first_name }}</td> <td>{{ subject.last_name }}</td> - <td>{{ subject.get_default_appointment_location_display }}</td> + <td>{{ subject.default_location }}</td> <td>{% if subject.dead %} YES {% else %} NO {% endif %} </td> <td>{% if subject.resigned %} YES {% else %} NO {% endif %} </td> <td>{% if subject.postponed %} YES {% else %} NO {% endif %} </td> diff --git a/smash/web/tests/functions.py b/smash/web/tests/functions.py index 520601b637e923bc3825a17787eccd68caf1a6c7..afd97ec1caf489279d723c94752eb242e29e3278 100644 --- a/smash/web/tests/functions.py +++ b/smash/web/tests/functions.py @@ -28,16 +28,21 @@ def create_subject(): sex= Subject.SEX_CHOICES_MALE) def create_user(): - return User.objects.create_user( + user = User.objects.create_user( username='piotr', email='jacob@bla', password='top_secret') -def create_worker(): + create_worker(user) + return user + + +def create_worker(user = None): return Worker.objects.create( first_name='piotr', last_name="gawron", email='jacob@bla', + user = user, ) def create_visit(subject = None): diff --git a/smash/web/tests/test_AppointmentAddForm.py b/smash/web/tests/test_AppointmentAddForm.py new file mode 100644 index 0000000000000000000000000000000000000000..2dfc3e7d154ffa4cf78f0086bc5fc1ba3bd86c38 --- /dev/null +++ b/smash/web/tests/test_AppointmentAddForm.py @@ -0,0 +1,34 @@ +from django.test import TestCase +from web.forms import AppointmentAddForm +from web.models import Subject + +from web.tests.functions import * + +class AppointmentAddFormTests(TestCase): + def setUp(self): + location = get_test_location() + self.user = create_user() + + worker = Worker.get_by_user(self.user) + worker.locations = [get_test_location()] + worker.save() + + self.visit = create_visit() + + self.sample_data = {'first_name': 'name', + 'length': '50', + 'visit' : self.visit.id, + 'location' : location.id, + 'datetime_when' : "2020-01-01", + } + + def test_validation(self): + form = AppointmentAddForm(user=self.user, data=self.sample_data) + self.assertTrue(form.is_valid()) + + def test_validation_invalid_location(self): + self.sample_data['location'] = create_location(name="xxx").id + form = AppointmentAddForm(user=self.user, data=self.sample_data) + + self.assertFalse(form.is_valid()) + self.assertTrue("location" in form.errors) diff --git a/smash/web/tests/test_AppointmentEditForm.py b/smash/web/tests/test_AppointmentEditForm.py new file mode 100644 index 0000000000000000000000000000000000000000..e1ef88cc72a02279e1e6793e24ef12f79589bfcf --- /dev/null +++ b/smash/web/tests/test_AppointmentEditForm.py @@ -0,0 +1,38 @@ +from django.test import TestCase +from web.forms import AppointmentEditForm +from web.models import Subject + +from web.tests.functions import * + +class AppointmentEditFormTests(TestCase): + def setUp(self): + location = get_test_location() + self.user = create_user() + + worker = Worker.get_by_user(self.user) + worker.locations = [get_test_location()] + worker.save() + + self.visit = create_visit() + + self.sample_data = {'first_name': 'name', + 'length': '50', + 'visit' : self.visit.id, + 'location' : location.id, + 'datetime_when' : "2020-01-01", + } + + add_form = AppointmentAddForm(user=self.user, data=self.sample_data) + self.appointment = add_form.save() + + + def test_validation(self): + form = AppointmentEditForm(user=self.user, data=self.sample_data) + self.assertTrue(form.is_valid()) + + def test_validation_invalid_location(self): + self.sample_data['location'] = create_location(name="xxx").id + form = AppointmentEditForm(user=self.user, data=self.sample_data) + + self.assertFalse(form.is_valid()) + self.assertTrue("location" in form.errors) diff --git a/smash/web/views.py b/smash/web/views.py index 4b5ee3e2a3970098c8d42e615199542d0110304c..6734692cc107937e9b639a04808ee79bedf3f269 100644 --- a/smash/web/views.py +++ b/smash/web/views.py @@ -655,13 +655,13 @@ def appointment_details(request, id): def appointment_add(request, id): full_list = get_calendar_full_appointments(request.user) if request.method == 'POST': - form = AppointmentAddForm(request.POST, request.FILES) + form = AppointmentAddForm(request.POST, request.FILES, user=request.user) form.fields['visit'].widget = forms.HiddenInput() if form.is_valid(): form.save() return redirect(visit_details, id=id) else: - form = AppointmentAddForm(initial={'visit': id}) + form = AppointmentAddForm(initial={'visit': id}, user=request.user) form.fields['visit'].widget = forms.HiddenInput() return wrap_response(request, 'appointments/add.html', {'form': form, 'visitID': id, 'full_appointment_list': full_list}) @@ -669,7 +669,11 @@ def appointment_add(request, id): def appointment_edit(request, id): the_appointment = get_object_or_404(Appointment, id=id) if request.method == 'POST': - form = AppointmentEditForm(request.POST, request.FILES, instance=the_appointment) + form = AppointmentEditForm(request.POST, + request.FILES, + instance=the_appointment, + user=request.user + ) if form.is_valid(): form.save() @@ -679,7 +683,7 @@ def appointment_edit(request, id): return redirect(appointments) else: - form = AppointmentEditForm(instance=the_appointment) + form = AppointmentEditForm(instance=the_appointment,user=request.user) subject_form = SubjectDetailForm(instance=the_appointment.visit.subject) @@ -694,13 +698,13 @@ def appointment_edit(request, id): def appointment_edit_datetime(request, id): the_appointment = get_object_or_404(Appointment, id=id) if request.method == 'POST': - form = AppointmentEditForm(request.POST, request.FILES, instance=the_appointment) + form = AppointmentEditForm(request.POST, request.FILES, instance=the_appointment,user=request.user) if form.is_valid(): form.save() return redirect(appointments) else: the_appointment.datetime_when = the_appointment.visit.datetime_begin - form = AppointmentEditForm(instance=the_appointment) + form = AppointmentEditForm(instance=the_appointment, user=request.user) return wrap_response(request, 'appointments/edit.html', {'form': form}) #because we don't wrap_response we must force login required