未验证 提交 f8f66ec5 编写于 作者: X xiongkun 提交者: GitHub

[ Dy2Static | Controlflow ]While + Cond support for python container. (#45105)

* while support for python container.
It is convenient to convert more dynamic graph codes into static graphs.

* cond support python container
上级 3d514e48
...@@ -53,7 +53,7 @@ void DenseTensor::check_memory_size() const { ...@@ -53,7 +53,7 @@ void DenseTensor::check_memory_size() const {
"Tensor's dimension is out of bound." "Tensor's dimension is out of bound."
"Tensor's dimension must be equal or less than the size of its " "Tensor's dimension must be equal or less than the size of its "
"memory." "memory."
"But received Tensor's dimension is d%, memory's size is %d.", "But received Tensor's dimension is %d, memory's size is %d.",
numel() * SizeOf(dtype()), numel() * SizeOf(dtype()),
memory_size())); memory_size()));
} }
......
...@@ -25,6 +25,7 @@ from paddle.fluid.layers import cast, control_flow, logical_and, logical_not, lo ...@@ -25,6 +25,7 @@ from paddle.fluid.layers import cast, control_flow, logical_and, logical_not, lo
from paddle.fluid.layers.control_flow import cond, while_loop, less_than, increment from paddle.fluid.layers.control_flow import cond, while_loop, less_than, increment
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_VAR_NAME from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_VAR_NAME
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar, Dygraph2StaticException from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar, Dygraph2StaticException
from paddle.fluid.layers.utils import copy_mutable_vars
def indexable(x, code=None): def indexable(x, code=None):
...@@ -92,7 +93,10 @@ def _run_paddle_while(cond, body, getter, setter): ...@@ -92,7 +93,10 @@ def _run_paddle_while(cond, body, getter, setter):
# NOTE: loop_vars of Paddle op `control_flow.while_loop` must be Paddle Tensors. # NOTE: loop_vars of Paddle op `control_flow.while_loop` must be Paddle Tensors.
def new_body_fn(*args): def new_body_fn(*args):
""" wrap the body() and add return value for `while_loop` """ wrap the body() and add return value for `while_loop`
the args may be differ from getter().
""" """
mutable_loop_vars = args
setter(mutable_loop_vars)
body() body()
return getter() return getter()
...@@ -110,7 +114,7 @@ def _run_paddle_while(cond, body, getter, setter): ...@@ -110,7 +114,7 @@ def _run_paddle_while(cond, body, getter, setter):
setter(loop_vars) # change the non-local var to variable setter(loop_vars) # change the non-local var to variable
# variable maybe modified to inner var. change it into # variable maybe modified to inner var. change it into
loop_vars = control_flow.while_loop(new_cond_fn, new_body_fn, loop_vars) loop_vars = control_flow.while_loop(new_cond_fn, new_body_fn, loop_vars)
setter(loop_vars) # change the non-local var to variable setter(loop_vars) # change back to loop_vars
return loop_vars return loop_vars
...@@ -287,7 +291,8 @@ def _run_paddle_cond(pred, true_fn, false_fn, get_args, set_args, ...@@ -287,7 +291,8 @@ def _run_paddle_cond(pred, true_fn, false_fn, get_args, set_args,
init_args = get_args() init_args = get_args()
def new_true_fn(): def new_true_fn():
set_args(init_args) #init args may contain mutable python container like [var, 2], we copy then like in while_loop
set_args(copy_mutable_vars(init_args))
ret = true_fn() ret = true_fn()
# IfExpr will return a non-None return value, so we just return ret. # IfExpr will return a non-None return value, so we just return ret.
# We assume normal return has no return value. # We assume normal return has no return value.
...@@ -295,7 +300,8 @@ def _run_paddle_cond(pred, true_fn, false_fn, get_args, set_args, ...@@ -295,7 +300,8 @@ def _run_paddle_cond(pred, true_fn, false_fn, get_args, set_args,
else: return ret else: return ret
def new_false_fn(): def new_false_fn():
set_args(init_args) #init args may contain mutable python container like [var, 2], we copy then like in while_loop
set_args(copy_mutable_vars(init_args))
ret = false_fn() ret = false_fn()
if ret is None: return get_args() if ret is None: return get_args()
else: return ret else: return ret
......
...@@ -21,6 +21,7 @@ from paddle.utils import gast ...@@ -21,6 +21,7 @@ from paddle.utils import gast
from paddle.fluid import unique_name from paddle.fluid import unique_name
from paddle.fluid.framework import Variable from paddle.fluid.framework import Variable
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar, create_undefined_variable from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar, create_undefined_variable
from paddle.fluid.layers.utils import map_structure, is_sequence
__all__ = [ __all__ = [
'create_bool_as_type', 'create_bool_as_type',
...@@ -63,9 +64,12 @@ def to_static_variable(x): ...@@ -63,9 +64,12 @@ def to_static_variable(x):
if isinstance(x, six.integer_types): if isinstance(x, six.integer_types):
return paddle.full(shape=[1], dtype='int64', fill_value=x) return paddle.full(shape=[1], dtype='int64', fill_value=x)
if isinstance(x, UndefinedVar) or x is None: if isinstance(x, UndefinedVar) or x is None:
""" for early return case, we need a variable to represent None, current we use data_layer_not_check. """
for early return case, we need a variable to represent None, current we use data_layer_not_check.
""" """
return create_undefined_variable() return create_undefined_variable()
if is_sequence(x):
return map_structure(to_static_variable, x)
return x return x
......
...@@ -1329,8 +1329,11 @@ def _deal_with_undefined_var(output_vars, loop_vars): ...@@ -1329,8 +1329,11 @@ def _deal_with_undefined_var(output_vars, loop_vars):
if isinstance(o_var, if isinstance(o_var,
(Variable, ) + support_ret_buildin_type) or o_var is None: (Variable, ) + support_ret_buildin_type) or o_var is None:
return create_undefined_variable() return create_undefined_variable()
if isinstance(o_var, (tuple, list)): if is_sequence(o_var):
return [create_undefined_variable() for i in range(len(o_var))] """
Create a complex container class inside the body of while, including Python list and python Dict
"""
return map_structure(lambda x: create_undefined_variable(), o_var)
if len(output_vars) != len(loop_vars): if len(output_vars) != len(loop_vars):
raise ValueError("The length of loop_vars should be the same.") raise ValueError("The length of loop_vars should be the same.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册