From a74d7d6c008125b283f45632e7ad5c1268266576 Mon Sep 17 00:00:00 2001
From: Piotr Gawron <piotr.gawron@uni.lu>
Date: Mon, 25 Jan 2021 15:29:05 +0100
Subject: [PATCH] allow to import language for written communication

---
 .../web/importer/csv_subject_import_reader.py | 22 +++++++++++++++++--
 smash/web/tests/data/import_language.csv      |  4 ++++
 .../test_csv_subject_import_reader.py         | 12 ++++++++++
 3 files changed, 36 insertions(+), 2 deletions(-)
 create mode 100644 smash/web/tests/data/import_language.csv

diff --git a/smash/web/importer/csv_subject_import_reader.py b/smash/web/importer/csv_subject_import_reader.py
index f80262c1..2e73d303 100644
--- a/smash/web/importer/csv_subject_import_reader.py
+++ b/smash/web/importer/csv_subject_import_reader.py
@@ -5,7 +5,7 @@ from typing import List, Type, Tuple
 from django.db import models
 from django.db.models import Field
 
-from web.models import StudySubject, Subject, SubjectImportData
+from web.models import StudySubject, Subject, SubjectImportData, Language
 from .etl_common import EtlCommon
 from .subject_import_reader import SubjectImportReader
 
@@ -52,6 +52,9 @@ class CsvSubjectImportReader(SubjectImportReader):
         if field.get_internal_type() == "DateField":
             value = self.get_date(value)
 
+        if field.get_internal_type() == "ForeignKey":
+            value = self.get_value_for_foreign_field(field, value)
+
         if table == Subject:
             old_val = getattr(study_subject.subject, field.name)
             setattr(study_subject.subject, field.name, self.get_new_value(old_val, value))
@@ -61,6 +64,20 @@ class CsvSubjectImportReader(SubjectImportReader):
         else:
             logger.warning("Don't know how to handle column " + column_name + " with data " + value)
 
+    @staticmethod
+    def get_value_for_foreign_field(field, value):
+        if field.related_model == Language:
+            if value == "":
+                return None
+            else:
+                language = Language.objects.filter(name=value).first()
+                if language is None:
+                    language = Language.objects.create(name=value)
+                return language
+        else:
+            logger.warning("Don't know how to handle type " + str(field.related_model))
+            return None
+
     def get_table_and_field(self, column_name: str) -> Tuple[Type[models.Model], Field]:
         return self.mappings.get(column_name, (None, None))
 
@@ -68,7 +85,8 @@ class CsvSubjectImportReader(SubjectImportReader):
         for field in object_type._meta.get_fields():
             if field.get_internal_type() == "CharField" or \
                     field.get_internal_type() == "DateField" or \
-                    field.get_internal_type() == "TextField":
+                    field.get_internal_type() == "TextField" or \
+                    (field.get_internal_type() == "ForeignKey" and field.related_model in (Language,)):
                 found = False
                 for mapping in self.import_data.column_mappings.all():
                     if mapping.table_name == object_type._meta.db_table and field.name == mapping.column_name:
diff --git a/smash/web/tests/data/import_language.csv b/smash/web/tests/data/import_language.csv
new file mode 100644
index 00000000..c50d0355
--- /dev/null
+++ b/smash/web/tests/data/import_language.csv
@@ -0,0 +1,4 @@
+first_name,last_name,participant_id,default_written_communication_language
+Piotr,Gawron,Cov-000001,English
+Piotr,Gawron,Cov-000002,Polish
+Piotr,Gawron,Cov-000003,
\ No newline at end of file
diff --git a/smash/web/tests/importer/test_csv_subject_import_reader.py b/smash/web/tests/importer/test_csv_subject_import_reader.py
index c80ec3a7..792331b8 100644
--- a/smash/web/tests/importer/test_csv_subject_import_reader.py
+++ b/smash/web/tests/importer/test_csv_subject_import_reader.py
@@ -51,6 +51,18 @@ class TestCsvReader(TestCase):
         self.assertIsNone(study_subjects[1].subject.date_born)
         self.assertIsNone(study_subjects[2].subject.date_born)
 
+    def test_load_language(self):
+        self.subject_import_data.filename = get_resource_path('import_language.csv')
+        study_subjects = CsvSubjectImportReader(self.subject_import_data).load_data()
+        for study_subject in study_subjects:
+            study_subject.subject.save()
+            study_subject.save()
+
+        self.assertEqual(3, len(study_subjects))
+        self.assertIsNotNone(study_subjects[0].subject.default_written_communication_language)
+        self.assertIsNotNone(study_subjects[1].subject.default_written_communication_language)
+        self.assertIsNone(study_subjects[2].subject.default_written_communication_language)
+
     def test_load_data_for_tns(self):
         self.subject_import_data = SubjectImportData.objects.create(study=get_test_study(),
                                                                     date_format="%d/%m/%Y",
-- 
GitLab