From 53a62ea4677d0fd1542e9ceed7bd2f573e272c0e Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Fri, 1 Apr 2022 17:29:23 +0800 Subject: [PATCH] [ControlFlow] Fix contrib API bug in while_loop (#41230) * [ControlFlow] Fix contrib API bug in while_loop * format code --- python/paddle/fluid/layers/control_flow.py | 16 +++++++- .../fluid/tests/unittests/test_while_op.py | 39 +++++++++++++++++++ 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index 138e968a0b3..785a3e6eac1 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -974,6 +974,19 @@ def get_inputs_outputs_in_block(current_block, inner_inputs, inner_outputs, :return: inner_inputs, inner_outputs """ + def is_ignore_vars(op, var_name): + # NOTE(dev): There are some persistable var created in some non-standard API + # such as "contrib.layers.shuffle_batch". It create a "Seed" used both in + # Input and Output. This var shall not be considered as a loop_var in + # control_flow. + IGNORE_VAR_NAMES = {"shuffle_batch": ["shuffle_batch_seed"]} + if op.type in IGNORE_VAR_NAMES: + var_names = IGNORE_VAR_NAMES[op.type] + for name in var_names: + if name in var_name: + return True + return False + # Step1: update inner_inputs and inner_outputs # NOTE: Here assumes that all variables are input or output of Ops, # but some variables are created without appendding a real op. @@ -982,7 +995,8 @@ def get_inputs_outputs_in_block(current_block, inner_inputs, inner_outputs, assert isinstance(op, Operator) for iname in op.input_names: for in_var_name in op.input(iname): - if in_var_name not in inner_outputs: + if in_var_name not in inner_outputs and not is_ignore_vars( + op, in_var_name): inner_inputs.add(in_var_name) for oname in op.output_names: diff --git a/python/paddle/fluid/tests/unittests/test_while_op.py b/python/paddle/fluid/tests/unittests/test_while_op.py index d6d52b7d604..8af9a39634f 100644 --- a/python/paddle/fluid/tests/unittests/test_while_op.py +++ b/python/paddle/fluid/tests/unittests/test_while_op.py @@ -137,5 +137,44 @@ class BadInputTest(unittest.TestCase): self.assertRaises(TypeError, test_bad_x) +class TestIgnoreVarNameInWhile(unittest.TestCase): + def test_ignore_var(self): + def cond(i, ten, temp, y): + return i < ten + + def body_func(i, ten, batch_info, origin_seq): + print(batch_info) + batch_info = fluid.contrib.layers.shuffle_batch(batch_info) + print(batch_info) + i = i + 1 + return [i, ten, batch_info, origin_seq] + + x = fluid.layers.data(name='x', shape=[-1, 1, 4]) + y = fluid.layers.data(name='y', shape=[-1, 1, 1]) + temp = layers.concat(input=[x, y], axis=-1) + i = layers.fill_constant(shape=[1], value=0, dtype='int32') + num = layers.fill_constant(shape=[1], value=5, dtype='int32') + + i, ten, shuffle_temp, y = layers.while_loop(cond, body_func, + [i, num, temp, y]) + + output = shuffle_temp + + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(fluid.default_startup_program()) + + input_x = numpy.array([[1, 2, 3, 4], [4, 5, 6, 7], [7, 8, 9, 10]]) + input_x = input_x.reshape(3, 1, 4) + input_y = numpy.array([[10], [12], [33]]) + input_y = input_y.reshape(3, 1, 1) + + res, = exe.run(fluid.default_main_program(), + feed={'x': input_x, + 'y': input_y}, + fetch_list=[output]) + + self.assertListEqual(list(res.shape), [3, 1, 5]) + + if __name__ == '__main__': unittest.main() -- GitLab