From 40a773198b42f2794c97c831a24e510164c0c664 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Mon, 27 Jun 2022 14:42:43 +0800 Subject: [PATCH] [Dy2Stat]Enhance nonlocal machanism while nonlocal vars is empty (#43848) --- .../dygraph_to_static/ifelse_transformer.py | 27 ++++++++++++++++++- .../dygraph_to_static/ifelse_simple_func.py | 12 +++++++++ .../dygraph_to_static/test_ifelse.py | 7 +++++ 3 files changed, 45 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py index bff41c9b9ae..f4d19905975 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py @@ -524,7 +524,8 @@ def transform_if_else(node, root): if ARGS_NAME in nonlocal_names: nonlocal_names.remove(ARGS_NAME) - nonlocal_stmt_node = [create_nonlocal_stmt_node(nonlocal_names)] + nonlocal_stmt_node = [create_nonlocal_stmt_node(nonlocal_names) + ] if nonlocal_names else [] empty_arg_node = gast.arguments(args=[], posonlyargs=[], @@ -557,8 +558,20 @@ def create_get_args_node(names): def get_args_0(): nonlocal x, y + return x, y """ + + def empty_node(): + func_def = """ + def {func_name}(): + return + """.format(func_name=unique_name.generate(GET_ARGS_FUNC_PREFIX)) + return gast.parse(textwrap.dedent(func_def)).body[0] + assert isinstance(names, (list, tuple)) + if not names: + return empty_node() + template = """ def {func_name}(): nonlocal {vars} @@ -578,7 +591,19 @@ def create_set_args_node(names): nonlocal x, y x, y = __args """ + + def empty_node(): + func_def = """ + def {func_name}({args}): + pass + """.format(func_name=unique_name.generate(SET_ARGS_FUNC_PREFIX), + args=ARGS_NAME) + return gast.parse(textwrap.dedent(func_def)).body[0] + assert isinstance(names, (list, tuple)) + if not names: + return empty_node() + template = """ def {func_name}({args}): nonlocal {vars} diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py index 39565044e7f..b5ba4c89ee2 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py @@ -28,6 +28,18 @@ def loss_fn(x, lable): return loss +def dyfunc_empty_nonlocal(x): + flag = True + if flag: + print("It's a test for empty nonlocal stmt") + + if paddle.mean(x) < 0: + x + 1 + + out = x * 2 + return out + + def dyfunc_with_if_else(x_v, label=None): if fluid.layers.mean(x_v).numpy()[0] > 5: x_v = x_v - 1 diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py index 5ce163c7685..822835a8c7c 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py @@ -73,6 +73,13 @@ class TestDygraphIfElse3(TestDygraphIfElse): self.dyfunc = dyfunc_with_if_else3 +class TestDygraphIfElse4(TestDygraphIfElse): + + def setUp(self): + self.x = np.random.random([10, 16]).astype('float32') + self.dyfunc = dyfunc_empty_nonlocal + + class TestDygraphIfElseWithListGenerator(TestDygraphIfElse): def setUp(self): -- GitLab