source: main/trunk/openPLM/plmapp/tests/csvimport.py @ 475

Revision 475, 9.3 KB checked in by pcosquer, 8 years ago (diff)

tests: csvimport: rebuild the search index and disconnect the task_prerun signal

Line 
1
2import os.path
3import shutil
4import cStringIO, StringIO
5from collections import defaultdict
6
7from django.conf import settings
8from django.core import mail
9from django.contrib.auth.models import User
10from django.test import TransactionTestCase
11from django.core.management import call_command
12
13from celery.signals import task_prerun
14
15from openPLM.plmapp.models import GroupInfo, PLMObject, ParentChildLink
16from openPLM.plmapp.csvimport import PLMObjectsImporter, BOMImporter,\
17        CSVImportError
18from openPLM.plmapp.base_views import get_obj
19from openPLM.plmapp.unicodecsv import UnicodeWriter
20from openPLM.plmapp.forms import CSVForm
21
22
23class CSVImportTestCase(TransactionTestCase):
24
25    def setUp(self):
26        super(CSVImportTestCase, self).setUp()
27        self.sent_tasks = defaultdict(list)
28        self.cie = User.objects.create(username="company")
29        p = self.cie.get_profile()
30        p.is_contributor = True
31        p.save()
32        self.leading_group = GroupInfo.objects.create(name="leading_group",
33                owner=self.cie, creator=self.cie)
34        self.cie.groups.add(self.leading_group)
35        self.user = User(username="user")
36        self.user.email = "test@example.net"
37        self.user.set_password("password")
38        self.user.save()
39        self.user.get_profile().is_contributor = True
40        self.user.get_profile().save()
41        self.group = GroupInfo(name="grp", owner=self.user, creator=self.user,
42                description="grp")
43        self.group.save()
44        self.user.groups.add(self.group)
45        self.client.post("/login/", {'username' : 'user', 'password' : 'password'})
46        task_prerun.connect(self.task_sent_handler)
47        call_command("rebuild_index", interactive=False, verbosity=0)
48
49    def task_sent_handler(self, sender=None, task_id=None, task=None, args=None,
50                      kwargs=None, **kwds):
51        self.sent_tasks[task.name].append(task)
52
53    def tearDown(self):
54        super(CSVImportTestCase, self).tearDown()
55        task_prerun.disconnect(self.task_sent_handler)
56        if os.path.exists(settings.HAYSTACK_XAPIAN_PATH):
57            shutil.rmtree(settings.HAYSTACK_XAPIAN_PATH)
58       
59    def get_valid_rows(self):
60        return [[u'Type',
61              u'reference',
62              u'revision',
63              u'name',
64              u'supplier',
65              u'group',
66              u'lifecycle'],
67             [u'Part',
68              u'p1',
69              u'a',
70              u'Part1',
71              u'Moi',
72              self.group.name,
73              u'draft_official_deprecated'],
74             [u'Document',
75              u'd1',
76              u'2',
77              u'Document1',
78              u'',
79              self.group.name,
80              u'draft_official_deprecated'],
81             [u'Document',
82              u'd2',
83              u'7',
84              u'Document 2',
85              u'',
86              self.group.name,
87              u'draft_official_deprecated'],
88             [u'SinglePart',
89              u'sp1',
90              u's',
91              u'SP1',
92              u'Lui',
93              self.group.name,
94              u'draft_official_deprecated'],
95             [u'SinglePart',
96              u'sp2',
97              u's',
98              u'SP2',
99              u'Lui',
100              self.group.name,
101              u'draft_official_deprecated'],
102             ]
103
104
105    def import_csv(self, Importer, rows):
106        csv_file = cStringIO.StringIO()
107        UnicodeWriter(csv_file).writerows(rows)
108        csv_file.seek(0)
109        importer = Importer(csv_file, self.user)
110        headers = importer.get_preview().guessed_headers
111        objects = importer.import_csv(headers)
112        return objects
113
114    def test_import_valid(self):
115        csv_rows = self.get_valid_rows()
116        objects = self.import_csv(PLMObjectsImporter, csv_rows)
117        self.assertEquals(len(csv_rows) - 1, len(objects))
118        sp1 = get_obj("SinglePart", "sp1", "s", self.user)
119        self.assertEquals("SP1", sp1.name)
120        self.assertEqual(len(mail.outbox), len(objects))
121        self.assertEqual(1, len(self.sent_tasks["openPLM.plmapp.tasks.update_indexes"]))
122
123    def test_import_csv_invalid_last_row(self):
124        """
125        Tests that an import with an invalid row doest not modify
126        the database.
127        """
128        csv_rows = self.get_valid_rows()
129        csv_rows.append(["BadType", "bt", "1", "BT",
130            self.group.name, u'draft_official_deprecated'])
131        plmobjects = list(PLMObject.objects.all())
132        self.assertRaises(CSVImportError, self.import_csv,
133                PLMObjectsImporter, csv_rows)
134        self.assertEquals(plmobjects, list(PLMObject.objects.all()))
135        self.assertEqual(len(mail.outbox), 0)
136        self.assertFalse(self.sent_tasks["openPLM.plmapp.tasks.update_indexes"])
137
138    def get_valid_bom(self):
139        return [["parent-type", "parent-reference", "parent-revision",
140                 "child-type", "child-reference", "child-revision",
141                 "quantity", "order"],
142                ["Part", "p1", "a", "SinglePart", "sp1", "s", "10", "15"],
143                ["SinglePart", "sp1", "s", "SinglePart", "sp2", "s", "10.5", "16"],
144                ]
145
146    def test_import_bom_valid(self):
147        """
148        Tests an import of a valid bom.
149        """
150        self.import_csv(PLMObjectsImporter, self.get_valid_rows())
151        csv_rows = self.get_valid_bom()
152        objects = self.import_csv(BOMImporter, csv_rows)
153        # objects should be [parent1, child1, ...]
154        self.assertEquals((len(csv_rows) - 1) * 2, len(objects))
155
156        # first row
157        parent = get_obj("Part", "p1", "a", self.user)
158        child = get_obj("SinglePart", "sp1", "s", self.user)
159        c = parent.get_children()[0]
160        self.assertEquals(c.link.parent.id, parent.id)
161        self.assertEquals(c.link.child.id, child.id)
162        self.assertEquals(c.link.quantity, 10)
163        self.assertEquals(c.link.order, 15)
164
165        # second row
166        parent = get_obj("SinglePart", "sp1", "s", self.user)
167        child = get_obj("SinglePart", "sp2", "s", self.user)
168        c = parent.get_children()[0]
169        self.assertEquals(c.link.parent.id, parent.id)
170        self.assertEquals(c.link.child.id, child.id)
171        self.assertEquals(c.link.quantity, 10.5)
172        self.assertEquals(c.link.order, 16)
173
174    def test_import_bom_invalid_order(self):
175        """
176        Tests an import of an invalid bom: invalid order.
177        """
178        self.import_csv(PLMObjectsImporter, self.get_valid_rows())
179        csv_rows = self.get_valid_bom()
180        csv_rows[-1][-1] = "not an integer"
181        self.assertRaises(CSVImportError, self.import_csv,
182                          BOMImporter, csv_rows)
183        self.assertEquals(0, len(ParentChildLink.objects.all()))
184   
185    def test_import_bom_invalid_parent(self):
186        """
187        Tests an import of an invalid bom: invalid parent.
188        """
189        self.import_csv(PLMObjectsImporter, self.get_valid_rows())
190        csv_rows = self.get_valid_bom()
191        csv_rows[1][0] = "not an type"
192        self.assertRaises(CSVImportError, self.import_csv,
193                          BOMImporter, csv_rows)
194        self.assertEquals(0, len(ParentChildLink.objects.all()))
195
196    def test_import_bom_invalid_duplicated_row(self):
197        """
198        Tests an import of an invalid bom: a row is duplicated.
199        """
200        self.import_csv(PLMObjectsImporter, self.get_valid_rows())
201        csv_rows = self.get_valid_bom()
202        csv_rows.append(csv_rows[-1])
203        self.assertRaises(CSVImportError, self.import_csv,
204                          BOMImporter, csv_rows)
205        self.assertEquals(0, len(ParentChildLink.objects.all()))
206
207    def test_view_init_get(self):
208        response = self.client.get("/import/csv/")
209        self.assertEqual(200, response.status_code)
210        self.assertEqual(1, response.context["step"])
211        self.assertEqual("csv", response.context["target"])
212        form = response.context["csv_form"]
213        self.assertTrue(isinstance(form, CSVForm))
214
215    def test_view_csv_all(self):
216        """
217        Complex test that simulate an upload of a csv file (complete process).
218        """
219        # upload a csv file
220        csv_file = StringIO.StringIO()
221        csv_file.name = "data.csv"
222        UnicodeWriter(csv_file).writerows(self.get_valid_rows())
223        csv_file.seek(0)
224        response = self.client.post("/import/csv/", {"encoding":"utf_8",
225            "filename":"data.csv", "file":csv_file}, follow=True)
226        csv_file.close()
227
228        # load the second page
229        url = response.redirect_chain[0][0]
230        response2 = self.client.get(url)
231        self.assertEquals(2, response2.context["step"])
232        preview = response2.context["preview"]
233        self.assertFalse(None in preview.guessed_headers)
234        formset = response2.context["headers_formset"]
235
236        # validate and import the file
237        data = {}
238        for key, value in formset.management_form.initial.iteritems():
239            data["form-" + key] = value or ""
240        for i, d in enumerate(formset.initial):
241            for key, value in d.iteritems():
242                data["form-%d-%s" % (i, key)] = value
243            data['form-%d-ORDER' % i] = str(i)
244        response3 = self.client.post(url, data, follow=True)
245        url_done = response3.redirect_chain[-1][0]
246        self.assertEquals("http://testserver/import/done/", url_done)
247        # check an item
248        sp1 = get_obj("SinglePart", "sp1", "s", self.user)
249        self.assertEquals("SP1", sp1.name)
250
251
252
Note: See TracBrowser for help on using the repository browser.