未验证 提交 40a77319 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat]Enhance nonlocal machanism while nonlocal vars is empty (#43848)

上级 e6e1c5e7
...@@ -524,7 +524,8 @@ def transform_if_else(node, root): ...@@ -524,7 +524,8 @@ def transform_if_else(node, root):
if ARGS_NAME in nonlocal_names: if ARGS_NAME in nonlocal_names:
nonlocal_names.remove(ARGS_NAME) nonlocal_names.remove(ARGS_NAME)
nonlocal_stmt_node = [create_nonlocal_stmt_node(nonlocal_names)] nonlocal_stmt_node = [create_nonlocal_stmt_node(nonlocal_names)
] if nonlocal_names else []
empty_arg_node = gast.arguments(args=[], empty_arg_node = gast.arguments(args=[],
posonlyargs=[], posonlyargs=[],
...@@ -557,8 +558,20 @@ def create_get_args_node(names): ...@@ -557,8 +558,20 @@ def create_get_args_node(names):
def get_args_0(): def get_args_0():
nonlocal x, y nonlocal x, y
return x, y
""" """
def empty_node():
func_def = """
def {func_name}():
return
""".format(func_name=unique_name.generate(GET_ARGS_FUNC_PREFIX))
return gast.parse(textwrap.dedent(func_def)).body[0]
assert isinstance(names, (list, tuple)) assert isinstance(names, (list, tuple))
if not names:
return empty_node()
template = """ template = """
def {func_name}(): def {func_name}():
nonlocal {vars} nonlocal {vars}
...@@ -578,7 +591,19 @@ def create_set_args_node(names): ...@@ -578,7 +591,19 @@ def create_set_args_node(names):
nonlocal x, y nonlocal x, y
x, y = __args x, y = __args
""" """
def empty_node():
func_def = """
def {func_name}({args}):
pass
""".format(func_name=unique_name.generate(SET_ARGS_FUNC_PREFIX),
args=ARGS_NAME)
return gast.parse(textwrap.dedent(func_def)).body[0]
assert isinstance(names, (list, tuple)) assert isinstance(names, (list, tuple))
if not names:
return empty_node()
template = """ template = """
def {func_name}({args}): def {func_name}({args}):
nonlocal {vars} nonlocal {vars}
......
...@@ -28,6 +28,18 @@ def loss_fn(x, lable): ...@@ -28,6 +28,18 @@ def loss_fn(x, lable):
return loss return loss
def dyfunc_empty_nonlocal(x):
flag = True
if flag:
print("It's a test for empty nonlocal stmt")
if paddle.mean(x) < 0:
x + 1
out = x * 2
return out
def dyfunc_with_if_else(x_v, label=None): def dyfunc_with_if_else(x_v, label=None):
if fluid.layers.mean(x_v).numpy()[0] > 5: if fluid.layers.mean(x_v).numpy()[0] > 5:
x_v = x_v - 1 x_v = x_v - 1
......
...@@ -73,6 +73,13 @@ class TestDygraphIfElse3(TestDygraphIfElse): ...@@ -73,6 +73,13 @@ class TestDygraphIfElse3(TestDygraphIfElse):
self.dyfunc = dyfunc_with_if_else3 self.dyfunc = dyfunc_with_if_else3
class TestDygraphIfElse4(TestDygraphIfElse):
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
self.dyfunc = dyfunc_empty_nonlocal
class TestDygraphIfElseWithListGenerator(TestDygraphIfElse): class TestDygraphIfElseWithListGenerator(TestDygraphIfElse):
def setUp(self): def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册