未验证 提交 e654f1e7 编写于 作者: X xiongkun 提交者: GitHub

[ Dy2Static ]Modify while interface[python] to fit onnx (#45034)

* Make sure that the output of whilep must exist in the input

* insert assign in block(0)

* add unittest.
上级 933868cd
......@@ -156,7 +156,12 @@ def create_undefined_variable():
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_MAGIC_NUM
var = data_layer_not_check(unique_name.generate("undefined_var"), [1],
"float64")
# the variable is created in block(0), we append assign in block(0) either.
helper = LayerHelper('create_undefined_variable', **locals())
saved_block_ids = helper.main_program.current_block_idx
helper.main_program.current_block_idx = 0
assign(RETURN_NO_VALUE_MAGIC_NUM, var)
helper.main_program.current_block_idx = saved_block_ids
return var
......
......@@ -1161,6 +1161,8 @@ class While(object):
if inner_var:
out_vars.append(inner_var)
x_name_list |= set(map(lambda x: x.name, out_vars))
step_scope = parent_block.create_var(
type=core.VarDesc.VarType.STEP_SCOPES)
......
......@@ -190,5 +190,44 @@ class TestIgnoreVarNameInWhile(unittest.TestCase):
self.assertListEqual(list(res.shape), [3, 1, 5])
class TestOutputsMustExistsInputs(unittest.TestCase):
def test_outputs_exists_inputs(self):
"""
We guarantee that the output tensor must be in the input tensor, so that the output and input can correspond to each other, but the input can be greater than the number of outputs. It's required in paddle2onnx.
"""
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
def func(x):
s = paddle.zeros([1])
i = paddle.ones([1])
max_len = paddle.shape(x)[0]
def cond(i, s, x):
return i < max_len
def body(i, s, x):
iter = x[i]
s += iter
i += 1
return i, s, x
[i, s, x] = paddle.static.nn.while_loop(cond, body, [i, s, x])
return s
paddle.enable_static()
x = paddle.static.data(shape=[-1], name='x')
func(x)
for op in main_program.block(0).ops:
if op.type == "while":
for out_name in op.output("Out"):
self.assertTrue(
out_name in op.input("X"),
"In while op, the variable in output(`Out`) must exists in inputs(`X`), but the variable with name `{}` not meet the precondition."
.format(out_name))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册