未验证 提交 a0846b62 编写于 作者: L liym27 提交者: GitHub

Remove target vars of gast.For from before_loop_vars or after_loop_vars (#24732)

上级 d15fc95e
......@@ -166,13 +166,19 @@ class NameVisitor(gast.NodeVisitor):
in_loop_vars = self.in_loop_vars[node]
in_loop_name_strs = self._var_nodes_to_names(in_loop_vars)
before_loop_body_vars = self.before_loop_body_vars[node]
before_loop_body_vars = self._remove_target_vars_of_for(
before_loop_body_vars, node)
before_loop_name_strs = self._var_nodes_to_names(before_loop_body_vars)
after_loop_vars = self.current_seen_vars - before_loop_body_vars - in_loop_vars
after_loop_vars = self._remove_target_vars_of_for(after_loop_vars, node)
after_loop_name_strs = self._var_nodes_to_names(after_loop_vars,
read_context)
condition_vars = self.condition_vars[node]
condition_names = self._var_nodes_to_names(condition_vars)
write_vars = self.write_in_loop[node]
write_names = self._var_nodes_to_names(write_vars)
......@@ -203,6 +209,7 @@ class NameVisitor(gast.NodeVisitor):
# vars out
loop_var_names.add(name)
create_var_names.add(name)
return loop_var_names, create_var_names
def visit_Name(self, node):
......@@ -309,11 +316,60 @@ class NameVisitor(gast.NodeVisitor):
return False
def _is_call_func_name_node(self, node):
parent_node = self.node_to_wrapper_map[node].parent.node
parent_node = self._get_parent_node(node)
if isinstance(parent_node, gast.Call) and parent_node.func == node:
return True
return False
def _get_parent_node(self, node):
wrapper_node = self.node_to_wrapper_map.get(node)
if wrapper_node:
parent_node = wrapper_node.parent.node
return parent_node
return None
def _remove_target_vars_of_for(self, before_or_after_loop_vars, loop_node):
"""
Remove target vars of gast.For from before_loop_vars or after_loop_vars.
:param before_or_after_loop_vars: before_loop_vars or after_loop_vars of loop_node.
:param loop_node: Current loop node.
"""
removed_vars = set()
for name_node in before_or_after_loop_vars:
if not isinstance(name_node, gast.Name):
continue
parent_node = self._get_parent_node(name_node)
# NOTE: gast.For.target can be gast.Tuple.
# For example: `for i, j in enumerate(x)` has two target vars: i and j
if isinstance(parent_node, gast.Tuple):
parent_node = self._get_parent_node(parent_node)
if isinstance(parent_node,
gast.For) and parent_node is not loop_node:
target_node = parent_node.target
if isinstance(target_node, gast.Tuple):
target_vars = target_node.elts
else:
target_vars = [target_node]
if name_node in target_vars:
removed_vars.add(name_node)
removed_vars_name_strs = {var.id for var in removed_vars}
for var in before_or_after_loop_vars:
if not isinstance(var, gast.Name):
continue
if var.id in removed_vars_name_strs and var not in self.condition_vars[
loop_node]:
removed_vars.add(var)
return before_or_after_loop_vars - removed_vars
class LoopTransformer(gast.NodeTransformer):
"""
......
......@@ -132,6 +132,19 @@ def var_create_in_for_loop(max_len):
return ret
def nested_for_loop_dyfunc():
two = fluid.layers.fill_constant(shape=[1], value=2, dtype="int32")
three = fluid.layers.fill_constant(shape=[1], value=3, dtype="int32")
for j in range(two):
for i in range(10):
a = 2
for i in range(three):
b = fluid.layers.zeros(shape=[1], dtype='float32')
return b
class TestNameVisitor(unittest.TestCase):
def setUp(self):
self.loop_funcs = [
......@@ -142,6 +155,8 @@ class TestNameVisitor(unittest.TestCase):
]
self.create_var_names = [set(), set(["ret"]), set()]
self.nested_for_loop_func = nested_for_loop_dyfunc
def test_loop_vars(self):
for i in range(len(self.loop_funcs)):
func = self.loop_funcs[i]
......@@ -155,6 +170,28 @@ class TestNameVisitor(unittest.TestCase):
self.assertEqual(loop_var_names, self.loop_var_names[i])
self.assertEqual(create_var_names, self.create_var_names[i])
def test_nested_loop_vars(self):
func = self.nested_for_loop_func
test_func = inspect.getsource(func)
gast_root = gast.parse(test_func)
name_visitor = NameVisitor(gast_root)
self.loop_var_names = [
set(["j", "two"]),
set(["i", "three", "b"]),
set(["i"]),
]
self.create_var_names = [set(), set(["b"]), set()]
i = 0
for node in gast.walk(gast_root):
if isinstance(node, (gast.While, gast.For)):
loop_var_names, create_var_names = name_visitor.get_loop_var_names(
node)
# print(loop_var_names)
self.assertEqual(loop_var_names, self.loop_var_names[i])
self.assertEqual(create_var_names, self.create_var_names[i])
i += 1
class TestTransformWhileLoop(unittest.TestCase):
def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册