提交 999a05b0 编写于 作者: S sneaxiy

polish code

test=develop
上级 d752177b
...@@ -71,7 +71,7 @@ class DataToLoDTensorConverter(object): ...@@ -71,7 +71,7 @@ class DataToLoDTensorConverter(object):
for each_data in data: for each_data in data:
self._feed_impl_(each_data, lod[1:], lod_level - 1) self._feed_impl_(each_data, lod[1:], lod_level - 1)
def _check_shape_(self, shape): def _check_shape(self, shape):
for s1, s2 in zip(self.shape, shape): for s1, s2 in zip(self.shape, shape):
if s1 != s2 and s1 >= 0 and s2 >= 0: if s1 != s2 and s1 >= 0 and s2 >= 0:
raise ValueError( raise ValueError(
...@@ -82,9 +82,14 @@ class DataToLoDTensorConverter(object): ...@@ -82,9 +82,14 @@ class DataToLoDTensorConverter(object):
arr = numpy.array(self.data, dtype=self.dtype) arr = numpy.array(self.data, dtype=self.dtype)
if self.shape: if self.shape:
if len(arr.shape) != len(self.shape): if len(arr.shape) != len(self.shape):
arr = arr.reshape(self.shape) try:
arr = arr.reshape(self.shape)
except ValueError:
raise ValueError(
"Reshape error. What is defined in data layer is {}, but receive {}"
.format(self.shape, arr.shape))
else: else:
self._check_shape_(arr.shape) self._check_shape(arr.shape)
t = core.LoDTensor() t = core.LoDTensor()
t.set(arr, self.place) t.set(arr, self.place)
if self.lod_level > 0: if self.lod_level > 0:
......
...@@ -30,6 +30,12 @@ class TestDataFeeder(unittest.TestCase): ...@@ -30,6 +30,12 @@ class TestDataFeeder(unittest.TestCase):
self.assertEqual(result['image'].recursive_sequence_lengths(), []) self.assertEqual(result['image'].recursive_sequence_lengths(), [])
self.assertEqual(result['label'].recursive_sequence_lengths(), []) self.assertEqual(result['label'].recursive_sequence_lengths(), [])
try:
result = feeder.feed([([0] * 783, [9]), ([1] * 783, [1])])
self.assertTrue(False)
except ValueError:
self.assertTrue(True)
def test_lod_level_1_converter(self): def test_lod_level_1_converter(self):
# lod_level = 1 # lod_level = 1
# each sentence has a different number of words # each sentence has a different number of words
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册