From 33f68f19cfef60919410308da30388cea2a79da6 Mon Sep 17 00:00:00 2001
From: Piotr Gawron <piotr.gawron@uni.lu>
Date: Tue, 17 Nov 2020 14:27:41 +0100
Subject: [PATCH] filename is associated with VisitImportData

---
 .../importer/csv_tns_visit_import_reader.py   |  3 +-
 smash/web/importer/importer_cron_job.py       | 67 ++++++++++---------
 smash/web/migrations/0179_visitimportdata.py  |  1 +
 .../0180_visitimportdata_migration.py         |  4 +-
 smash/web/models/etl/visit_import.py          | 23 +++++++
 .../tests/importer/test_importer_cron_job.py  | 20 ++++--
 .../test_tns_csv_visit_import_reader.py       | 45 +++++++------
 7 files changed, 103 insertions(+), 60 deletions(-)

diff --git a/smash/web/importer/csv_tns_visit_import_reader.py b/smash/web/importer/csv_tns_visit_import_reader.py
index fa18c566..2c5ed625 100644
--- a/smash/web/importer/csv_tns_visit_import_reader.py
+++ b/smash/web/importer/csv_tns_visit_import_reader.py
@@ -36,7 +36,8 @@ class TnsCsvVisitImportReader:
         if data.import_worker is None:
             logger.warning("Import user is not assigned")
 
-    def load_data(self, filename):
+    def load_data(self):
+        filename = self.visit_import_data.get_absolute_file_path()
         warning_counter = MsgCounterHandler()
         logging.getLogger('').addHandler(warning_counter)
 
diff --git a/smash/web/importer/importer_cron_job.py b/smash/web/importer/importer_cron_job.py
index 5d8e5171..4be9890e 100644
--- a/smash/web/importer/importer_cron_job.py
+++ b/smash/web/importer/importer_cron_job.py
@@ -6,12 +6,13 @@ import os.path
 import traceback
 
 import timeout_decorator
+from django.conf import settings
 from django.db import OperationalError
 from django_cron import CronJobBase, Schedule
 
-from web.models import ConfigurationItem, Study
+from web.models import ConfigurationItem, Study, VisitImportData
 from web.models.constants import CRON_JOB_TIMEOUT, DEFAULT_FROM_EMAIL, DAILY_SUBJECT_IMPORT_FILE, \
-    DAILY_VISIT_IMPORT_FILE, SUBJECT_IMPORT_RUN_AT, VISIT_IMPORT_RUN_AT, GLOBAL_STUDY_ID
+    SUBJECT_IMPORT_RUN_AT, VISIT_IMPORT_RUN_AT, GLOBAL_STUDY_ID
 from web.smash_email import EmailSender
 from .csv_tns_subject_import_reader import TnsCsvSubjectImportReader
 from .csv_tns_visit_import_reader import TnsCsvVisitImportReader
@@ -90,35 +91,39 @@ class VisitImporterCronJob(CronJobBase):
         email_title = "Visits daily import"
         email_recipients = ConfigurationItem.objects.get(type=DEFAULT_FROM_EMAIL).value
 
-        filename = ConfigurationItem.objects.get(type=DAILY_VISIT_IMPORT_FILE).value
-
-        if filename is None or filename == '':
-            logger.info("Importing visits skipped. File not defined ")
-            return "import file not defined"
-        logger.info("Importing visits from file: " + filename)
-        if not os.path.isfile(filename):
-            EmailSender().send_email(email_title,
-                                     "<h3><font color='red'>File with imported data is not available in the system: "
-                                     + filename + "</font></h3>",
-                                     email_recipients)
-            return "import file not found"
-        try:
-            importer = TnsCsvVisitImportReader(self.study.id)
-            importer.load_data(filename)
-            email_body = importer.get_summary()
-            EmailSender().send_email(email_title,
-                                     "<h3>Data was successfully imported from file: " + filename + "</h3>" + email_body,
-                                     email_recipients)
-            self.backup_file(filename)
-            return "import is successful"
-
-        except:
-            tb = traceback.format_exc()
-            EmailSender().send_email(email_title,
-                                     "<h3><font color='red'>There was a problem with importing data from file: "
-                                     + filename + "</font></h3><pre>" + tb + "</pre>",
-                                     email_recipients)
-            return "import crashed"
+        for import_data in VisitImportData.objects.filter(study=self.study).all():
+
+            if import_data.filename is None or import_data.filename == '':
+                logger.info("Importing visits skipped. File not defined ")
+                return "import file not defined"
+            logger.info("Importing visits from file: " + import_data.filename)
+            if not import_data.file_available():
+                content = "<h3><font color='red'>File with imported data is not available in the system: " + \
+                          import_data.filename + "</font></h3>"
+                EmailSender().send_email(email_title, content, email_recipients)
+                return "import file not found"
+
+            filename = import_data.get_absolute_file_path()
+            # noinspection PyBroadException
+            try:
+                importer = TnsCsvVisitImportReader(import_data)
+                importer.load_data()
+                email_body = importer.get_summary()
+                EmailSender().send_email(email_title,
+                                         "<h3>Data was successfully imported from file: " + filename + "</h3>" +
+                                         email_body,
+                                         email_recipients)
+                self.backup_file(filename)
+                return "import is successful"
+
+            except BaseException:
+                tb = traceback.format_exc()
+                EmailSender().send_email(email_title,
+                                         "<h3><font color='red'>There was a problem with importing data from file: "
+                                         + filename + "</font></h3><pre>" + tb + "</pre>",
+                                         email_recipients)
+                print(tb)
+                return "import crashed"
 
     def backup_file(self, filename):
         new_file = filename + "-" + datetime.datetime.now().strftime("%Y-%m-%d-%H-%M") + ".bac"
diff --git a/smash/web/migrations/0179_visitimportdata.py b/smash/web/migrations/0179_visitimportdata.py
index 4064c8d2..af88770e 100644
--- a/smash/web/migrations/0179_visitimportdata.py
+++ b/smash/web/migrations/0179_visitimportdata.py
@@ -22,6 +22,7 @@ class Migration(migrations.Migration):
                 ('visit_date_column_name', models.CharField(blank=False, default='dateofvisit', max_length=128, null=False,verbose_name='Visit date column name')),
                 ('location_column_name', models.CharField(blank=False, default='adressofvisit', max_length=128, null=False,verbose_name='Location column name')),
                 ('visit_number_column_name', models.CharField(blank=False, default='visit_id', max_length=128, null=False,verbose_name='Visit number column name')),
+                ('filename', models.CharField(blank=True, default='', max_length=128, null=False, verbose_name='File used for automatic import')),
             ],
         ),
     ]
diff --git a/smash/web/migrations/0180_visitimportdata_migration.py b/smash/web/migrations/0180_visitimportdata_migration.py
index 033ba1e3..830a153a 100644
--- a/smash/web/migrations/0180_visitimportdata_migration.py
+++ b/smash/web/migrations/0180_visitimportdata_migration.py
@@ -1,6 +1,6 @@
 from django.db import migrations
 
-from web.models.constants import IMPORTER_USER, GLOBAL_STUDY_ID, IMPORT_APPOINTMENT_TYPE
+from web.models.constants import IMPORTER_USER, GLOBAL_STUDY_ID, IMPORT_APPOINTMENT_TYPE, DAILY_VISIT_IMPORT_FILE
 
 
 def get_val(apps, item_type: str):
@@ -27,6 +27,7 @@ def create_visit_import_data(apps, schema_editor):
     entry.study = Study.objects.get(pk=GLOBAL_STUDY_ID)
     importer_user_name = get_val(apps, IMPORTER_USER)
     appointment_type_name = get_val(apps, IMPORT_APPOINTMENT_TYPE)
+    import_file = get_val(apps, DAILY_VISIT_IMPORT_FILE)
 
     if importer_user_name is not None:
         workers = Worker.objects.filter(user__username=importer_user_name)
@@ -39,6 +40,7 @@ def create_visit_import_data(apps, schema_editor):
 
         if len(appointment_type) > 0:
             entry.appointment_type = appointment_type[0]
+    entry.filename = import_file
     entry.save()
 
 
diff --git a/smash/web/models/etl/visit_import.py b/smash/web/models/etl/visit_import.py
index 528dd997..8edb99bf 100644
--- a/smash/web/models/etl/visit_import.py
+++ b/smash/web/models/etl/visit_import.py
@@ -1,7 +1,12 @@
 # coding=utf-8
+import logging
+import os
 
+from django.conf import settings
 from django.db import models
 
+logger = logging.getLogger(__name__)
+
 
 class VisitImportData(models.Model):
     study = models.ForeignKey("web.Study",
@@ -50,3 +55,21 @@ class VisitImportData(models.Model):
                                                 null=False,
                                                 blank=False
                                                 )
+    filename = models.CharField(max_length=128,
+                                verbose_name='File used for automatic import',
+                                default='',
+                                null=False,
+                                blank=True
+                                )
+
+    def file_available(self):
+        if self.filename is None or self.filename == '':
+            return False
+        absolute_path = self.get_absolute_file_path()
+        if os.path.basename(absolute_path) != self.filename:
+            logger.warning('File "{}" outside defined ETL_ROOT: {}'.format(self.filename, settings.ETL_ROOT))
+            return False
+        return os.path.isfile(absolute_path)
+
+    def get_absolute_file_path(self) -> str:
+        return os.path.join(settings.ETL_ROOT, self.filename)
diff --git a/smash/web/tests/importer/test_importer_cron_job.py b/smash/web/tests/importer/test_importer_cron_job.py
index b187c89d..8647b487 100644
--- a/smash/web/tests/importer/test_importer_cron_job.py
+++ b/smash/web/tests/importer/test_importer_cron_job.py
@@ -1,17 +1,19 @@
 # coding=utf-8
 
 import logging
+import os
 import tempfile
 from shutil import copyfile
 
+from django.conf import settings
 from django.core import mail
 from django.test import TestCase
 from django_cron.models import CronJobLog
 
 from web.importer import SubjectImporterCronJob, VisitImporterCronJob
-from web.models import ConfigurationItem, Visit
-from web.models.constants import DAILY_SUBJECT_IMPORT_FILE, DAILY_VISIT_IMPORT_FILE
-from web.tests.functions import get_resource_path, get_test_study
+from web.models import ConfigurationItem, Visit, VisitImportData
+from web.models.constants import DAILY_SUBJECT_IMPORT_FILE
+from web.tests.functions import get_resource_path, get_test_study, create_appointment_type, create_worker
 
 logger = logging.getLogger(__name__)
 
@@ -20,6 +22,11 @@ class TestCronJobImporter(TestCase):
 
     def setUp(self):
         self.study = get_test_study()
+        self.study.redcap_first_visit_number=0
+        self.study.save()
+        self.visit_import_data = VisitImportData.objects.create(study=self.study,
+                                                                appointment_type=create_appointment_type(),
+                                                                import_worker=create_worker())
 
     def test_import_without_configuration(self):
         CronJobLog.objects.all().delete()
@@ -53,9 +60,10 @@ class TestCronJobImporter(TestCase):
         new_file, tmp = tempfile.mkstemp()
         copyfile(filename, tmp)
 
-        conf = ConfigurationItem.objects.get(type=DAILY_VISIT_IMPORT_FILE)
-        conf.value = tmp
-        conf.save()
+        settings.ETL_ROOT = os.path.dirname(tmp)
+
+        self.visit_import_data.filename=os.path.basename(tmp)
+        self.visit_import_data.save()
         CronJobLog.objects.all().delete()
 
         job = VisitImporterCronJob(study_id=self.study.id)
diff --git a/smash/web/tests/importer/test_tns_csv_visit_import_reader.py b/smash/web/tests/importer/test_tns_csv_visit_import_reader.py
index 99c36a32..c65ff64c 100644
--- a/smash/web/tests/importer/test_tns_csv_visit_import_reader.py
+++ b/smash/web/tests/importer/test_tns_csv_visit_import_reader.py
@@ -1,7 +1,9 @@
 # coding=utf-8
 
 import logging
+import os
 
+from django.conf import settings
 from django.test import TestCase
 from django.utils import timezone
 
@@ -21,6 +23,7 @@ class TestTnsCsvVisitReader(TestCase):
         study = get_test_study()
         study.redcap_first_visit_number = 0
         study.save()
+        settings.ETL_ROOT = os.path.dirname(get_resource_path('tns_vouchers_import.csv'))
         self.visit_import_data = VisitImportData.objects.create(study=get_test_study(),
                                                                 appointment_type=appointment_type,
                                                                 import_worker=create_worker())
@@ -39,8 +42,8 @@ class TestTnsCsvVisitReader(TestCase):
         logging.getLogger('').removeHandler(self.warning_counter)
 
     def test_load_data(self):
-        filename = get_resource_path('tns_vouchers_import.csv')
-        visits = TnsCsvVisitImportReader(self.visit_import_data).load_data(filename)
+        self.visit_import_data.filename = 'tns_vouchers_import.csv'
+        visits = TnsCsvVisitImportReader(self.visit_import_data).load_data()
         self.assertEqual(3, len(visits))
         visit = Visit.objects.filter(id=visits[0].id)[0]
         self.assertEqual("cov-000111", visit.subject.nd_number)
@@ -57,14 +60,14 @@ class TestTnsCsvVisitReader(TestCase):
         self.assertEqual(0, self.get_warnings_count())
 
     def test_data_provenance_for_update_visit_load_data(self):
-        filename = get_resource_path('tns_vouchers_import.csv')
+        self.visit_import_data.filename = 'tns_vouchers_import.csv'
 
         Visit.objects.create(subject=StudySubject.objects.get(nd_number='cov-000111', study=get_test_study()),
                              visit_number=1,
                              datetime_begin=timezone.now(),
                              datetime_end=timezone.now())
 
-        visit = TnsCsvVisitImportReader(self.visit_import_data).load_data(filename)[0]
+        visit = TnsCsvVisitImportReader(self.visit_import_data).load_data()[0]
 
         self.assertEqual(1, Provenance.objects.filter(modified_table=Visit._meta.db_table,
                                                       modified_table_id=visit.id,
@@ -78,7 +81,7 @@ class TestTnsCsvVisitReader(TestCase):
                                                       previous_value__exact='').count())
 
     def test_data_provenance_for_update_appointment_load_data(self):
-        filename = get_resource_path('tns_vouchers_import.csv')
+        self.visit_import_data.filename = 'tns_vouchers_import.csv'
 
         old_visit = Visit.objects.create(
             subject=StudySubject.objects.get(nd_number='cov-000111', study=get_test_study()),
@@ -89,7 +92,7 @@ class TestTnsCsvVisitReader(TestCase):
         AppointmentTypeLink.objects.create(appointment=old_appointment,
                                            appointment_type=self.visit_import_data.appointment_type)
 
-        TnsCsvVisitImportReader(self.visit_import_data).load_data(filename)
+        TnsCsvVisitImportReader(self.visit_import_data).load_data()
 
         self.assertEqual(1, Provenance.objects.filter(modified_table=Appointment._meta.db_table,
                                                       modified_table_id=old_appointment.id,
@@ -103,9 +106,9 @@ class TestTnsCsvVisitReader(TestCase):
                                                       previous_value__exact='').count())
 
     def test_data_provenance_for_create_appointment_load_data(self):
-        filename = get_resource_path('tns_vouchers_import.csv')
+        self.visit_import_data.filename = 'tns_vouchers_import.csv'
 
-        visit = TnsCsvVisitImportReader(self.visit_import_data).load_data(filename)[0]
+        visit = TnsCsvVisitImportReader(self.visit_import_data).load_data()[0]
         appointment = visit.appointment_set.all()[0]
 
         self.assertEqual(0, Provenance.objects.filter(modified_table=Appointment._meta.db_table,
@@ -120,9 +123,9 @@ class TestTnsCsvVisitReader(TestCase):
                                                       previous_value__exact='').count())
 
     def test_data_provenance_for_create_visit_load_data(self):
-        filename = get_resource_path('tns_vouchers_import.csv')
+        self.visit_import_data.filename = 'tns_vouchers_import.csv'
 
-        visit = TnsCsvVisitImportReader(self.visit_import_data).load_data(filename)[0]
+        visit = TnsCsvVisitImportReader(self.visit_import_data).load_data()[0]
 
         self.assertEqual(0, Provenance.objects.filter(modified_table=Visit._meta.db_table,
                                                       modified_table_id=visit.id,
@@ -136,12 +139,12 @@ class TestTnsCsvVisitReader(TestCase):
                                                       previous_value__exact='').count())
 
     def test_load_data_with_existing_visit(self):
-        filename = get_resource_path('tns_vouchers_import.csv')
+        self.visit_import_data.filename = 'tns_vouchers_import.csv'
         visit = Visit.objects.create(subject=StudySubject.objects.filter(nd_number='cov-000111')[0],
                                      datetime_end=timezone.now(),
                                      datetime_begin=timezone.now(),
                                      visit_number=1)
-        visits = TnsCsvVisitImportReader(self.visit_import_data).load_data(filename)
+        visits = TnsCsvVisitImportReader(self.visit_import_data).load_data()
         visit = Visit.objects.filter(id=visits[0].id)[0]
         self.assertEqual("cov-000111", visit.subject.nd_number)
 
@@ -156,7 +159,7 @@ class TestTnsCsvVisitReader(TestCase):
         self.assertEqual(0, self.get_warnings_count())
 
     def test_load_data_with_existing_visit_and_appointment(self):
-        filename = get_resource_path('tns_vouchers_import.csv')
+        self.visit_import_data.filename = 'tns_vouchers_import.csv'
         visit = Visit.objects.create(subject=StudySubject.objects.filter(nd_number='cov-000111')[0],
                                      datetime_end=timezone.now(),
                                      datetime_begin=timezone.now(),
@@ -167,7 +170,7 @@ class TestTnsCsvVisitReader(TestCase):
         AppointmentTypeLink.objects.create(appointment_id=appointment.id,
                                            appointment_type=AppointmentType.objects.filter(code="SAMPLE_2")[0])
 
-        visits = TnsCsvVisitImportReader(self.visit_import_data).load_data(filename)
+        visits = TnsCsvVisitImportReader(self.visit_import_data).load_data()
         visit = Visit.objects.filter(id=visits[0].id)[0]
         self.assertEqual("cov-000111", visit.subject.nd_number)
 
@@ -184,9 +187,9 @@ class TestTnsCsvVisitReader(TestCase):
         self.assertEqual(0, self.get_warnings_count())
 
     def test_load_data_with_visit_and_no_previous_visits(self):
-        filename = get_resource_path('tns_vouchers_3_import.csv')
+        self.visit_import_data.filename = 'tns_vouchers_3_import.csv'
 
-        visits = TnsCsvVisitImportReader(self.visit_import_data).load_data(filename)
+        visits = TnsCsvVisitImportReader(self.visit_import_data).load_data()
 
         subject_visits = Visit.objects.filter(subject=StudySubject.objects.filter(nd_number='cov-000111')[0])
 
@@ -206,9 +209,9 @@ class TestTnsCsvVisitReader(TestCase):
         self.assertEqual(2, self.get_warnings_count())
 
     def test_load_data_with_no_subject(self):
-        filename = get_resource_path('tns_vouchers_import.csv')
+        self.visit_import_data.filename = 'tns_vouchers_import.csv'
         StudySubject.objects.filter(nd_number="cov-000111").delete()
-        visits = TnsCsvVisitImportReader(self.visit_import_data).load_data(filename)
+        visits = TnsCsvVisitImportReader(self.visit_import_data).load_data()
         self.assertEqual(3, len(visits))
         visit = Visit.objects.filter(id=visits[0].id)[0]
         self.assertEqual("cov-000111", visit.subject.nd_number)
@@ -217,11 +220,11 @@ class TestTnsCsvVisitReader(TestCase):
         self.assertEqual(0, self.get_warnings_count())
 
     def test_dont_add_links_for_existing_appointments(self):
-        filename = get_resource_path('tns_vouchers_import.csv')
-        TnsCsvVisitImportReader(self.visit_import_data).load_data(filename)
+        self.visit_import_data.filename = 'tns_vouchers_import.csv'
+        TnsCsvVisitImportReader(self.visit_import_data).load_data()
         links = AppointmentTypeLink.objects.all().count()
 
-        TnsCsvVisitImportReader(self.visit_import_data).load_data(filename)
+        TnsCsvVisitImportReader(self.visit_import_data).load_data()
         self.assertEqual(links, AppointmentTypeLink.objects.all().count())
 
         self.assertEqual(0, self.get_warnings_count())
-- 
GitLab