未验证 提交 8571833f 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat]Enhance nonlocal machanism while returning single var (#43957)

* [Dy2Stat]Enhance nonlocal machanism while returning single var

* [Dy2Stat]Enhance nonlocal machanism while returning single var
上级 ccb333c1
...@@ -50,18 +50,13 @@ def convert_while_loop(cond, body, getter, setter): ...@@ -50,18 +50,13 @@ def convert_while_loop(cond, body, getter, setter):
def _run_paddle_while(cond, body, getter, setter): def _run_paddle_while(cond, body, getter, setter):
# NOTE: loop_vars of Paddle op `control_flow.while_loop` must be Paddle Tensors. # NOTE: loop_vars of Paddle op `control_flow.while_loop` must be Paddle Tensors.
def to_list(x):
if isinstance(x, (tuple, list)): return x
return [x]
# UndefinedVar will become data layer not check. # UndefinedVar will become data layer not check.
loop_vars = [to_static_variable(var) for var in to_list(getter())] loop_vars = [to_static_variable(var) for var in getter()]
setter(loop_vars if len(loop_vars) > 1 else setter(loop_vars) # change the non-local var to variable
loop_vars[0]) # change the non-local var to variable
# variable maybe modified to inner var. change it into # variable maybe modified to inner var. change it into
loop_vars = control_flow.while_loop(cond, body, loop_vars) loop_vars = control_flow.while_loop(cond, body, loop_vars)
setter(loop_vars if len(loop_vars) > 1 else setter(loop_vars) # change the non-local var to variable
loop_vars[0]) # change the non-local var to variable
return loop_vars return loop_vars
...@@ -318,11 +313,11 @@ def _recover_args_state(outs, get_args, set_args, return_name_ids): ...@@ -318,11 +313,11 @@ def _recover_args_state(outs, get_args, set_args, return_name_ids):
init_args = get_args() init_args = get_args()
# recover args state # recover args state
num_outs = len(return_name_ids) num_outs = len(return_name_ids)
num_args = 1 if not isinstance(init_args, tuple) else len(init_args) num_args = len(init_args)
assert num_outs <= num_args assert num_outs <= num_args
if num_args == 1: if num_args == 1:
final_outs = outs final_outs = (outs, )
else: else:
outs = (outs, ) if num_outs == 1 else outs outs = (outs, ) if num_outs == 1 else outs
final_outs = outs + init_args[num_outs:] final_outs = outs + init_args[num_outs:]
......
...@@ -38,6 +38,9 @@ from paddle.fluid import core ...@@ -38,6 +38,9 @@ from paddle.fluid import core
PADDLE_MODULE_PREFIX = 'paddle.' PADDLE_MODULE_PREFIX = 'paddle.'
DYGRAPH_MODULE_PREFIX = 'paddle.fluid.dygraph' DYGRAPH_MODULE_PREFIX = 'paddle.fluid.dygraph'
DYGRAPH_TO_STATIC_MODULE_PREFIX = 'paddle.fluid.dygraph.dygraph_to_static' DYGRAPH_TO_STATIC_MODULE_PREFIX = 'paddle.fluid.dygraph.dygraph_to_static'
GET_ARGS_FUNC_PREFIX = 'get_args'
SET_ARGS_FUNC_PREFIX = 'set_args'
ARGS_NAME = '__args'
class BaseNodeVisitor(gast.NodeVisitor): class BaseNodeVisitor(gast.NodeVisitor):
...@@ -1619,7 +1622,7 @@ def create_get_args_node(names): ...@@ -1619,7 +1622,7 @@ def create_get_args_node(names):
template = """ template = """
def {func_name}(): def {func_name}():
nonlocal {nonlocal_vars} nonlocal {nonlocal_vars}
return {vars} return {vars},
""" """
func_def = template.format( func_def = template.format(
func_name=unique_name.generate(GET_ARGS_FUNC_PREFIX), func_name=unique_name.generate(GET_ARGS_FUNC_PREFIX),
...@@ -1628,11 +1631,6 @@ def create_get_args_node(names): ...@@ -1628,11 +1631,6 @@ def create_get_args_node(names):
return gast.parse(textwrap.dedent(func_def)).body[0] return gast.parse(textwrap.dedent(func_def)).body[0]
GET_ARGS_FUNC_PREFIX = 'get_args'
SET_ARGS_FUNC_PREFIX = 'set_args'
ARGS_NAME = '__args'
def create_set_args_node(names): def create_set_args_node(names):
""" """
Create set_args function as follows: Create set_args function as follows:
...@@ -1661,7 +1659,7 @@ def create_set_args_node(names): ...@@ -1661,7 +1659,7 @@ def create_set_args_node(names):
template = """ template = """
def {func_name}({args}): def {func_name}({args}):
nonlocal {nonlocal_vars} nonlocal {nonlocal_vars}
{vars} = {args} {vars}, = {args}
""" """
func_def = template.format( func_def = template.format(
func_name=unique_name.generate(SET_ARGS_FUNC_PREFIX), func_name=unique_name.generate(SET_ARGS_FUNC_PREFIX),
......
...@@ -74,11 +74,11 @@ class StaticCode1(): ...@@ -74,11 +74,11 @@ class StaticCode1():
def get_args_0(): def get_args_0():
nonlocal x_v nonlocal x_v
return x_v return x_v,
def set_args_0(__args): def set_args_0(__args):
nonlocal x_v nonlocal x_v
x_v = __args x_v, = __args
def true_fn_0(): def true_fn_0():
nonlocal x_v nonlocal x_v
...@@ -96,11 +96,11 @@ class StaticCode1(): ...@@ -96,11 +96,11 @@ class StaticCode1():
def get_args_1(): def get_args_1():
nonlocal __return_value_0, label, x_v nonlocal __return_value_0, label, x_v
return __return_value_0, label, x_v return __return_value_0, label, x_v,
def set_args_1(__args): def set_args_1(__args):
nonlocal __return_value_0, label, x_v nonlocal __return_value_0, label, x_v
__return_value_0, label, x_v = __args __return_value_0, label, x_v, = __args
def true_fn_1(): def true_fn_1():
nonlocal __return_value_0, label, x_v nonlocal __return_value_0, label, x_v
...@@ -131,11 +131,11 @@ class StaticCode2(): ...@@ -131,11 +131,11 @@ class StaticCode2():
def get_args_2(): def get_args_2():
nonlocal x_v nonlocal x_v
return x_v return x_v,
def set_args_2(__args): def set_args_2(__args):
nonlocal x_v nonlocal x_v
x_v = __args x_v, = __args
def true_fn_2(): def true_fn_2():
nonlocal x_v nonlocal x_v
...@@ -153,11 +153,11 @@ class StaticCode2(): ...@@ -153,11 +153,11 @@ class StaticCode2():
def get_args_3(): def get_args_3():
nonlocal __return_value_1, label, x_v nonlocal __return_value_1, label, x_v
return __return_value_1, label, x_v return __return_value_1, label, x_v,
def set_args_3(__args): def set_args_3(__args):
nonlocal __return_value_1, label, x_v nonlocal __return_value_1, label, x_v
__return_value_1, label, x_v = __args __return_value_1, label, x_v, = __args
def true_fn_3(): def true_fn_3():
nonlocal __return_value_1, label, x_v nonlocal __return_value_1, label, x_v
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册