From e1a5fb8f653c0c948abcebb3e3e252edd724f05c Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Sat, 3 Sep 2022 17:01:08 +0800 Subject: [PATCH] [ONNX]Remove cond_var from Input('X') in while_loop op (#45675) * [ONXX]Remove cond_var from Input(X) in while_loop op * fix unittest --- python/paddle/fluid/layers/control_flow.py | 3 +++ python/paddle/fluid/tests/unittests/test_while_op.py | 1 + 2 files changed, 4 insertions(+) diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index a5a04e6582d..a9f2eaa40e2 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -1162,6 +1162,9 @@ class While(object): out_vars.append(inner_var) x_name_list |= set(map(lambda x: x.name, out_vars)) + # NOTE(dev): cond_var has been contained in Input('Condition'), so + # we remove it from Input('X') + x_name_list -= {self.cond_var.name} step_scope = parent_block.create_var( type=core.VarDesc.VarType.STEP_SCOPES) diff --git a/python/paddle/fluid/tests/unittests/test_while_op.py b/python/paddle/fluid/tests/unittests/test_while_op.py index 04093fdceb3..144e7a5c496 100644 --- a/python/paddle/fluid/tests/unittests/test_while_op.py +++ b/python/paddle/fluid/tests/unittests/test_while_op.py @@ -223,6 +223,7 @@ class TestOutputsMustExistsInputs(unittest.TestCase): for op in main_program.block(0).ops: if op.type == "while": for out_name in op.output("Out"): + if out_name in op.input("Condition"): continue self.assertTrue( out_name in op.input("X"), "In while op, the variable in output(`Out`) must exists in inputs(`X`), but the variable with name `{}` not meet the precondition." -- GitLab