diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index 812603ba14cfd944667d14fc3dae2c490005add3..69a7c019710db31ecb84a8ceb35a437f81f3b6f6 100644 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -679,9 +679,11 @@ class While(object): raise TypeError("condition should be a variable") assert isinstance(cond, Variable) if cond.dtype != core.VarDesc.VarType.BOOL: - raise TypeError("condition should be a bool variable") + raise TypeError("condition should be a boolean variable") if reduce(lambda a, b: a * b, cond.shape, 1) != 1: - raise TypeError("condition should be a bool scalar") + raise TypeError( + "condition expected shape as [], but given shape as {0}.". + format(list(cond.shape))) self.cond_var = cond self.is_test = is_test diff --git a/python/paddle/fluid/tests/unittests/test_while_op.py b/python/paddle/fluid/tests/unittests/test_while_op.py index 43fd9d425bffb1e0198f4e845da959570a964990..f19601d72835f7041d3d6434ffe9fcf09ad15065 100644 --- a/python/paddle/fluid/tests/unittests/test_while_op.py +++ b/python/paddle/fluid/tests/unittests/test_while_op.py @@ -96,6 +96,16 @@ class TestWhileOp(unittest.TestCase): fetch_list=[sum_result]) self.assertAlmostEqual(numpy.sum(d), numpy.sum(outs[0]), delta=0.01) + def test_exceptions(self): + i = layers.zeros(shape=[2], dtype='int64') + array_len = layers.fill_constant(shape=[2], dtype='int64', value=1) + cond = layers.less_than(x=i, y=array_len) + with self.assertRaises(TypeError): + layers.While(cond=cond) + cond = layers.cast(cond, dtype='float64') + with self.assertRaises(TypeError): + layers.While(cond=cond) + if __name__ == '__main__': unittest.main()