From e5f0e6b0033d1177f47c5c6212e6eea6e30d635e Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Thu, 19 Nov 2020 18:43:42 +0800 Subject: [PATCH] [Dynamic-to-Static] Fix bug in loop_transformer: loop vars should contain the var from ancestor-for-node (#28735) --- .../dygraph_to_static/loop_transformer.py | 33 ++++++++++++++++--- .../unittests/dygraph_to_static/test_loop.py | 17 +++++++--- 2 files changed, 41 insertions(+), 9 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py index 8e3ca72788..9c1271c1cd 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py @@ -294,11 +294,21 @@ class NameVisitor(gast.NodeVisitor): return True return False + def _is_ancestor_node(self, ancestor_node, node): + parent_node = self._get_parent_node(node) + + while parent_node is not None: + if parent_node == ancestor_node: + return True + parent_node = self._get_parent_node(parent_node) + return False + def _get_parent_node(self, node): wrapper_node = self.node_to_wrapper_map.get(node) if wrapper_node: - parent_node = wrapper_node.parent.node - return parent_node + if wrapper_node.parent: + parent_node = wrapper_node.parent.node + return parent_node return None def _remove_unnecessary_vars(self, loop_vars, loop_node): @@ -355,9 +365,22 @@ class NameVisitor(gast.NodeVisitor): if child_node.id in target_var_names: vars_of_list_generator.add(child_node) - # 2. Get target vars or vars from target vars used in for-loop. - elif isinstance(parent_node, - gast.For) and parent_node is not loop_node: + # 2. Get target vars or vars from target vars used in for-loop but the for-loop is + # 1) not the "loop_node" itself + # 2) not the ancestor of the "loop_node" + # + # For examples: + # for k in range(x): # if it's this "loop_node", i or j both should be target vars. + # # do something + # + # for i in range(a): # if it's this "loop_node", k or j should be in target vars but i should not. + # for j in range(a): # if it's this "loop_node", k should be in target_vars but i or j should not. + # x = i+j + elif isinstance(parent_node, gast.For): + if parent_node is loop_node: + continue + if self._is_ancestor_node(parent_node, loop_node): + continue # 2.1 target vars in gast.For node. target_node = parent_node.target if isinstance(target_node, gast.Tuple): diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py index bf9b579b68..2f107e53ab 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py @@ -161,7 +161,7 @@ def nested_for_loop_dyfunc(): three = fluid.layers.fill_constant(shape=[1], value=3, dtype="int32") for j in range(two): for i in range(10): - a = 2 + a = 2 + j for i in range(three): b = fluid.layers.zeros(shape=[1], dtype='float32') @@ -216,16 +216,25 @@ class TestNameVisitor(unittest.TestCase): self.loop_var_names = [ set(["j", "two"]), set(["i", "three", "b"]), - set(["i"]), + set(["i", "j"]), ] self.create_var_names = [set(), set(["b"]), set()] + i = 0 for node in gast.walk(gast_root): if isinstance(node, (gast.While, gast.For)): loop_var_names, create_var_names = name_visitor.get_loop_var_names( node) - self.assertEqual(loop_var_names, self.loop_var_names[i]) - self.assertEqual(create_var_names, self.create_var_names[i]) + self.assertEqual( + loop_var_names, + self.loop_var_names[i], + msg="loop_var_names : {}, \nexpected loop_var_names : {}". + format(loop_var_names, self.loop_var_names[i])) + self.assertEqual( + create_var_names, + self.create_var_names[i], + msg="i = {}\ncreate_var_names : {}, \nexpected create_var_names : {}". + format(i, create_var_names, self.create_var_names[i])) i += 1 -- GitLab