未验证 提交 e5f0e6b0 编写于 作者: L liym27 提交者: GitHub

[Dynamic-to-Static] Fix bug in loop_transformer: loop vars should contain the...

[Dynamic-to-Static] Fix bug in loop_transformer: loop vars should contain the var from ancestor-for-node (#28735)
上级 04cefeac
......@@ -294,9 +294,19 @@ class NameVisitor(gast.NodeVisitor):
return True
return False
def _is_ancestor_node(self, ancestor_node, node):
parent_node = self._get_parent_node(node)
while parent_node is not None:
if parent_node == ancestor_node:
return True
parent_node = self._get_parent_node(parent_node)
return False
def _get_parent_node(self, node):
wrapper_node = self.node_to_wrapper_map.get(node)
if wrapper_node:
if wrapper_node.parent:
parent_node = wrapper_node.parent.node
return parent_node
return None
......@@ -355,9 +365,22 @@ class NameVisitor(gast.NodeVisitor):
if child_node.id in target_var_names:
vars_of_list_generator.add(child_node)
# 2. Get target vars or vars from target vars used in for-loop.
elif isinstance(parent_node,
gast.For) and parent_node is not loop_node:
# 2. Get target vars or vars from target vars used in for-loop but the for-loop is
# 1) not the "loop_node" itself
# 2) not the ancestor of the "loop_node"
#
# For examples:
# for k in range(x): # if it's this "loop_node", i or j both should be target vars.
# # do something
#
# for i in range(a): # if it's this "loop_node", k or j should be in target vars but i should not.
# for j in range(a): # if it's this "loop_node", k should be in target_vars but i or j should not.
# x = i+j
elif isinstance(parent_node, gast.For):
if parent_node is loop_node:
continue
if self._is_ancestor_node(parent_node, loop_node):
continue
# 2.1 target vars in gast.For node.
target_node = parent_node.target
if isinstance(target_node, gast.Tuple):
......
......@@ -161,7 +161,7 @@ def nested_for_loop_dyfunc():
three = fluid.layers.fill_constant(shape=[1], value=3, dtype="int32")
for j in range(two):
for i in range(10):
a = 2
a = 2 + j
for i in range(three):
b = fluid.layers.zeros(shape=[1], dtype='float32')
......@@ -216,16 +216,25 @@ class TestNameVisitor(unittest.TestCase):
self.loop_var_names = [
set(["j", "two"]),
set(["i", "three", "b"]),
set(["i"]),
set(["i", "j"]),
]
self.create_var_names = [set(), set(["b"]), set()]
i = 0
for node in gast.walk(gast_root):
if isinstance(node, (gast.While, gast.For)):
loop_var_names, create_var_names = name_visitor.get_loop_var_names(
node)
self.assertEqual(loop_var_names, self.loop_var_names[i])
self.assertEqual(create_var_names, self.create_var_names[i])
self.assertEqual(
loop_var_names,
self.loop_var_names[i],
msg="loop_var_names : {}, \nexpected loop_var_names : {}".
format(loop_var_names, self.loop_var_names[i]))
self.assertEqual(
create_var_names,
self.create_var_names[i],
msg="i = {}\ncreate_var_names : {}, \nexpected create_var_names : {}".
format(i, create_var_names, self.create_var_names[i]))
i += 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册