未验证 提交 67d77846 编写于 作者: X xiongkun 提交者: GitHub

[Dy2Static] fix non-local error while dealing push_pop names (#45828)

* 1. fix non-local error while dealing push_pop names
2. escape "'" in push_pop_names to avoid syntax errors.
3. unified the non-local stmt creation processes in getter and setter.
4. split the nonlocal_names and getter/setter names.

* fix bugs

* 1. revert setter and getter, push_pop_names must have non-local

* fix bugs.

* code format
上级 1a929c31
...@@ -35,6 +35,7 @@ from paddle.fluid.layer_helper import LayerHelper ...@@ -35,6 +35,7 @@ from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.layers import assign from paddle.fluid.layers import assign
import collections import collections
from functools import reduce from functools import reduce
import warnings
# Note(Aurelius): Do not forget the dot `.` to distinguish other # Note(Aurelius): Do not forget the dot `.` to distinguish other
# module such as paddlenlp. # module such as paddlenlp.
...@@ -1024,6 +1025,7 @@ class NameScope: ...@@ -1024,6 +1025,7 @@ class NameScope:
self.father = None # point to the nearest function name scope. self.father = None # point to the nearest function name scope.
self.w_vars = set() # all qualified + normal names been stored self.w_vars = set() # all qualified + normal names been stored
self.created = set() # useful for control flow compatibility self.created = set() # useful for control flow compatibility
# only valid in control_flow nodes
# may be remove later. # may be remove later.
self.push_pop_vars = set() # we call push and pop in the vars self.push_pop_vars = set() # we call push and pop in the vars
...@@ -1045,15 +1047,54 @@ class NameScope: ...@@ -1045,15 +1047,54 @@ class NameScope:
return self.w_vars return self.w_vars
def variadic_length_vars(self): def variadic_length_vars(self):
return self.push_pop_vars """
At present, we do not support global append, such as
import numpy as np
a = []
def func():
a.append() # global names `a`, we will raise a warning.
p.append(a, 1) # global names `np`, we will raise a warning.
"""
non_global_push_pop_names = []
for var in self.push_pop_vars:
if self._is_simple_name(var) and self.is_global_var(var):
warnings.warn(
f"Find variable `{var}` defined in global scope"
f" and call `{var}.append() or {var}.pop()`"
f", which will be ignored and never be transfered into"
f" tensor array.")
else:
non_global_push_pop_names.append(var)
return set(non_global_push_pop_names)
def control_flow_vars(self): def control_flow_vars(self):
valid_names = self.w_vars valid_names = self.w_vars
tmp = self.father.global_vars & valid_names, tmp = self.father.global_vars & valid_names,
return {"global": tmp, "nonlocal": self.w_vars - tmp} return {"global": tmp, "nonlocal": self.w_vars - tmp}
def global_vars(self): def _is_simple_name(self, name):
return self.globals if '.' in name or '[' in name: return False
return True
def is_global_var(self, name):
"""
Return whether the name is a var created in global scope.
Search from bottom to top. If it is not created or modified,
it means global vars; otherwise, it means local vars.
Only valid after FunctionNameLivenessAnalysis visitor.
"""
assert self._is_simple_name(
name), "is_global_var accept a simple name, but get `{name}`."
ancestor = self
while ancestor is not None:
if name in ancestor.globals: return True
if name in (ancestor.nonlocals | ancestor.w_vars): return False
ancestor = ancestor.father
return True
def is_local_var(self, name):
return not self.is_global_var(name)
def merge_from(self, name_scope): def merge_from(self, name_scope):
self.globals |= name_scope.globals self.globals |= name_scope.globals
...@@ -1186,7 +1227,7 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor): ...@@ -1186,7 +1227,7 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor):
""" """
self._reset_name_scope(node) self._reset_name_scope(node)
self.scope_node_stack.append(node) self.scope_node_stack.append(node)
self._current_name_scope().father = self._nearest_function_scope() self._current_name_scope().set_father(self._nearest_function_scope())
if pre_func: pre_func() if pre_func: pre_func()
self.generic_visit(node) self.generic_visit(node)
if post_func: post_func() if post_func: post_func()
...@@ -1274,16 +1315,13 @@ def create_get_args_node(names): ...@@ -1274,16 +1315,13 @@ def create_get_args_node(names):
return gast.parse(textwrap.dedent(func_def)).body[0] return gast.parse(textwrap.dedent(func_def)).body[0]
assert isinstance(names, (list, tuple)) assert isinstance(names, (list, tuple))
mapped = list(filter(lambda n: '.' not in n, names)) node = create_nonlocal_stmt_nodes(names)
nonlocal_names = sorted(
mapped,
key=mapped.index) # to keep the order, we can't use set() to unique
if not names: if not names:
return empty_node() return empty_node()
if not nonlocal_names: if node == []:
nonlocal_vars = "\n" nonlocal_vars = "\n"
else: else:
nonlocal_vars = "nonlocal " + ",".join(nonlocal_names) nonlocal_vars = ast_to_source_code(node[0])
template = """ template = """
def {func_name}(): def {func_name}():
{nonlocal_vars} {nonlocal_vars}
...@@ -1314,16 +1352,13 @@ def create_set_args_node(names): ...@@ -1314,16 +1352,13 @@ def create_set_args_node(names):
return gast.parse(textwrap.dedent(func_def)).body[0] return gast.parse(textwrap.dedent(func_def)).body[0]
assert isinstance(names, (list, tuple)) assert isinstance(names, (list, tuple))
mapped = list(filter(lambda n: '.' not in n, names)) node = create_nonlocal_stmt_nodes(names)
nonlocal_names = sorted(
mapped,
key=mapped.index) # to keep the order, we can't use set() to unique
if not names: if not names:
return empty_node() return empty_node()
if not nonlocal_names: if node == []:
nonlocal_vars = "\n" nonlocal_vars = "\n"
else: else:
nonlocal_vars = "nonlocal " + ",".join(nonlocal_names) nonlocal_vars = ast_to_source_code(node[0])
template = """ template = """
def {func_name}({args}): def {func_name}({args}):
{nonlocal_vars} {nonlocal_vars}
...@@ -1341,6 +1376,7 @@ def create_nonlocal_stmt_nodes(names): ...@@ -1341,6 +1376,7 @@ def create_nonlocal_stmt_nodes(names):
assert isinstance(names, (list, tuple)) assert isinstance(names, (list, tuple))
mapped = list(filter(lambda n: '.' not in n, names)) mapped = list(filter(lambda n: '.' not in n, names))
mapped = list(filter(lambda n: '[' not in n, mapped))
names = sorted( names = sorted(
mapped, mapped,
key=mapped.index) # to keep the order, we can't use set() to unique key=mapped.index) # to keep the order, we can't use set() to unique
...@@ -1400,5 +1436,5 @@ def create_name_str(name_ids): ...@@ -1400,5 +1436,5 @@ def create_name_str(name_ids):
if not name_ids: if not name_ids:
return 'None' return 'None'
names_str = ["'%s'" % name for name in name_ids] names_str = ["'%s'" % (name.replace("'", "\\'")) for name in name_ids]
return "(%s, )" % ','.join(names_str) return "(%s, )" % ','.join(names_str)
...@@ -20,6 +20,9 @@ import paddle ...@@ -20,6 +20,9 @@ import paddle
from paddle.fluid.dygraph.dygraph_to_static.utils import FunctionNameLivenessAnalysis from paddle.fluid.dygraph.dygraph_to_static.utils import FunctionNameLivenessAnalysis
from paddle.utils import gast from paddle.utils import gast
import inspect import inspect
from numpy import append
global_a = []
class JudgeVisitor(gast.NodeVisitor): class JudgeVisitor(gast.NodeVisitor):
...@@ -257,5 +260,70 @@ class TestClosureAnalysis_PushPop(TestClosureAnalysis): ...@@ -257,5 +260,70 @@ class TestClosureAnalysis_PushPop(TestClosureAnalysis):
}] }]
class TestPushPopTrans(unittest.TestCase):
def test(self):
def vlist_of_dict(x):
ma = {'a': []}
for i in range(3):
ma['a'].append(1)
return ma
x = paddle.to_tensor([3])
print(paddle.jit.to_static(vlist_of_dict).code)
print(paddle.jit.to_static(vlist_of_dict)(x))
def test2(self):
import numpy as np
def vlist_of_dict(x):
a = np.array([1, 2, 3])
for i in range(3):
np.append(a, 4)
return a
x = paddle.to_tensor([3])
print(paddle.jit.to_static(vlist_of_dict).code)
print(paddle.jit.to_static(vlist_of_dict)(x))
def test3(self):
import numpy as np
def vlist_of_dict(x):
a = np.array([1, 2, 3])
if True:
pass
return a
x = paddle.to_tensor([3])
print(paddle.jit.to_static(vlist_of_dict).code)
print(paddle.jit.to_static(vlist_of_dict)(x))
def test4(self):
def vlist_of_dict(x):
a = np.array([1, 2, 3])
for i in range(3):
append(a, 4)
return a
x = paddle.to_tensor([3])
print(paddle.jit.to_static(vlist_of_dict).code)
print(paddle.jit.to_static(vlist_of_dict)(x))
def test5(self):
def vlist_of_dict(x):
a = np.array([1, 2, 3])
for i in range(3):
global_a.append(4)
return a
x = paddle.to_tensor([3])
print(paddle.jit.to_static(vlist_of_dict).code)
print(paddle.jit.to_static(vlist_of_dict)(x))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册