diff --git a/smash/web/importer/csv_subject_import_reader.py b/smash/web/importer/csv_subject_import_reader.py index 7acb8d5cfd787158ed7540c8a231cf3b68e375c7..42d37ddf926d88fbdb015fd4c649fec40c7902d1 100644 --- a/smash/web/importer/csv_subject_import_reader.py +++ b/smash/web/importer/csv_subject_import_reader.py @@ -27,8 +27,10 @@ class CsvSubjectImportReader(SubjectImportReader): headers = next(reader, None) for row in reader: subject = Subject() + subject.country = self.import_data.country study_subject = StudySubject() study_subject.subject = subject + study_subject.default_location = self.import_data.location study_subject.study = self.import_data.study for header, value in zip(headers, row): self.add_data(study_subject, header, value) @@ -54,7 +56,6 @@ class CsvSubjectImportReader(SubjectImportReader): setattr(study_subject.subject, field.name, self.get_new_value(old_val, value)) elif table == StudySubject: old_val = getattr(study_subject, field.name) - print(field.name + ": " + str(old_val) + " - " + str(value)) setattr(study_subject, field.name, self.get_new_value(old_val, value)) else: logger.warning("Don't know how to handle column " + column_name + " with data " + value) diff --git a/smash/web/migrations/0179_visitimportdata.py b/smash/web/migrations/0179_visitimportdata.py index 6dab8df5ad90d5b2c9bdc0e00a45526703494016..adaf96410d26989545e3cb8a4e4a62a72a4ff2bc 100644 --- a/smash/web/migrations/0179_visitimportdata.py +++ b/smash/web/migrations/0179_visitimportdata.py @@ -47,6 +47,8 @@ class Migration(migrations.Migration): name='SubjectImportData', fields=[ ('etldata_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, primary_key=True, serialize=False, to='web.EtlData')), + ('country', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to='web.Country', verbose_name='Default country')), + ('location', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to='web.Location', verbose_name='Default location')), ], bases=('web.etldata',), ), diff --git a/smash/web/models/etl/subject_import.py b/smash/web/models/etl/subject_import.py index 5884586995e0b39eeb05c7b07caaa483ea6394ed..066e30fa38c4af22386c5b8195e282a5b6e30a40 100644 --- a/smash/web/models/etl/subject_import.py +++ b/smash/web/models/etl/subject_import.py @@ -1,10 +1,23 @@ # coding=utf-8 import logging +from django.db import models + from web.models.etl.etl import EtlData logger = logging.getLogger(__name__) class SubjectImportData(EtlData): - pass + location = models.ForeignKey("web.Location", + verbose_name='Default location', + blank=True, + null=True, + on_delete=models.SET_NULL + ) + country = models.ForeignKey("web.Country", + verbose_name='Default country', + blank=True, + null=True, + on_delete=models.SET_NULL + ) diff --git a/smash/web/tests/importer/test_csv_subject_import_reader.py b/smash/web/tests/importer/test_csv_subject_import_reader.py index a304d33548aa0cca8b3a1cd91295094b8ed56431..c80ec3a77cd86428c38928dda667844f3f63fb95 100644 --- a/smash/web/tests/importer/test_csv_subject_import_reader.py +++ b/smash/web/tests/importer/test_csv_subject_import_reader.py @@ -5,8 +5,9 @@ import logging from django.test import TestCase from web.importer import CsvSubjectImportReader, MsgCounterHandler -from web.models import SubjectImportData, EtlColumnMapping, StudySubject -from web.tests.functions import get_resource_path, get_test_study, create_tns_column_mapping +from web.models import SubjectImportData, EtlColumnMapping, StudySubject, Country +from web.models.constants import COUNTRY_AFGHANISTAN_ID +from web.tests.functions import get_resource_path, get_test_study, create_tns_column_mapping, create_location logger = logging.getLogger(__name__) @@ -84,3 +85,17 @@ class TestCsvReader(TestCase): return self.warning_counter.level2count["WARNING"] else: return 0 + + def test_load_default_country(self): + self.subject_import_data.filename = get_resource_path('import.csv') + self.subject_import_data.country = Country.objects.get(pk=COUNTRY_AFGHANISTAN_ID) + study_subjects = CsvSubjectImportReader(self.subject_import_data).load_data() + + self.assertEqual(study_subjects[0].subject.country, Country.objects.get(pk=COUNTRY_AFGHANISTAN_ID)) + + def test_load_default_location(self): + self.subject_import_data.filename = get_resource_path('import.csv') + self.subject_import_data.location = create_location() + study_subjects = CsvSubjectImportReader(self.subject_import_data).load_data() + + self.assertEqual(study_subjects[0].default_location, self.subject_import_data.location)