From e654f1e74fb71a8f82b71e0dbb764b7243632b6e Mon Sep 17 00:00:00 2001 From: xiongkun Date: Fri, 19 Aug 2022 11:28:06 +0800 Subject: [PATCH] [ Dy2Static ]Modify while interface[python] to fit onnx (#45034) * Make sure that the output of whilep must exist in the input * insert assign in block(0) * add unittest. --- .../fluid/dygraph/dygraph_to_static/utils.py | 5 +++ python/paddle/fluid/layers/control_flow.py | 2 + .../fluid/tests/unittests/test_while_op.py | 39 +++++++++++++++++++ 3 files changed, 46 insertions(+) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index ed7faf83cef..d3db7209c65 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -156,7 +156,12 @@ def create_undefined_variable(): from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_MAGIC_NUM var = data_layer_not_check(unique_name.generate("undefined_var"), [1], "float64") + # the variable is created in block(0), we append assign in block(0) either. + helper = LayerHelper('create_undefined_variable', **locals()) + saved_block_ids = helper.main_program.current_block_idx + helper.main_program.current_block_idx = 0 assign(RETURN_NO_VALUE_MAGIC_NUM, var) + helper.main_program.current_block_idx = saved_block_ids return var diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index d7b85961247..4a2cab2e062 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -1161,6 +1161,8 @@ class While(object): if inner_var: out_vars.append(inner_var) + x_name_list |= set(map(lambda x: x.name, out_vars)) + step_scope = parent_block.create_var( type=core.VarDesc.VarType.STEP_SCOPES) diff --git a/python/paddle/fluid/tests/unittests/test_while_op.py b/python/paddle/fluid/tests/unittests/test_while_op.py index 8e35a57f242..04093fdceb3 100644 --- a/python/paddle/fluid/tests/unittests/test_while_op.py +++ b/python/paddle/fluid/tests/unittests/test_while_op.py @@ -190,5 +190,44 @@ class TestIgnoreVarNameInWhile(unittest.TestCase): self.assertListEqual(list(res.shape), [3, 1, 5]) +class TestOutputsMustExistsInputs(unittest.TestCase): + + def test_outputs_exists_inputs(self): + """ + We guarantee that the output tensor must be in the input tensor, so that the output and input can correspond to each other, but the input can be greater than the number of outputs. It's required in paddle2onnx. + """ + main_program = fluid.Program() + startup_program = fluid.Program() + with fluid.program_guard(main_program, startup_program): + + def func(x): + s = paddle.zeros([1]) + i = paddle.ones([1]) + max_len = paddle.shape(x)[0] + + def cond(i, s, x): + return i < max_len + + def body(i, s, x): + iter = x[i] + s += iter + i += 1 + return i, s, x + + [i, s, x] = paddle.static.nn.while_loop(cond, body, [i, s, x]) + return s + + paddle.enable_static() + x = paddle.static.data(shape=[-1], name='x') + func(x) + for op in main_program.block(0).ops: + if op.type == "while": + for out_name in op.output("Out"): + self.assertTrue( + out_name in op.input("X"), + "In while op, the variable in output(`Out`) must exists in inputs(`X`), but the variable with name `{}` not meet the precondition." + .format(out_name)) + + if __name__ == '__main__': unittest.main() -- GitLab