未验证 提交 8571833f 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat]Enhance nonlocal machanism while returning single var (#43957)

* [Dy2Stat]Enhance nonlocal machanism while returning single var

* [Dy2Stat]Enhance nonlocal machanism while returning single var
上级 ccb333c1
......@@ -50,18 +50,13 @@ def convert_while_loop(cond, body, getter, setter):
def _run_paddle_while(cond, body, getter, setter):
# NOTE: loop_vars of Paddle op `control_flow.while_loop` must be Paddle Tensors.
def to_list(x):
if isinstance(x, (tuple, list)): return x
return [x]
# UndefinedVar will become data layer not check.
loop_vars = [to_static_variable(var) for var in to_list(getter())]
setter(loop_vars if len(loop_vars) > 1 else
loop_vars[0]) # change the non-local var to variable
loop_vars = [to_static_variable(var) for var in getter()]
setter(loop_vars) # change the non-local var to variable
# variable maybe modified to inner var. change it into
loop_vars = control_flow.while_loop(cond, body, loop_vars)
setter(loop_vars if len(loop_vars) > 1 else
loop_vars[0]) # change the non-local var to variable
setter(loop_vars) # change the non-local var to variable
return loop_vars
......@@ -318,11 +313,11 @@ def _recover_args_state(outs, get_args, set_args, return_name_ids):
init_args = get_args()
# recover args state
num_outs = len(return_name_ids)
num_args = 1 if not isinstance(init_args, tuple) else len(init_args)
num_args = len(init_args)
assert num_outs <= num_args
if num_args == 1:
final_outs = outs
final_outs = (outs, )
else:
outs = (outs, ) if num_outs == 1 else outs
final_outs = outs + init_args[num_outs:]
......
......@@ -38,6 +38,9 @@ from paddle.fluid import core
PADDLE_MODULE_PREFIX = 'paddle.'
DYGRAPH_MODULE_PREFIX = 'paddle.fluid.dygraph'
DYGRAPH_TO_STATIC_MODULE_PREFIX = 'paddle.fluid.dygraph.dygraph_to_static'
GET_ARGS_FUNC_PREFIX = 'get_args'
SET_ARGS_FUNC_PREFIX = 'set_args'
ARGS_NAME = '__args'
class BaseNodeVisitor(gast.NodeVisitor):
......@@ -1619,7 +1622,7 @@ def create_get_args_node(names):
template = """
def {func_name}():
nonlocal {nonlocal_vars}
return {vars}
return {vars},
"""
func_def = template.format(
func_name=unique_name.generate(GET_ARGS_FUNC_PREFIX),
......@@ -1628,11 +1631,6 @@ def create_get_args_node(names):
return gast.parse(textwrap.dedent(func_def)).body[0]
GET_ARGS_FUNC_PREFIX = 'get_args'
SET_ARGS_FUNC_PREFIX = 'set_args'
ARGS_NAME = '__args'
def create_set_args_node(names):
"""
Create set_args function as follows:
......@@ -1661,7 +1659,7 @@ def create_set_args_node(names):
template = """
def {func_name}({args}):
nonlocal {nonlocal_vars}
{vars} = {args}
{vars}, = {args}
"""
func_def = template.format(
func_name=unique_name.generate(SET_ARGS_FUNC_PREFIX),
......
......@@ -74,11 +74,11 @@ class StaticCode1():
def get_args_0():
nonlocal x_v
return x_v
return x_v,
def set_args_0(__args):
nonlocal x_v
x_v = __args
x_v, = __args
def true_fn_0():
nonlocal x_v
......@@ -96,11 +96,11 @@ class StaticCode1():
def get_args_1():
nonlocal __return_value_0, label, x_v
return __return_value_0, label, x_v
return __return_value_0, label, x_v,
def set_args_1(__args):
nonlocal __return_value_0, label, x_v
__return_value_0, label, x_v = __args
__return_value_0, label, x_v, = __args
def true_fn_1():
nonlocal __return_value_0, label, x_v
......@@ -131,11 +131,11 @@ class StaticCode2():
def get_args_2():
nonlocal x_v
return x_v
return x_v,
def set_args_2(__args):
nonlocal x_v
x_v = __args
x_v, = __args
def true_fn_2():
nonlocal x_v
......@@ -153,11 +153,11 @@ class StaticCode2():
def get_args_3():
nonlocal __return_value_1, label, x_v
return __return_value_1, label, x_v
return __return_value_1, label, x_v,
def set_args_3(__args):
nonlocal __return_value_1, label, x_v
__return_value_1, label, x_v = __args
__return_value_1, label, x_v, = __args
def true_fn_3():
nonlocal __return_value_1, label, x_v
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册