diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py index bff41c9b9ae029913c79a51f5301c6f7ab112b0b..f4d19905975d95a79724b9d08cb95bc10a0cfdcc 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py @@ -524,7 +524,8 @@ def transform_if_else(node, root): if ARGS_NAME in nonlocal_names: nonlocal_names.remove(ARGS_NAME) - nonlocal_stmt_node = [create_nonlocal_stmt_node(nonlocal_names)] + nonlocal_stmt_node = [create_nonlocal_stmt_node(nonlocal_names) + ] if nonlocal_names else [] empty_arg_node = gast.arguments(args=[], posonlyargs=[], @@ -557,8 +558,20 @@ def create_get_args_node(names): def get_args_0(): nonlocal x, y + return x, y """ + + def empty_node(): + func_def = """ + def {func_name}(): + return + """.format(func_name=unique_name.generate(GET_ARGS_FUNC_PREFIX)) + return gast.parse(textwrap.dedent(func_def)).body[0] + assert isinstance(names, (list, tuple)) + if not names: + return empty_node() + template = """ def {func_name}(): nonlocal {vars} @@ -578,7 +591,19 @@ def create_set_args_node(names): nonlocal x, y x, y = __args """ + + def empty_node(): + func_def = """ + def {func_name}({args}): + pass + """.format(func_name=unique_name.generate(SET_ARGS_FUNC_PREFIX), + args=ARGS_NAME) + return gast.parse(textwrap.dedent(func_def)).body[0] + assert isinstance(names, (list, tuple)) + if not names: + return empty_node() + template = """ def {func_name}({args}): nonlocal {vars} diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py index 39565044e7fd1cf2e5eb3db73f65bc96edc05617..b5ba4c89ee2e3c0d8b03a514b0ce169a9dcfcf32 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py @@ -28,6 +28,18 @@ def loss_fn(x, lable): return loss +def dyfunc_empty_nonlocal(x): + flag = True + if flag: + print("It's a test for empty nonlocal stmt") + + if paddle.mean(x) < 0: + x + 1 + + out = x * 2 + return out + + def dyfunc_with_if_else(x_v, label=None): if fluid.layers.mean(x_v).numpy()[0] > 5: x_v = x_v - 1 diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py index 5ce163c76855d11982e0b7cf3af75334a35bef11..822835a8c7cd1e46379df02c6d445604d646cfea 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py @@ -73,6 +73,13 @@ class TestDygraphIfElse3(TestDygraphIfElse): self.dyfunc = dyfunc_with_if_else3 +class TestDygraphIfElse4(TestDygraphIfElse): + + def setUp(self): + self.x = np.random.random([10, 16]).astype('float32') + self.dyfunc = dyfunc_empty_nonlocal + + class TestDygraphIfElseWithListGenerator(TestDygraphIfElse): def setUp(self):