未验证 提交 e1a5fb8f 编写于 作者: A Aurelius84 提交者: GitHub

[ONNX]Remove cond_var from Input('X') in while_loop op (#45675)

* [ONXX]Remove cond_var from Input(X) in while_loop op

* fix unittest
上级 eaea2bee
......@@ -1162,6 +1162,9 @@ class While(object):
out_vars.append(inner_var)
x_name_list |= set(map(lambda x: x.name, out_vars))
# NOTE(dev): cond_var has been contained in Input('Condition'), so
# we remove it from Input('X')
x_name_list -= {self.cond_var.name}
step_scope = parent_block.create_var(
type=core.VarDesc.VarType.STEP_SCOPES)
......
......@@ -223,6 +223,7 @@ class TestOutputsMustExistsInputs(unittest.TestCase):
for op in main_program.block(0).ops:
if op.type == "while":
for out_name in op.output("Out"):
if out_name in op.input("Condition"): continue
self.assertTrue(
out_name in op.input("X"),
"In while op, the variable in output(`Out`) must exists in inputs(`X`), but the variable with name `{}` not meet the precondition."
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册