未验证 提交 eb1c0901 编写于 作者: L liym27 提交者: GitHub

[Dy2Stat]Remove unnecessary vars from gast.comprehension in LoopTransformer. (#25094)

上级 a7944904
......@@ -117,15 +117,16 @@ class NameVisitor(gast.NodeVisitor):
var_node.ctx)
in_loop_vars = set(in_loop_vars_list)
in_loop_vars = self._remove_unnecessary_vars(in_loop_vars, node)
in_loop_name_strs = self._var_nodes_to_names(in_loop_vars)
before_loop_body_vars = self.before_loop_body_vars[node]
before_loop_body_vars = self._remove_target_vars_of_for(
before_loop_body_vars = self._remove_unnecessary_vars(
before_loop_body_vars, node)
before_loop_name_strs = self._var_nodes_to_names(before_loop_body_vars)
after_loop_vars = self.current_seen_vars - before_loop_body_vars - in_loop_vars
after_loop_vars = self._remove_target_vars_of_for(after_loop_vars, node)
after_loop_vars = self._remove_unnecessary_vars(after_loop_vars, node)
after_loop_name_strs = self._var_nodes_to_names(after_loop_vars,
read_context)
condition_vars = self.condition_vars[node]
......@@ -138,7 +139,6 @@ class NameVisitor(gast.NodeVisitor):
for var in in_loop_vars:
wrapper = self.node_to_wrapper_map[var]
name_to_type[self._var_node_to_name(var)] = wrapper.node_var_type
for name in in_loop_name_strs:
if name in before_loop_name_strs:
# If a variable is used in loop and created before loop
......@@ -296,47 +296,83 @@ class NameVisitor(gast.NodeVisitor):
return parent_node
return None
def _remove_target_vars_of_for(self, before_or_after_loop_vars, loop_node):
def _remove_unnecessary_vars(self, loop_vars, loop_node):
"""
Remove target vars of gast.For from before_loop_vars or after_loop_vars.
:param before_or_after_loop_vars: before_loop_vars or after_loop_vars of loop_node.
Remove unnecessary vars from before_loop_vars, after_loop_vars or in_loop_vars about loop_node.
1. Remove target vars of gast.For from before_loop_vars or after_loop_vars.
2. Remove vars only in gast.comprehension.
:param loop_vars: before_loop_vars, after_loop_vars or in_loop_vars of loop_node.
:param loop_node: Current loop node.
"""
removed_vars = set()
for name_node in before_or_after_loop_vars:
vars_of_list_generator = set()
target_vars_of_for_node = set()
for name_node in loop_vars:
if not isinstance(name_node, gast.Name):
continue
parent_node = self._get_parent_node(name_node)
# NOTE: gast.For.target can be gast.Tuple.
# For example: `for i, j in enumerate(x)` has two target vars: i and j
# NOTE: gast.For.target or gast.comprehension.target can be gast.Tuple.
# For examples:
# 1) `for i, j in enumerate(x)` has two target vars: i and j
# 2) `[x for x,y in array]` has two target vars: x and y
if isinstance(parent_node, gast.Tuple):
parent_node = self._get_parent_node(parent_node)
if isinstance(parent_node,
gast.For) and parent_node is not loop_node:
# 1. Get vars only in gast.comprehension.
# For examples:
# 1) [x for x,y in array] -> x, x, y
# 2) [f(x) for x in array] -> x
# 3) [func(x, y) for x in array] -> x, x
if isinstance(parent_node, gast.comprehension):
# 1.1 target vars in list/set comprehensions
target_node = parent_node.target
if isinstance(target_node, gast.Tuple):
target_vars = target_node.elts
else:
target_vars = [target_node]
if name_node in target_vars:
removed_vars.add(name_node)
vars_of_list_generator = vars_of_list_generator | set(
target_vars)
# 1.2 vars from target vars used in elt_node
target_var_names = {var.id for var in target_vars}
listcomp_node = self._get_parent_node(parent_node)
elt_node = listcomp_node.elt
if isinstance(elt_node, gast.Name):
if elt_node.id in target_var_names:
vars_of_list_generator.add(elt_node)
for child_node in gast.walk(elt_node):
if isinstance(child_node, gast.Name):
if child_node.id in target_var_names:
vars_of_list_generator.add(child_node)
# 2. Get target vars or vars from target vars used in for-loop.
elif isinstance(parent_node,
gast.For) and parent_node is not loop_node:
# 2.1 target vars in gast.For node.
target_node = parent_node.target
if isinstance(target_node, gast.Tuple):
target_vars = target_node.elts
else:
target_vars = [target_node]
removed_vars_name_strs = {var.id for var in removed_vars}
target_vars_of_for_node = target_vars_of_for_node | set(
target_vars)
for var in before_or_after_loop_vars:
# 2.2 vars from target vars used in for-loop
target_vars_name_strs = {var.id for var in target_vars_of_for_node}
for var in loop_vars:
if not isinstance(var, gast.Name):
continue
if var.id in removed_vars_name_strs and var not in self.condition_vars[
if var.id in target_vars_name_strs and var not in self.condition_vars[
loop_node]:
removed_vars.add(var)
target_vars_of_for_node.add(var)
return before_or_after_loop_vars - removed_vars
removed_vars = target_vars_of_for_node | vars_of_list_generator
return loop_vars - removed_vars
class LoopTransformer(gast.NodeTransformer):
......
......@@ -169,15 +169,28 @@ def nested_for_loop_dyfunc():
return b
def for_loop_dufunc_with_listcomp(array):
a = 1
for j in range(array):
res = [x + a for x in array]
res = [i for i in array]
x = 1
b = [i for i in array]
print(x)
return res
class TestNameVisitor(unittest.TestCase):
def setUp(self):
self.loop_funcs = [
while_loop_dyfunc, for_loop_dyfunc, while_loop_dyfunc_with_none
while_loop_dyfunc, for_loop_dyfunc, while_loop_dyfunc_with_none,
for_loop_dufunc_with_listcomp
]
self.loop_var_names = [
set(["i", "x"]), set(["i", "ret", "max_len"]), set(["i", "x"])
set(["i", "x"]), set(["i", "ret", "max_len"]), set(["i", "x"]),
set(["j", "array", "res", "x"])
]
self.create_var_names = [set(), set(["ret"]), set()]
self.create_var_names = [set(), set(["ret"]), set(), set(["res", "x"])]
self.nested_for_loop_func = nested_for_loop_dyfunc
......@@ -211,7 +224,6 @@ class TestNameVisitor(unittest.TestCase):
if isinstance(node, (gast.While, gast.For)):
loop_var_names, create_var_names = name_visitor.get_loop_var_names(
node)
# print(loop_var_names)
self.assertEqual(loop_var_names, self.loop_var_names[i])
self.assertEqual(create_var_names, self.create_var_names[i])
i += 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册