diff --git a/smash/web/importer/csv_tns_visit_import_reader.py b/smash/web/importer/csv_tns_visit_import_reader.py index 185ac5b02436feccf34410aa9aeae068f3e745e5..e6931aa6a16623ebe26960fb16fb2eabaf4e43f1 100644 --- a/smash/web/importer/csv_tns_visit_import_reader.py +++ b/smash/web/importer/csv_tns_visit_import_reader.py @@ -1,6 +1,8 @@ import csv import datetime import logging +import sys +import traceback import pytz from django.conf import settings @@ -51,25 +53,39 @@ class TnsCsvVisitImportReader: visit_number = int(visit_number) + 1 visits = Visit.objects.filter(subject=study_subject, visit_number=visit_number) - if len(visits) > 0: - raise NotImplementedError - date = self.extract_date(data['dateofvisit']) - visit = Visit.objects.create(subject=study_subject, visit_number=visit_number, datetime_begin=date, - datetime_end=date + datetime.timedelta(days=14)) + location = self.extract_location(data['adressofvisit']) + + if len(visits) > 0: + logger.warn("Visit for subject " + nd_number + " already exists. Updating") + visit = visits[0] + visit.datetime_begin = date + visit.datetime_end = date + datetime.timedelta(days=14) + else: + visit = Visit.objects.create(subject=study_subject, visit_number=visit_number, + datetime_begin=date, + datetime_end=date + datetime.timedelta(days=14)) visit.save() result.append(visit) - location = self.extract_location(data['adressofvisit']) - - appointment = Appointment.objects.create(visit=visit, length=60, datetime_when=date, - location=location) + appointments = Appointment.objects.filter(visit=visit, appointment_types=self.appointment_type) + if len(appointments) > 0: + logger.warn("Appointment for subject " + nd_number + " already set. Updating") + appointment = appointments[0] + appointment.length = 60 + appointment.datetime_when = date + appointment.location = location + appointment.save() + else: + appointment = Appointment.objects.create(visit=visit, length=60, datetime_when=date, + location=location) if self.appointment_type is not None: AppointmentTypeLink.objects.create(appointment_id=appointment.id, appointment_type=self.appointment_type) self.processed_count += 1 except: self.problematic_count += 1 + traceback.print_exc(file=sys.stdout) logger.warn("Problematic data: " + ';'.join(row)) if "WARNING" in warning_counter.level2count: diff --git a/smash/web/tests/importer/test_tns_csv_visit_import_reader.py b/smash/web/tests/importer/test_tns_csv_visit_import_reader.py index d93f37bf78186808d37f7628bb33dfc6f9d21b27..cc1a297fbfb766cdd22f8645e51ed79c34adbdc2 100644 --- a/smash/web/tests/importer/test_tns_csv_visit_import_reader.py +++ b/smash/web/tests/importer/test_tns_csv_visit_import_reader.py @@ -1,12 +1,14 @@ # coding=utf-8 +import pytz import logging +from datetime import datetime from django.conf import settings from django.test import TestCase from web.importer import TnsCsvVisitImportReader, MsgCounterHandler -from web.models import Appointment, Visit +from web.models import Appointment, Visit, StudySubject, AppointmentTypeLink, AppointmentType from web.tests.functions import get_resource_path, create_study_subject, create_appointment_type, create_location logger = logging.getLogger(__name__) @@ -17,14 +19,8 @@ class TestTnsCsvSubjectReader(TestCase): self.warning_counter = MsgCounterHandler() logging.getLogger('').addHandler(self.warning_counter) setattr(settings, "IMPORT_APPOINTMENT_TYPE", "SAMPLE_2") - create_appointment_type(code = "SAMPLE_2") + create_appointment_type(code="SAMPLE_2") - - def tearDown(self): - setattr(settings, "IMPORT_APPOINTMENT_TYPE", None) - logging.getLogger('').removeHandler(self.warning_counter) - - def test_load_data(self): create_study_subject(nd_number='cov-000111') create_study_subject(nd_number='cov-222333') create_study_subject(nd_number='cov-444444') @@ -33,6 +29,11 @@ class TestTnsCsvSubjectReader(TestCase): create_location(name=u"PickenDoheem") create_location(name=u"Ketterthill 1-3, rue de la Continentale 4917 Bascharage") + def tearDown(self): + setattr(settings, "IMPORT_APPOINTMENT_TYPE", None) + logging.getLogger('').removeHandler(self.warning_counter) + + def test_load_data(self): filename = get_resource_path('tns_vouchers_import.csv') visits = TnsCsvVisitImportReader().load_data(filename) self.assertEqual(3, len(visits)) @@ -48,5 +49,58 @@ class TestTnsCsvSubjectReader(TestCase): self.assertEqual(4, appointment.datetime_when.month) self.assertEqual(2020, appointment.datetime_when.year) + self.assertEquals(0, self.get_warnings_count()) + + def test_load_data_with_existing_visit(self): + filename = get_resource_path('tns_vouchers_import.csv') + visit = Visit.objects.create(subject=StudySubject.objects.filter(nd_number='cov-000111')[0], + datetime_end=datetime.now().replace(tzinfo=pytz.UTC), + datetime_begin=datetime.now().replace(tzinfo=pytz.UTC), + visit_number=1) + visits = TnsCsvVisitImportReader().load_data(filename) + visit = Visit.objects.filter(id=visits[0].id)[0] + self.assertEqual("cov-000111", visit.subject.nd_number) + + appointment = Appointment.objects.filter(visit=visit)[0] + self.assertEqual(u"Laboratoires réunis 23 Route de Diekirch 6555 Bollendorf-Pont", + appointment.location.name) + + self.assertEqual(10, appointment.datetime_when.day) + self.assertEqual(4, appointment.datetime_when.month) + self.assertEqual(2020, appointment.datetime_when.year) + + self.assertEquals(1, self.get_warnings_count()) + + def test_load_data_with_existing_visit_and_appointment(self): + filename = get_resource_path('tns_vouchers_import.csv') + visit = Visit.objects.create(subject=StudySubject.objects.filter(nd_number='cov-000111')[0], + datetime_end=datetime.now().replace(tzinfo=pytz.UTC), + datetime_begin=datetime.now().replace(tzinfo=pytz.UTC), + visit_number=1) + appointment = Appointment.objects.create(visit=visit, length=1, datetime_when=datetime.now(), + location=create_location()) + + AppointmentTypeLink.objects.create(appointment_id=appointment.id, + appointment_type=AppointmentType.objects.filter(code="SAMPLE_2")[0]) + + visits = TnsCsvVisitImportReader().load_data(filename) + visit = Visit.objects.filter(id=visits[0].id)[0] + self.assertEqual("cov-000111", visit.subject.nd_number) + + self.assertEquals(1, Appointment.objects.filter(visit=visit).count()) + + appointment = Appointment.objects.filter(visit=visit)[0] + self.assertEqual(u"Laboratoires réunis 23 Route de Diekirch 6555 Bollendorf-Pont", + appointment.location.name) + + self.assertEqual(10, appointment.datetime_when.day) + self.assertEqual(4, appointment.datetime_when.month) + self.assertEqual(2020, appointment.datetime_when.year) + + self.assertEquals(2, self.get_warnings_count()) + + def get_warnings_count(self): if "WARNING" in self.warning_counter.level2count: - self.assertEquals(0, self.warning_counter.level2count["WARNING"]) + return self.warning_counter.level2count["WARNING"] + else: + return 0