From 999a05b04bdb6eb62f8de8fe106e2df10388157c Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Wed, 9 Jan 2019 04:40:31 +0000 Subject: [PATCH] polish code test=develop --- python/paddle/fluid/data_feeder.py | 11 ++++++++--- python/paddle/fluid/tests/test_data_feeder.py | 6 ++++++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/data_feeder.py b/python/paddle/fluid/data_feeder.py index 1301525914c..7b70d19de5c 100644 --- a/python/paddle/fluid/data_feeder.py +++ b/python/paddle/fluid/data_feeder.py @@ -71,7 +71,7 @@ class DataToLoDTensorConverter(object): for each_data in data: self._feed_impl_(each_data, lod[1:], lod_level - 1) - def _check_shape_(self, shape): + def _check_shape(self, shape): for s1, s2 in zip(self.shape, shape): if s1 != s2 and s1 >= 0 and s2 >= 0: raise ValueError( @@ -82,9 +82,14 @@ class DataToLoDTensorConverter(object): arr = numpy.array(self.data, dtype=self.dtype) if self.shape: if len(arr.shape) != len(self.shape): - arr = arr.reshape(self.shape) + try: + arr = arr.reshape(self.shape) + except ValueError: + raise ValueError( + "Reshape error. What is defined in data layer is {}, but receive {}" + .format(self.shape, arr.shape)) else: - self._check_shape_(arr.shape) + self._check_shape(arr.shape) t = core.LoDTensor() t.set(arr, self.place) if self.lod_level > 0: diff --git a/python/paddle/fluid/tests/test_data_feeder.py b/python/paddle/fluid/tests/test_data_feeder.py index 01de564aa43..16a33fd3ab3 100644 --- a/python/paddle/fluid/tests/test_data_feeder.py +++ b/python/paddle/fluid/tests/test_data_feeder.py @@ -30,6 +30,12 @@ class TestDataFeeder(unittest.TestCase): self.assertEqual(result['image'].recursive_sequence_lengths(), []) self.assertEqual(result['label'].recursive_sequence_lengths(), []) + try: + result = feeder.feed([([0] * 783, [9]), ([1] * 783, [1])]) + self.assertTrue(False) + except ValueError: + self.assertTrue(True) + def test_lod_level_1_converter(self): # lod_level = 1 # each sentence has a different number of words -- GitLab