From cc8ca8cea9506c58d5c2c99be3040452d2be882a Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Wed, 25 Mar 2020 17:22:25 +0800 Subject: [PATCH] Polish error Info in while_loop (#23183) * Polish error Info in while_loop test=develop --- python/paddle/fluid/layers/control_flow.py | 7 +++- .../tests/unittests/test_while_loop_op.py | 42 +++++++++++++++++++ 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index 57aefb7df97..802d3e1cc38 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -1031,9 +1031,12 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None): output_vars = body(*loop_vars) if not isinstance(output_vars, (list, tuple)): output_vars = [output_vars] - if len(output_vars) != len(loop_vars): + try: + assert_same_structure(output_vars, loop_vars, check_types=False) + except ValueError as e: raise ValueError("body in while_loop should return the same arity " - "(length and structure) and types as loop_vars") + "(length and structure) as loop_vars: {0}".format( + e)) now_cond = cond(*output_vars) map_structure(assign, output_vars, loop_vars) assign(now_cond, pre_cond) diff --git a/python/paddle/fluid/tests/unittests/test_while_loop_op.py b/python/paddle/fluid/tests/unittests/test_while_loop_op.py index 47fb726c6da..b9c3d9dbf38 100644 --- a/python/paddle/fluid/tests/unittests/test_while_loop_op.py +++ b/python/paddle/fluid/tests/unittests/test_while_loop_op.py @@ -384,6 +384,23 @@ class TestApiWhileLoop_Error(unittest.TestCase): def body_returns_error_type(i, ten): return layers.increment(i) + def cond_returns_with_mutable_dict(i, test_dict): + return i > 0 + + def body_returns_with_mutable_dict(i, test_dict): + test_dict['new_key'] = layers.fill_constant( + shape=[1], dtype='int64', value=1) + return layers.increment(i), test_dict + + def cond_returns_with_mutable_list(i, test_list): + return i > 0 + + def body_returns_with_mutable_list(i, test_list): + test_list.append( + layers.fill_constant( + shape=[1], dtype='int64', value=1)) + return layers.increment(i), test_list + main_program = Program() startup_program = Program() with program_guard(main_program, startup_program): @@ -451,6 +468,31 @@ class TestApiWhileLoop_Error(unittest.TestCase): self.assertRaises(ValueError, value_error_body_returns_error_type) + # The length of `output_vars` with mutable value should keep same with `loop_vars` + def value_error_body_returns_with_mutable_dict(): + test_dict = { + "int_constant": layers.fill_constant( + shape=[2, 2], dtype='int64', value=1) + } + out = layers.while_loop(cond_returns_with_mutable_dict, + body_returns_with_mutable_dict, + [data, test_dict]) + + self.assertRaises(ValueError, + value_error_body_returns_with_mutable_dict) + + def value_error_body_returns_with_mutable_list(): + test_list = [ + layers.fill_constant( + shape=[2, 2], dtype='int64', value=1) + ] + out = layers.while_loop(cond_returns_with_mutable_list, + body_returns_with_mutable_list, + [data, test_list]) + + self.assertRaises(ValueError, + value_error_body_returns_with_mutable_list) + if __name__ == '__main__': unittest.main() -- GitLab