From e606b1754e21b2960bb6cec266594ca4c13ec4f2 Mon Sep 17 00:00:00 2001 From: wopeizl Date: Mon, 23 Sep 2019 15:57:19 +0800 Subject: [PATCH] =?UTF-8?q?optimize=20the=20error=20information=20when=20t?= =?UTF-8?q?he=20input=20for=20while=20op=20has=20a=20wron=E2=80=A6=20(#198?= =?UTF-8?q?72)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * optimize the error information when the input for while op has a wrong shape test=develop --- python/paddle/fluid/layers/control_flow.py | 6 ++++-- python/paddle/fluid/tests/unittests/test_while_op.py | 10 ++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index 812603ba14c..69a7c019710 100644 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -679,9 +679,11 @@ class While(object): raise TypeError("condition should be a variable") assert isinstance(cond, Variable) if cond.dtype != core.VarDesc.VarType.BOOL: - raise TypeError("condition should be a bool variable") + raise TypeError("condition should be a boolean variable") if reduce(lambda a, b: a * b, cond.shape, 1) != 1: - raise TypeError("condition should be a bool scalar") + raise TypeError( + "condition expected shape as [], but given shape as {0}.". + format(list(cond.shape))) self.cond_var = cond self.is_test = is_test diff --git a/python/paddle/fluid/tests/unittests/test_while_op.py b/python/paddle/fluid/tests/unittests/test_while_op.py index 43fd9d425bf..f19601d7283 100644 --- a/python/paddle/fluid/tests/unittests/test_while_op.py +++ b/python/paddle/fluid/tests/unittests/test_while_op.py @@ -96,6 +96,16 @@ class TestWhileOp(unittest.TestCase): fetch_list=[sum_result]) self.assertAlmostEqual(numpy.sum(d), numpy.sum(outs[0]), delta=0.01) + def test_exceptions(self): + i = layers.zeros(shape=[2], dtype='int64') + array_len = layers.fill_constant(shape=[2], dtype='int64', value=1) + cond = layers.less_than(x=i, y=array_len) + with self.assertRaises(TypeError): + layers.While(cond=cond) + cond = layers.cast(cond, dtype='float64') + with self.assertRaises(TypeError): + layers.While(cond=cond) + if __name__ == '__main__': unittest.main() -- GitLab