未验证 提交 e606b175 编写于 作者: W wopeizl 提交者: GitHub

optimize the error information when the input for while op has a wron… (#19872)

* optimize the error information when the input for while op has a wrong shape test=develop
上级 d31c92a2
...@@ -679,9 +679,11 @@ class While(object): ...@@ -679,9 +679,11 @@ class While(object):
raise TypeError("condition should be a variable") raise TypeError("condition should be a variable")
assert isinstance(cond, Variable) assert isinstance(cond, Variable)
if cond.dtype != core.VarDesc.VarType.BOOL: if cond.dtype != core.VarDesc.VarType.BOOL:
raise TypeError("condition should be a bool variable") raise TypeError("condition should be a boolean variable")
if reduce(lambda a, b: a * b, cond.shape, 1) != 1: if reduce(lambda a, b: a * b, cond.shape, 1) != 1:
raise TypeError("condition should be a bool scalar") raise TypeError(
"condition expected shape as [], but given shape as {0}.".
format(list(cond.shape)))
self.cond_var = cond self.cond_var = cond
self.is_test = is_test self.is_test = is_test
......
...@@ -96,6 +96,16 @@ class TestWhileOp(unittest.TestCase): ...@@ -96,6 +96,16 @@ class TestWhileOp(unittest.TestCase):
fetch_list=[sum_result]) fetch_list=[sum_result])
self.assertAlmostEqual(numpy.sum(d), numpy.sum(outs[0]), delta=0.01) self.assertAlmostEqual(numpy.sum(d), numpy.sum(outs[0]), delta=0.01)
def test_exceptions(self):
i = layers.zeros(shape=[2], dtype='int64')
array_len = layers.fill_constant(shape=[2], dtype='int64', value=1)
cond = layers.less_than(x=i, y=array_len)
with self.assertRaises(TypeError):
layers.While(cond=cond)
cond = layers.cast(cond, dtype='float64')
with self.assertRaises(TypeError):
layers.While(cond=cond)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册