提交 999a05b0 编写于 作者: S sneaxiy

polish code

test=develop
上级 d752177b
......@@ -71,7 +71,7 @@ class DataToLoDTensorConverter(object):
for each_data in data:
self._feed_impl_(each_data, lod[1:], lod_level - 1)
def _check_shape_(self, shape):
def _check_shape(self, shape):
for s1, s2 in zip(self.shape, shape):
if s1 != s2 and s1 >= 0 and s2 >= 0:
raise ValueError(
......@@ -82,9 +82,14 @@ class DataToLoDTensorConverter(object):
arr = numpy.array(self.data, dtype=self.dtype)
if self.shape:
if len(arr.shape) != len(self.shape):
try:
arr = arr.reshape(self.shape)
except ValueError:
raise ValueError(
"Reshape error. What is defined in data layer is {}, but receive {}"
.format(self.shape, arr.shape))
else:
self._check_shape_(arr.shape)
self._check_shape(arr.shape)
t = core.LoDTensor()
t.set(arr, self.place)
if self.lod_level > 0:
......
......@@ -30,6 +30,12 @@ class TestDataFeeder(unittest.TestCase):
self.assertEqual(result['image'].recursive_sequence_lengths(), [])
self.assertEqual(result['label'].recursive_sequence_lengths(), [])
try:
result = feeder.feed([([0] * 783, [9]), ([1] * 783, [1])])
self.assertTrue(False)
except ValueError:
self.assertTrue(True)
def test_lod_level_1_converter(self):
# lod_level = 1
# each sentence has a different number of words
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册