From f2595bcb7666db74363320da88a8cd32f0df285f Mon Sep 17 00:00:00 2001
From: Piotr Gawron <piotr.gawron@uni.lu>
Date: Thu, 26 Nov 2020 10:43:34 +0100
Subject: [PATCH] default location and country is selectanble on subject import

---
 .../web/importer/csv_subject_import_reader.py |  3 ++-
 smash/web/migrations/0179_visitimportdata.py  |  2 ++
 smash/web/models/etl/subject_import.py        | 15 ++++++++++++++-
 .../test_csv_subject_import_reader.py         | 19 +++++++++++++++++--
 4 files changed, 35 insertions(+), 4 deletions(-)

diff --git a/smash/web/importer/csv_subject_import_reader.py b/smash/web/importer/csv_subject_import_reader.py
index 7acb8d5c..42d37ddf 100644
--- a/smash/web/importer/csv_subject_import_reader.py
+++ b/smash/web/importer/csv_subject_import_reader.py
@@ -27,8 +27,10 @@ class CsvSubjectImportReader(SubjectImportReader):
             headers = next(reader, None)
             for row in reader:
                 subject = Subject()
+                subject.country = self.import_data.country
                 study_subject = StudySubject()
                 study_subject.subject = subject
+                study_subject.default_location = self.import_data.location
                 study_subject.study = self.import_data.study
                 for header, value in zip(headers, row):
                     self.add_data(study_subject, header, value)
@@ -54,7 +56,6 @@ class CsvSubjectImportReader(SubjectImportReader):
             setattr(study_subject.subject, field.name, self.get_new_value(old_val, value))
         elif table == StudySubject:
             old_val = getattr(study_subject, field.name)
-            print(field.name + ": " + str(old_val) + " - " + str(value))
             setattr(study_subject, field.name, self.get_new_value(old_val, value))
         else:
             logger.warning("Don't know how to handle column " + column_name + " with data " + value)
diff --git a/smash/web/migrations/0179_visitimportdata.py b/smash/web/migrations/0179_visitimportdata.py
index 6dab8df5..adaf9641 100644
--- a/smash/web/migrations/0179_visitimportdata.py
+++ b/smash/web/migrations/0179_visitimportdata.py
@@ -47,6 +47,8 @@ class Migration(migrations.Migration):
             name='SubjectImportData',
             fields=[
                 ('etldata_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, primary_key=True, serialize=False, to='web.EtlData')),
+                ('country', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to='web.Country', verbose_name='Default country')),
+                ('location', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to='web.Location', verbose_name='Default location')),
             ],
             bases=('web.etldata',),
         ),
diff --git a/smash/web/models/etl/subject_import.py b/smash/web/models/etl/subject_import.py
index 58845869..066e30fa 100644
--- a/smash/web/models/etl/subject_import.py
+++ b/smash/web/models/etl/subject_import.py
@@ -1,10 +1,23 @@
 # coding=utf-8
 import logging
 
+from django.db import models
+
 from web.models.etl.etl import EtlData
 
 logger = logging.getLogger(__name__)
 
 
 class SubjectImportData(EtlData):
-    pass
+    location = models.ForeignKey("web.Location",
+                                 verbose_name='Default location',
+                                 blank=True,
+                                 null=True,
+                                 on_delete=models.SET_NULL
+                                 )
+    country = models.ForeignKey("web.Country",
+                                verbose_name='Default country',
+                                blank=True,
+                                null=True,
+                                on_delete=models.SET_NULL
+                                )
diff --git a/smash/web/tests/importer/test_csv_subject_import_reader.py b/smash/web/tests/importer/test_csv_subject_import_reader.py
index a304d335..c80ec3a7 100644
--- a/smash/web/tests/importer/test_csv_subject_import_reader.py
+++ b/smash/web/tests/importer/test_csv_subject_import_reader.py
@@ -5,8 +5,9 @@ import logging
 from django.test import TestCase
 
 from web.importer import CsvSubjectImportReader, MsgCounterHandler
-from web.models import SubjectImportData, EtlColumnMapping, StudySubject
-from web.tests.functions import get_resource_path, get_test_study, create_tns_column_mapping
+from web.models import SubjectImportData, EtlColumnMapping, StudySubject, Country
+from web.models.constants import COUNTRY_AFGHANISTAN_ID
+from web.tests.functions import get_resource_path, get_test_study, create_tns_column_mapping, create_location
 
 logger = logging.getLogger(__name__)
 
@@ -84,3 +85,17 @@ class TestCsvReader(TestCase):
             return self.warning_counter.level2count["WARNING"]
         else:
             return 0
+
+    def test_load_default_country(self):
+        self.subject_import_data.filename = get_resource_path('import.csv')
+        self.subject_import_data.country = Country.objects.get(pk=COUNTRY_AFGHANISTAN_ID)
+        study_subjects = CsvSubjectImportReader(self.subject_import_data).load_data()
+
+        self.assertEqual(study_subjects[0].subject.country, Country.objects.get(pk=COUNTRY_AFGHANISTAN_ID))
+
+    def test_load_default_location(self):
+        self.subject_import_data.filename = get_resource_path('import.csv')
+        self.subject_import_data.location = create_location()
+        study_subjects = CsvSubjectImportReader(self.subject_import_data).load_data()
+
+        self.assertEqual(study_subjects[0].default_location, self.subject_import_data.location)
-- 
GitLab