source: main/branches/3D/openPLM/plmapp/tests/csvimport.py @ 662

Revision 662, 11.2 KB checked in by pcosquer, 8 years ago (diff)

3D branch: merge changes from trunk (rev [661])

Line 
1
2import cStringIO, StringIO
3from collections import defaultdict
4
5from django.core import mail
6from django.contrib.auth.models import User
7from django.test import TransactionTestCase
8
9from celery.signals import task_prerun
10
11import openPLM.plmapp.models as models
12from openPLM.plmapp.models import GroupInfo, PLMObject, ParentChildLink
13from openPLM.plmapp.csvimport import PLMObjectsImporter, BOMImporter,\
14        CSVImportError, UsersImporter
15from openPLM.plmapp.base_views import get_obj
16from openPLM.plmapp.unicodecsv import UnicodeWriter
17from openPLM.plmapp.forms import CSVForm
18
19
20class CSVImportTestCase(TransactionTestCase):
21
22    def setUp(self):
23        super(CSVImportTestCase, self).setUp()
24        self.sent_tasks = defaultdict(list)
25        self.cie = User.objects.create(username="company")
26        p = self.cie.get_profile()
27        p.is_contributor = True
28        p.save()
29        self.leading_group = GroupInfo.objects.create(name="leading_group",
30                owner=self.cie, creator=self.cie)
31        self.cie.groups.add(self.leading_group)
32        self.user = User(username="user")
33        self.user.email = "test@example.net"
34        self.user.set_password("password")
35        self.user.save()
36        self.user.get_profile().is_contributor = True
37        self.user.get_profile().save()
38        self.group = GroupInfo(name="grp", owner=self.user, creator=self.user,
39                description="grp")
40        self.group.save()
41        self.user.groups.add(self.group)
42        self.client.post("/login/", {'username' : 'user', 'password' : 'password'})
43        task_prerun.connect(self.task_sent_handler)
44
45    def task_sent_handler(self, sender=None, task_id=None, task=None, args=None,
46                      kwargs=None, **kwds):
47        self.sent_tasks[task.name].append(task)
48
49    def tearDown(self):
50        super(CSVImportTestCase, self).tearDown()
51        task_prerun.disconnect(self.task_sent_handler)
52        from haystack import backend
53        backend.SearchBackend.inmemory_db = None
54   
55    def get_valid_rows(self):
56        return [[u'Type',
57              u'reference',
58              u'revision',
59              u'name',
60              u'supplier',
61              u'group',
62              u'lifecycle'],
63             [u'Part',
64              u'p1',
65              u'a',
66              u'Part1',
67              u'Moi',
68              self.group.name,
69              u'draft_official_deprecated'],
70             [u'Document',
71              u'd1',
72              u'2',
73              u'Document1',
74              u'',
75              self.group.name,
76              u'draft_official_deprecated'],
77             [u'Document',
78              u'd2',
79              u'7',
80              u'Document 2',
81              u'',
82              self.group.name,
83              u'draft_official_deprecated'],
84             [u'SinglePart',
85              u'sp1',
86              u's',
87              u'SP1',
88              u'Lui',
89              self.group.name,
90              u'draft_official_deprecated'],
91             [u'SinglePart',
92              u'sp2',
93              u's',
94              u'SP2',
95              u'Lui',
96              self.group.name,
97              u'draft_official_deprecated'],
98             ]
99
100
101    def import_csv(self, Importer, rows):
102        csv_file = cStringIO.StringIO()
103        UnicodeWriter(csv_file).writerows(rows)
104        csv_file.seek(0)
105        importer = Importer(csv_file, self.user)
106        headers = importer.get_preview().guessed_headers
107        objects = importer.import_csv(headers)
108        return objects
109
110    def test_import_valid(self):
111        csv_rows = self.get_valid_rows()
112        objects = self.import_csv(PLMObjectsImporter, csv_rows)
113        self.assertEquals(len(csv_rows) - 1, len(objects))
114        sp1 = get_obj("SinglePart", "sp1", "s", self.user)
115        self.assertEquals("SP1", sp1.name)
116        self.assertEqual(len(mail.outbox), len(objects))
117        self.assertEqual(1, len(self.sent_tasks["openPLM.plmapp.tasks.update_indexes"]))
118
119    def test_import_csv_invalid_last_row(self):
120        """
121        Tests that an import with an invalid row doest not modify
122        the database.
123        """
124        csv_rows = self.get_valid_rows()
125        csv_rows.append(["BadType", "bt", "1", "BT",
126            self.group.name, u'draft_official_deprecated'])
127        plmobjects = list(PLMObject.objects.all())
128        self.assertRaises(CSVImportError, self.import_csv,
129                PLMObjectsImporter, csv_rows)
130        self.assertEquals(plmobjects, list(PLMObject.objects.all()))
131        self.assertEqual(len(mail.outbox), 0)
132        self.assertFalse(self.sent_tasks["openPLM.plmapp.tasks.update_indexes"])
133
134    def get_valid_bom(self):
135        return [["parent-type", "parent-reference", "parent-revision",
136                 "child-type", "child-reference", "child-revision",
137                 "quantity", "order"],
138                ["Part", "p1", "a", "SinglePart", "sp1", "s", "10", "15"],
139                ["SinglePart", "sp1", "s", "SinglePart", "sp2", "s", "10.5", "16"],
140                ]
141
142    def test_import_bom_valid(self):
143        """
144        Tests an import of a valid bom.
145        """
146        self.import_csv(PLMObjectsImporter, self.get_valid_rows())
147        csv_rows = self.get_valid_bom()
148        objects = self.import_csv(BOMImporter, csv_rows)
149        # objects should be [parent1, child1, ...]
150        self.assertEquals((len(csv_rows) - 1) * 2, len(objects))
151
152        # first row
153        parent = get_obj("Part", "p1", "a", self.user)
154        child = get_obj("SinglePart", "sp1", "s", self.user)
155        c = parent.get_children()[0]
156        self.assertEquals(c.link.parent.id, parent.id)
157        self.assertEquals(c.link.child.id, child.id)
158        self.assertEquals(c.link.quantity, 10)
159        self.assertEquals(c.link.order, 15)
160
161        # second row
162        parent = get_obj("SinglePart", "sp1", "s", self.user)
163        child = get_obj("SinglePart", "sp2", "s", self.user)
164        c = parent.get_children()[0]
165        self.assertEquals(c.link.parent.id, parent.id)
166        self.assertEquals(c.link.child.id, child.id)
167        self.assertEquals(c.link.quantity, 10.5)
168        self.assertEquals(c.link.order, 16)
169
170    def test_import_bom_invalid_order(self):
171        """
172        Tests an import of an invalid bom: invalid order.
173        """
174        self.import_csv(PLMObjectsImporter, self.get_valid_rows())
175        csv_rows = self.get_valid_bom()
176        csv_rows[-1][-1] = "not an integer"
177        self.assertRaises(CSVImportError, self.import_csv,
178                          BOMImporter, csv_rows)
179        self.assertEquals(0, len(ParentChildLink.objects.all()))
180   
181    def test_import_bom_invalid_parent(self):
182        """
183        Tests an import of an invalid bom: invalid parent.
184        """
185        self.import_csv(PLMObjectsImporter, self.get_valid_rows())
186        csv_rows = self.get_valid_bom()
187        csv_rows[1][0] = "not an type"
188        self.assertRaises(CSVImportError, self.import_csv,
189                          BOMImporter, csv_rows)
190        self.assertEquals(0, len(ParentChildLink.objects.all()))
191
192    def test_import_bom_invalid_duplicated_row(self):
193        """
194        Tests an import of an invalid bom: a row is duplicated.
195        """
196        self.import_csv(PLMObjectsImporter, self.get_valid_rows())
197        csv_rows = self.get_valid_bom()
198        csv_rows.append(csv_rows[-1])
199        self.assertRaises(CSVImportError, self.import_csv,
200                          BOMImporter, csv_rows)
201        self.assertEquals(0, len(ParentChildLink.objects.all()))
202
203    def test_view_init_get(self):
204        response = self.client.get("/import/csv/")
205        self.assertEqual(200, response.status_code)
206        self.assertEqual(1, response.context["step"])
207        self.assertEqual("csv", response.context["target"])
208        form = response.context["csv_form"]
209        self.assertTrue(isinstance(form, CSVForm))
210
211    def test_view_csv_all(self):
212        """
213        Complex test that simulate an upload of a csv file (complete process).
214        """
215        # upload a csv file
216        csv_file = StringIO.StringIO()
217        csv_file.name = "data.csv"
218        UnicodeWriter(csv_file).writerows(self.get_valid_rows())
219        csv_file.seek(0)
220        response = self.client.post("/import/csv/", {"encoding":"utf_8",
221            "filename":"data.csv", "file":csv_file}, follow=True)
222        csv_file.close()
223
224        # load the second page
225        url = response.redirect_chain[0][0]
226        response2 = self.client.get(url)
227        self.assertEquals(2, response2.context["step"])
228        preview = response2.context["preview"]
229        self.assertFalse(None in preview.guessed_headers)
230        formset = response2.context["headers_formset"]
231
232        # validate and import the file
233        data = {}
234        for key, value in formset.management_form.initial.iteritems():
235            data["form-" + key] = value or ""
236        for i, d in enumerate(formset.initial):
237            for key, value in d.iteritems():
238                data["form-%d-%s" % (i, key)] = value
239            data['form-%d-ORDER' % i] = str(i)
240        response3 = self.client.post(url, data, follow=True)
241        url_done = response3.redirect_chain[-1][0]
242        self.assertEquals("http://testserver/import/done/", url_done)
243        # check an item
244        sp1 = get_obj("SinglePart", "sp1", "s", self.user)
245        self.assertEquals("SP1", sp1.name)
246
247    def get_users_rows(self):
248        return [['username', 'first_name', 'last_name', 'email', 'groups'],
249                ['user_1', 'fn1', 'ln1', 'user_1@example.net', 'grp'],
250                ['user_2', 'fn2', 'ln2', 'user_2@example.net', 'grp'],
251                ['user_3', 'fn3', 'ln3', 'user_3@example.net', 'grp'],
252                ['user_4', 'fn4', 'ln4', 'user_4@example.net', 'grp'],
253                ['user_5', 'fn5', 'ln5', 'user_5@example.net', 'grp'],
254               ]
255
256    def test_users_valid(self):
257        csv_rows = self.get_users_rows()
258        objects = self.import_csv(UsersImporter, csv_rows)
259        self.assertEquals(len(csv_rows) - 1, len(objects))
260        users = models.User.objects.filter(username__startswith="user_")
261        self.assertEquals(5, users.count())
262        sponsor_links = self.user.delegationlink_delegator.filter(role="sponsor")
263        self.assertEquals(5, sponsor_links.count())
264        self.assertEquals(set(users.values_list("id", flat=True)),
265                set(sponsor_links.values_list("delegatee", flat=True)))
266
267    def test_users_invalid_duplicated_users(self):
268        csv_rows = self.get_users_rows()
269        csv_rows.append(csv_rows[1])
270        self.assertRaises(CSVImportError, self.import_csv,
271                          UsersImporter, csv_rows)
272        # 2 : company + test
273        self.assertEquals(2, User.objects.count())
274        sponsor_links = self.user.delegationlink_delegator.filter(role="sponsor")
275        self.assertFalse(bool(sponsor_links))
276        self.assertEqual(len(mail.outbox), 0)
277
278    def test_users_invalid_missing_first_name(self):
279        csv_rows = self.get_users_rows()
280        csv_rows[1][1] = ""
281        self.assertRaises(CSVImportError, self.import_csv,
282                          UsersImporter, csv_rows)
283        # 2 : company + test
284        self.assertEquals(2, User.objects.count())
285        sponsor_links = self.user.delegationlink_delegator.filter(role="sponsor")
286        self.assertFalse(bool(sponsor_links))
287        self.assertEqual(len(mail.outbox), 0)
288
Note: See TracBrowser for help on using the repository browser.