diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index 57aefb7df973c16cae4a0fcb1e57743de6a96e2e..802d3e1cc384a42cf883ac57dc8f92bab7f09f48 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -1031,9 +1031,12 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None): output_vars = body(*loop_vars) if not isinstance(output_vars, (list, tuple)): output_vars = [output_vars] - if len(output_vars) != len(loop_vars): + try: + assert_same_structure(output_vars, loop_vars, check_types=False) + except ValueError as e: raise ValueError("body in while_loop should return the same arity " - "(length and structure) and types as loop_vars") + "(length and structure) as loop_vars: {0}".format( + e)) now_cond = cond(*output_vars) map_structure(assign, output_vars, loop_vars) assign(now_cond, pre_cond) diff --git a/python/paddle/fluid/tests/unittests/test_while_loop_op.py b/python/paddle/fluid/tests/unittests/test_while_loop_op.py index 47fb726c6dae3b8ce5298ba282eae28533af3606..b9c3d9dbf3853cbb6825007cd8484b88f635a75c 100644 --- a/python/paddle/fluid/tests/unittests/test_while_loop_op.py +++ b/python/paddle/fluid/tests/unittests/test_while_loop_op.py @@ -384,6 +384,23 @@ class TestApiWhileLoop_Error(unittest.TestCase): def body_returns_error_type(i, ten): return layers.increment(i) + def cond_returns_with_mutable_dict(i, test_dict): + return i > 0 + + def body_returns_with_mutable_dict(i, test_dict): + test_dict['new_key'] = layers.fill_constant( + shape=[1], dtype='int64', value=1) + return layers.increment(i), test_dict + + def cond_returns_with_mutable_list(i, test_list): + return i > 0 + + def body_returns_with_mutable_list(i, test_list): + test_list.append( + layers.fill_constant( + shape=[1], dtype='int64', value=1)) + return layers.increment(i), test_list + main_program = Program() startup_program = Program() with program_guard(main_program, startup_program): @@ -451,6 +468,31 @@ class TestApiWhileLoop_Error(unittest.TestCase): self.assertRaises(ValueError, value_error_body_returns_error_type) + # The length of `output_vars` with mutable value should keep same with `loop_vars` + def value_error_body_returns_with_mutable_dict(): + test_dict = { + "int_constant": layers.fill_constant( + shape=[2, 2], dtype='int64', value=1) + } + out = layers.while_loop(cond_returns_with_mutable_dict, + body_returns_with_mutable_dict, + [data, test_dict]) + + self.assertRaises(ValueError, + value_error_body_returns_with_mutable_dict) + + def value_error_body_returns_with_mutable_list(): + test_list = [ + layers.fill_constant( + shape=[2, 2], dtype='int64', value=1) + ] + out = layers.while_loop(cond_returns_with_mutable_list, + body_returns_with_mutable_list, + [data, test_list]) + + self.assertRaises(ValueError, + value_error_body_returns_with_mutable_list) + if __name__ == '__main__': unittest.main()