From 8571833fc7e00b37ccd417c57a94415210a82d6b Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Fri, 1 Jul 2022 11:01:54 +0800 Subject: [PATCH] [Dy2Stat]Enhance nonlocal machanism while returning single var (#43957) * [Dy2Stat]Enhance nonlocal machanism while returning single var * [Dy2Stat]Enhance nonlocal machanism while returning single var --- .../dygraph_to_static/convert_operators.py | 15 +++++---------- .../fluid/dygraph/dygraph_to_static/utils.py | 12 +++++------- .../dygraph_to_static/test_program_translator.py | 16 ++++++++-------- 3 files changed, 18 insertions(+), 25 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py index a6cab0db513..c0c679e2e1e 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py @@ -50,18 +50,13 @@ def convert_while_loop(cond, body, getter, setter): def _run_paddle_while(cond, body, getter, setter): # NOTE: loop_vars of Paddle op `control_flow.while_loop` must be Paddle Tensors. - def to_list(x): - if isinstance(x, (tuple, list)): return x - return [x] # UndefinedVar will become data layer not check. - loop_vars = [to_static_variable(var) for var in to_list(getter())] - setter(loop_vars if len(loop_vars) > 1 else - loop_vars[0]) # change the non-local var to variable + loop_vars = [to_static_variable(var) for var in getter()] + setter(loop_vars) # change the non-local var to variable # variable maybe modified to inner var. change it into loop_vars = control_flow.while_loop(cond, body, loop_vars) - setter(loop_vars if len(loop_vars) > 1 else - loop_vars[0]) # change the non-local var to variable + setter(loop_vars) # change the non-local var to variable return loop_vars @@ -318,11 +313,11 @@ def _recover_args_state(outs, get_args, set_args, return_name_ids): init_args = get_args() # recover args state num_outs = len(return_name_ids) - num_args = 1 if not isinstance(init_args, tuple) else len(init_args) + num_args = len(init_args) assert num_outs <= num_args if num_args == 1: - final_outs = outs + final_outs = (outs, ) else: outs = (outs, ) if num_outs == 1 else outs final_outs = outs + init_args[num_outs:] diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index 466e9ee4d34..b51635b85f9 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -38,6 +38,9 @@ from paddle.fluid import core PADDLE_MODULE_PREFIX = 'paddle.' DYGRAPH_MODULE_PREFIX = 'paddle.fluid.dygraph' DYGRAPH_TO_STATIC_MODULE_PREFIX = 'paddle.fluid.dygraph.dygraph_to_static' +GET_ARGS_FUNC_PREFIX = 'get_args' +SET_ARGS_FUNC_PREFIX = 'set_args' +ARGS_NAME = '__args' class BaseNodeVisitor(gast.NodeVisitor): @@ -1619,7 +1622,7 @@ def create_get_args_node(names): template = """ def {func_name}(): nonlocal {nonlocal_vars} - return {vars} + return {vars}, """ func_def = template.format( func_name=unique_name.generate(GET_ARGS_FUNC_PREFIX), @@ -1628,11 +1631,6 @@ def create_get_args_node(names): return gast.parse(textwrap.dedent(func_def)).body[0] -GET_ARGS_FUNC_PREFIX = 'get_args' -SET_ARGS_FUNC_PREFIX = 'set_args' -ARGS_NAME = '__args' - - def create_set_args_node(names): """ Create set_args function as follows: @@ -1661,7 +1659,7 @@ def create_set_args_node(names): template = """ def {func_name}({args}): nonlocal {nonlocal_vars} - {vars} = {args} + {vars}, = {args} """ func_def = template.format( func_name=unique_name.generate(SET_ARGS_FUNC_PREFIX), diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py index 41968278f7b..8d2665129e9 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py @@ -74,11 +74,11 @@ class StaticCode1(): def get_args_0(): nonlocal x_v - return x_v + return x_v, def set_args_0(__args): nonlocal x_v - x_v = __args + x_v, = __args def true_fn_0(): nonlocal x_v @@ -96,11 +96,11 @@ class StaticCode1(): def get_args_1(): nonlocal __return_value_0, label, x_v - return __return_value_0, label, x_v + return __return_value_0, label, x_v, def set_args_1(__args): nonlocal __return_value_0, label, x_v - __return_value_0, label, x_v = __args + __return_value_0, label, x_v, = __args def true_fn_1(): nonlocal __return_value_0, label, x_v @@ -131,11 +131,11 @@ class StaticCode2(): def get_args_2(): nonlocal x_v - return x_v + return x_v, def set_args_2(__args): nonlocal x_v - x_v = __args + x_v, = __args def true_fn_2(): nonlocal x_v @@ -153,11 +153,11 @@ class StaticCode2(): def get_args_3(): nonlocal __return_value_1, label, x_v - return __return_value_1, label, x_v + return __return_value_1, label, x_v, def set_args_3(__args): nonlocal __return_value_1, label, x_v - __return_value_1, label, x_v = __args + __return_value_1, label, x_v, = __args def true_fn_3(): nonlocal __return_value_1, label, x_v -- GitLab