diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index 41cd4676e608a9780e855206152391d843532e46..9a8586e3761cceac36b2fde48ae8d4a0161f509a 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -35,6 +35,7 @@ from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layers import assign import collections from functools import reduce +import warnings # Note(Aurelius): Do not forget the dot `.` to distinguish other # module such as paddlenlp. @@ -1024,6 +1025,7 @@ class NameScope: self.father = None # point to the nearest function name scope. self.w_vars = set() # all qualified + normal names been stored self.created = set() # useful for control flow compatibility + # only valid in control_flow nodes # may be remove later. self.push_pop_vars = set() # we call push and pop in the vars @@ -1045,15 +1047,54 @@ class NameScope: return self.w_vars def variadic_length_vars(self): - return self.push_pop_vars + """ + At present, we do not support global append, such as + + import numpy as np + a = [] + def func(): + a.append() # global names `a`, we will raise a warning. + p.append(a, 1) # global names `np`, we will raise a warning. + """ + non_global_push_pop_names = [] + for var in self.push_pop_vars: + if self._is_simple_name(var) and self.is_global_var(var): + warnings.warn( + f"Find variable `{var}` defined in global scope" + f" and call `{var}.append() or {var}.pop()`" + f", which will be ignored and never be transfered into" + f" tensor array.") + else: + non_global_push_pop_names.append(var) + return set(non_global_push_pop_names) def control_flow_vars(self): valid_names = self.w_vars tmp = self.father.global_vars & valid_names, return {"global": tmp, "nonlocal": self.w_vars - tmp} - def global_vars(self): - return self.globals + def _is_simple_name(self, name): + if '.' in name or '[' in name: return False + return True + + def is_global_var(self, name): + """ + Return whether the name is a var created in global scope. + Search from bottom to top. If it is not created or modified, + it means global vars; otherwise, it means local vars. + Only valid after FunctionNameLivenessAnalysis visitor. + """ + assert self._is_simple_name( + name), "is_global_var accept a simple name, but get `{name}`." + ancestor = self + while ancestor is not None: + if name in ancestor.globals: return True + if name in (ancestor.nonlocals | ancestor.w_vars): return False + ancestor = ancestor.father + return True + + def is_local_var(self, name): + return not self.is_global_var(name) def merge_from(self, name_scope): self.globals |= name_scope.globals @@ -1186,7 +1227,7 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor): """ self._reset_name_scope(node) self.scope_node_stack.append(node) - self._current_name_scope().father = self._nearest_function_scope() + self._current_name_scope().set_father(self._nearest_function_scope()) if pre_func: pre_func() self.generic_visit(node) if post_func: post_func() @@ -1274,16 +1315,13 @@ def create_get_args_node(names): return gast.parse(textwrap.dedent(func_def)).body[0] assert isinstance(names, (list, tuple)) - mapped = list(filter(lambda n: '.' not in n, names)) - nonlocal_names = sorted( - mapped, - key=mapped.index) # to keep the order, we can't use set() to unique + node = create_nonlocal_stmt_nodes(names) if not names: return empty_node() - if not nonlocal_names: + if node == []: nonlocal_vars = "\n" else: - nonlocal_vars = "nonlocal " + ",".join(nonlocal_names) + nonlocal_vars = ast_to_source_code(node[0]) template = """ def {func_name}(): {nonlocal_vars} @@ -1314,16 +1352,13 @@ def create_set_args_node(names): return gast.parse(textwrap.dedent(func_def)).body[0] assert isinstance(names, (list, tuple)) - mapped = list(filter(lambda n: '.' not in n, names)) - nonlocal_names = sorted( - mapped, - key=mapped.index) # to keep the order, we can't use set() to unique + node = create_nonlocal_stmt_nodes(names) if not names: return empty_node() - if not nonlocal_names: + if node == []: nonlocal_vars = "\n" else: - nonlocal_vars = "nonlocal " + ",".join(nonlocal_names) + nonlocal_vars = ast_to_source_code(node[0]) template = """ def {func_name}({args}): {nonlocal_vars} @@ -1341,6 +1376,7 @@ def create_nonlocal_stmt_nodes(names): assert isinstance(names, (list, tuple)) mapped = list(filter(lambda n: '.' not in n, names)) + mapped = list(filter(lambda n: '[' not in n, mapped)) names = sorted( mapped, key=mapped.index) # to keep the order, we can't use set() to unique @@ -1400,5 +1436,5 @@ def create_name_str(name_ids): if not name_ids: return 'None' - names_str = ["'%s'" % name for name in name_ids] + names_str = ["'%s'" % (name.replace("'", "\\'")) for name in name_ids] return "(%s, )" % ','.join(names_str) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_closure_analysis.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_closure_analysis.py index 52e679323267017566924c6c6a26e2d885107f3d..227191a68fe38e290dfb635d539aa997bcfbaa27 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_closure_analysis.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_closure_analysis.py @@ -20,6 +20,9 @@ import paddle from paddle.fluid.dygraph.dygraph_to_static.utils import FunctionNameLivenessAnalysis from paddle.utils import gast import inspect +from numpy import append + +global_a = [] class JudgeVisitor(gast.NodeVisitor): @@ -257,5 +260,70 @@ class TestClosureAnalysis_PushPop(TestClosureAnalysis): }] +class TestPushPopTrans(unittest.TestCase): + + def test(self): + + def vlist_of_dict(x): + ma = {'a': []} + for i in range(3): + ma['a'].append(1) + return ma + + x = paddle.to_tensor([3]) + print(paddle.jit.to_static(vlist_of_dict).code) + print(paddle.jit.to_static(vlist_of_dict)(x)) + + def test2(self): + import numpy as np + + def vlist_of_dict(x): + a = np.array([1, 2, 3]) + for i in range(3): + np.append(a, 4) + return a + + x = paddle.to_tensor([3]) + print(paddle.jit.to_static(vlist_of_dict).code) + print(paddle.jit.to_static(vlist_of_dict)(x)) + + def test3(self): + import numpy as np + + def vlist_of_dict(x): + a = np.array([1, 2, 3]) + if True: + pass + return a + + x = paddle.to_tensor([3]) + print(paddle.jit.to_static(vlist_of_dict).code) + print(paddle.jit.to_static(vlist_of_dict)(x)) + + def test4(self): + + def vlist_of_dict(x): + a = np.array([1, 2, 3]) + for i in range(3): + append(a, 4) + return a + + x = paddle.to_tensor([3]) + print(paddle.jit.to_static(vlist_of_dict).code) + print(paddle.jit.to_static(vlist_of_dict)(x)) + + def test5(self): + + def vlist_of_dict(x): + a = np.array([1, 2, 3]) + for i in range(3): + global_a.append(4) + return a + + x = paddle.to_tensor([3]) + print(paddle.jit.to_static(vlist_of_dict).code) + print(paddle.jit.to_static(vlist_of_dict)(x)) + + if __name__ == '__main__': unittest.main()