diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index a5a04e6582dd9e5ddcf1be019e7af82f89930887..a9f2eaa40e2a54fbdeeb806b5ba986da1306afaa 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -1162,6 +1162,9 @@ class While(object): out_vars.append(inner_var) x_name_list |= set(map(lambda x: x.name, out_vars)) + # NOTE(dev): cond_var has been contained in Input('Condition'), so + # we remove it from Input('X') + x_name_list -= {self.cond_var.name} step_scope = parent_block.create_var( type=core.VarDesc.VarType.STEP_SCOPES) diff --git a/python/paddle/fluid/tests/unittests/test_while_op.py b/python/paddle/fluid/tests/unittests/test_while_op.py index 04093fdceb312f09f8bd3452fd094b36a212309f..144e7a5c496ed6d0b49597887ae05a637245f426 100644 --- a/python/paddle/fluid/tests/unittests/test_while_op.py +++ b/python/paddle/fluid/tests/unittests/test_while_op.py @@ -223,6 +223,7 @@ class TestOutputsMustExistsInputs(unittest.TestCase): for op in main_program.block(0).ops: if op.type == "while": for out_name in op.output("Out"): + if out_name in op.input("Condition"): continue self.assertTrue( out_name in op.input("X"), "In while op, the variable in output(`Out`) must exists in inputs(`X`), but the variable with name `{}` not meet the precondition."