From a0846b627a022d3b145d696273596c7117cd9c00 Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Tue, 26 May 2020 17:30:05 +0800 Subject: [PATCH] Remove target vars of gast.For from before_loop_vars or after_loop_vars (#24732) --- .../dygraph_to_static/loop_transformer.py | 62 ++++++++++++++++++- .../unittests/dygraph_to_static/test_loop.py | 37 +++++++++++ 2 files changed, 96 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py index d4d1ff6ba2d..b9e6eff2f9b 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py @@ -166,13 +166,19 @@ class NameVisitor(gast.NodeVisitor): in_loop_vars = self.in_loop_vars[node] in_loop_name_strs = self._var_nodes_to_names(in_loop_vars) + before_loop_body_vars = self.before_loop_body_vars[node] + before_loop_body_vars = self._remove_target_vars_of_for( + before_loop_body_vars, node) before_loop_name_strs = self._var_nodes_to_names(before_loop_body_vars) + after_loop_vars = self.current_seen_vars - before_loop_body_vars - in_loop_vars + after_loop_vars = self._remove_target_vars_of_for(after_loop_vars, node) after_loop_name_strs = self._var_nodes_to_names(after_loop_vars, read_context) condition_vars = self.condition_vars[node] condition_names = self._var_nodes_to_names(condition_vars) + write_vars = self.write_in_loop[node] write_names = self._var_nodes_to_names(write_vars) @@ -203,6 +209,7 @@ class NameVisitor(gast.NodeVisitor): # vars out loop_var_names.add(name) create_var_names.add(name) + return loop_var_names, create_var_names def visit_Name(self, node): @@ -221,8 +228,8 @@ class NameVisitor(gast.NodeVisitor): self.in_loop_vars[loop_node].add(node) if type(node.ctx) in write_context: self.write_in_loop[loop_node].add(node) - if self.in_condition: - self.condition_vars[loop_node].add(node) + if self.in_condition: + self.condition_vars[loop_node].add(node) self.generic_visit(node) def visit_FunctionDef(self, node): @@ -309,11 +316,60 @@ class NameVisitor(gast.NodeVisitor): return False def _is_call_func_name_node(self, node): - parent_node = self.node_to_wrapper_map[node].parent.node + parent_node = self._get_parent_node(node) if isinstance(parent_node, gast.Call) and parent_node.func == node: return True return False + def _get_parent_node(self, node): + wrapper_node = self.node_to_wrapper_map.get(node) + if wrapper_node: + parent_node = wrapper_node.parent.node + return parent_node + return None + + def _remove_target_vars_of_for(self, before_or_after_loop_vars, loop_node): + """ + Remove target vars of gast.For from before_loop_vars or after_loop_vars. + :param before_or_after_loop_vars: before_loop_vars or after_loop_vars of loop_node. + :param loop_node: Current loop node. + """ + + removed_vars = set() + for name_node in before_or_after_loop_vars: + if not isinstance(name_node, gast.Name): + continue + + parent_node = self._get_parent_node(name_node) + + # NOTE: gast.For.target can be gast.Tuple. + # For example: `for i, j in enumerate(x)` has two target vars: i and j + if isinstance(parent_node, gast.Tuple): + parent_node = self._get_parent_node(parent_node) + + if isinstance(parent_node, + gast.For) and parent_node is not loop_node: + target_node = parent_node.target + + if isinstance(target_node, gast.Tuple): + target_vars = target_node.elts + else: + target_vars = [target_node] + + if name_node in target_vars: + removed_vars.add(name_node) + + removed_vars_name_strs = {var.id for var in removed_vars} + + for var in before_or_after_loop_vars: + if not isinstance(var, gast.Name): + continue + if var.id in removed_vars_name_strs and var not in self.condition_vars[ + loop_node]: + removed_vars.add(var) + + return before_or_after_loop_vars - removed_vars + class LoopTransformer(gast.NodeTransformer): """ diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py index 66f153d9ef0..08b1336152c 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py @@ -132,6 +132,19 @@ def var_create_in_for_loop(max_len): return ret +def nested_for_loop_dyfunc(): + two = fluid.layers.fill_constant(shape=[1], value=2, dtype="int32") + three = fluid.layers.fill_constant(shape=[1], value=3, dtype="int32") + for j in range(two): + for i in range(10): + a = 2 + + for i in range(three): + b = fluid.layers.zeros(shape=[1], dtype='float32') + + return b + + class TestNameVisitor(unittest.TestCase): def setUp(self): self.loop_funcs = [ @@ -142,6 +155,8 @@ class TestNameVisitor(unittest.TestCase): ] self.create_var_names = [set(), set(["ret"]), set()] + self.nested_for_loop_func = nested_for_loop_dyfunc + def test_loop_vars(self): for i in range(len(self.loop_funcs)): func = self.loop_funcs[i] @@ -155,6 +170,28 @@ class TestNameVisitor(unittest.TestCase): self.assertEqual(loop_var_names, self.loop_var_names[i]) self.assertEqual(create_var_names, self.create_var_names[i]) + def test_nested_loop_vars(self): + func = self.nested_for_loop_func + test_func = inspect.getsource(func) + gast_root = gast.parse(test_func) + name_visitor = NameVisitor(gast_root) + + self.loop_var_names = [ + set(["j", "two"]), + set(["i", "three", "b"]), + set(["i"]), + ] + self.create_var_names = [set(), set(["b"]), set()] + i = 0 + for node in gast.walk(gast_root): + if isinstance(node, (gast.While, gast.For)): + loop_var_names, create_var_names = name_visitor.get_loop_var_names( + node) + # print(loop_var_names) + self.assertEqual(loop_var_names, self.loop_var_names[i]) + self.assertEqual(create_var_names, self.create_var_names[i]) + i += 1 + class TestTransformWhileLoop(unittest.TestCase): def setUp(self): -- GitLab