未验证 提交 f8f66ec5 编写于 作者: X xiongkun 提交者: GitHub

[ Dy2Static | Controlflow ]While + Cond support for python container. (#45105)

* while support for python container.
It is convenient to convert more dynamic graph codes into static graphs.

* cond support python container
上级 3d514e48
......@@ -53,7 +53,7 @@ void DenseTensor::check_memory_size() const {
"Tensor's dimension is out of bound."
"Tensor's dimension must be equal or less than the size of its "
"memory."
"But received Tensor's dimension is d%, memory's size is %d.",
"But received Tensor's dimension is %d, memory's size is %d.",
numel() * SizeOf(dtype()),
memory_size()));
}
......
......@@ -25,6 +25,7 @@ from paddle.fluid.layers import cast, control_flow, logical_and, logical_not, lo
from paddle.fluid.layers.control_flow import cond, while_loop, less_than, increment
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_VAR_NAME
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar, Dygraph2StaticException
from paddle.fluid.layers.utils import copy_mutable_vars
def indexable(x, code=None):
......@@ -92,7 +93,10 @@ def _run_paddle_while(cond, body, getter, setter):
# NOTE: loop_vars of Paddle op `control_flow.while_loop` must be Paddle Tensors.
def new_body_fn(*args):
""" wrap the body() and add return value for `while_loop`
the args may be differ from getter().
"""
mutable_loop_vars = args
setter(mutable_loop_vars)
body()
return getter()
......@@ -110,7 +114,7 @@ def _run_paddle_while(cond, body, getter, setter):
setter(loop_vars) # change the non-local var to variable
# variable maybe modified to inner var. change it into
loop_vars = control_flow.while_loop(new_cond_fn, new_body_fn, loop_vars)
setter(loop_vars) # change the non-local var to variable
setter(loop_vars) # change back to loop_vars
return loop_vars
......@@ -287,7 +291,8 @@ def _run_paddle_cond(pred, true_fn, false_fn, get_args, set_args,
init_args = get_args()
def new_true_fn():
set_args(init_args)
#init args may contain mutable python container like [var, 2], we copy then like in while_loop
set_args(copy_mutable_vars(init_args))
ret = true_fn()
# IfExpr will return a non-None return value, so we just return ret.
# We assume normal return has no return value.
......@@ -295,7 +300,8 @@ def _run_paddle_cond(pred, true_fn, false_fn, get_args, set_args,
else: return ret
def new_false_fn():
set_args(init_args)
#init args may contain mutable python container like [var, 2], we copy then like in while_loop
set_args(copy_mutable_vars(init_args))
ret = false_fn()
if ret is None: return get_args()
else: return ret
......
......@@ -21,6 +21,7 @@ from paddle.utils import gast
from paddle.fluid import unique_name
from paddle.fluid.framework import Variable
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar, create_undefined_variable
from paddle.fluid.layers.utils import map_structure, is_sequence
__all__ = [
'create_bool_as_type',
......@@ -63,9 +64,12 @@ def to_static_variable(x):
if isinstance(x, six.integer_types):
return paddle.full(shape=[1], dtype='int64', fill_value=x)
if isinstance(x, UndefinedVar) or x is None:
""" for early return case, we need a variable to represent None, current we use data_layer_not_check.
"""
for early return case, we need a variable to represent None, current we use data_layer_not_check.
"""
return create_undefined_variable()
if is_sequence(x):
return map_structure(to_static_variable, x)
return x
......
......@@ -1329,8 +1329,11 @@ def _deal_with_undefined_var(output_vars, loop_vars):
if isinstance(o_var,
(Variable, ) + support_ret_buildin_type) or o_var is None:
return create_undefined_variable()
if isinstance(o_var, (tuple, list)):
return [create_undefined_variable() for i in range(len(o_var))]
if is_sequence(o_var):
"""
Create a complex container class inside the body of while, including Python list and python Dict
"""
return map_structure(lambda x: create_undefined_variable(), o_var)
if len(output_vars) != len(loop_vars):
raise ValueError("The length of loop_vars should be the same.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册