未验证 提交 e5f0e6b0 编写于 作者: L liym27 提交者: GitHub

[Dynamic-to-Static] Fix bug in loop_transformer: loop vars should contain the...

[Dynamic-to-Static] Fix bug in loop_transformer: loop vars should contain the var from ancestor-for-node (#28735)
上级 04cefeac
...@@ -294,11 +294,21 @@ class NameVisitor(gast.NodeVisitor): ...@@ -294,11 +294,21 @@ class NameVisitor(gast.NodeVisitor):
return True return True
return False return False
def _is_ancestor_node(self, ancestor_node, node):
parent_node = self._get_parent_node(node)
while parent_node is not None:
if parent_node == ancestor_node:
return True
parent_node = self._get_parent_node(parent_node)
return False
def _get_parent_node(self, node): def _get_parent_node(self, node):
wrapper_node = self.node_to_wrapper_map.get(node) wrapper_node = self.node_to_wrapper_map.get(node)
if wrapper_node: if wrapper_node:
parent_node = wrapper_node.parent.node if wrapper_node.parent:
return parent_node parent_node = wrapper_node.parent.node
return parent_node
return None return None
def _remove_unnecessary_vars(self, loop_vars, loop_node): def _remove_unnecessary_vars(self, loop_vars, loop_node):
...@@ -355,9 +365,22 @@ class NameVisitor(gast.NodeVisitor): ...@@ -355,9 +365,22 @@ class NameVisitor(gast.NodeVisitor):
if child_node.id in target_var_names: if child_node.id in target_var_names:
vars_of_list_generator.add(child_node) vars_of_list_generator.add(child_node)
# 2. Get target vars or vars from target vars used in for-loop. # 2. Get target vars or vars from target vars used in for-loop but the for-loop is
elif isinstance(parent_node, # 1) not the "loop_node" itself
gast.For) and parent_node is not loop_node: # 2) not the ancestor of the "loop_node"
#
# For examples:
# for k in range(x): # if it's this "loop_node", i or j both should be target vars.
# # do something
#
# for i in range(a): # if it's this "loop_node", k or j should be in target vars but i should not.
# for j in range(a): # if it's this "loop_node", k should be in target_vars but i or j should not.
# x = i+j
elif isinstance(parent_node, gast.For):
if parent_node is loop_node:
continue
if self._is_ancestor_node(parent_node, loop_node):
continue
# 2.1 target vars in gast.For node. # 2.1 target vars in gast.For node.
target_node = parent_node.target target_node = parent_node.target
if isinstance(target_node, gast.Tuple): if isinstance(target_node, gast.Tuple):
......
...@@ -161,7 +161,7 @@ def nested_for_loop_dyfunc(): ...@@ -161,7 +161,7 @@ def nested_for_loop_dyfunc():
three = fluid.layers.fill_constant(shape=[1], value=3, dtype="int32") three = fluid.layers.fill_constant(shape=[1], value=3, dtype="int32")
for j in range(two): for j in range(two):
for i in range(10): for i in range(10):
a = 2 a = 2 + j
for i in range(three): for i in range(three):
b = fluid.layers.zeros(shape=[1], dtype='float32') b = fluid.layers.zeros(shape=[1], dtype='float32')
...@@ -216,16 +216,25 @@ class TestNameVisitor(unittest.TestCase): ...@@ -216,16 +216,25 @@ class TestNameVisitor(unittest.TestCase):
self.loop_var_names = [ self.loop_var_names = [
set(["j", "two"]), set(["j", "two"]),
set(["i", "three", "b"]), set(["i", "three", "b"]),
set(["i"]), set(["i", "j"]),
] ]
self.create_var_names = [set(), set(["b"]), set()] self.create_var_names = [set(), set(["b"]), set()]
i = 0 i = 0
for node in gast.walk(gast_root): for node in gast.walk(gast_root):
if isinstance(node, (gast.While, gast.For)): if isinstance(node, (gast.While, gast.For)):
loop_var_names, create_var_names = name_visitor.get_loop_var_names( loop_var_names, create_var_names = name_visitor.get_loop_var_names(
node) node)
self.assertEqual(loop_var_names, self.loop_var_names[i]) self.assertEqual(
self.assertEqual(create_var_names, self.create_var_names[i]) loop_var_names,
self.loop_var_names[i],
msg="loop_var_names : {}, \nexpected loop_var_names : {}".
format(loop_var_names, self.loop_var_names[i]))
self.assertEqual(
create_var_names,
self.create_var_names[i],
msg="i = {}\ncreate_var_names : {}, \nexpected create_var_names : {}".
format(i, create_var_names, self.create_var_names[i]))
i += 1 i += 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册