未验证 提交 8814853d 编写于 作者: H Huihuang Zheng 提交者: GitHub

Remove Read-Only Basic Type Variable in loop_vars (#23299)

* Remove Read-Only Basic Type Variable in loop_vars

test=develop

* Put class support in loop for future thing

test=develop

* Modify based on reviewer's comment

test=develop
上级 0471476a
......@@ -20,6 +20,7 @@ import gast
from collections import defaultdict
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node
......@@ -134,6 +135,12 @@ class NameVisitor(gast.NodeVisitor):
self.before_loop_body_vars = defaultdict(set)
self.in_loop_vars = defaultdict(set)
# Mapping from gast.While/gast.For to variable nodes which is condition
# of loop or being modified during the loop
self.write_in_loop = defaultdict(set)
self.condition_vars = defaultdict(set)
self.in_condition = False
self.static_analysis_visitor = StaticAnalysisVisitor(root_node)
self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
)
......@@ -158,14 +165,36 @@ class NameVisitor(gast.NodeVisitor):
after_loop_vars = self.current_seen_vars - before_loop_body_vars - in_loop_vars
after_loop_name_strs = self._var_nodes_to_names(after_loop_vars,
read_context)
condition_vars = self.condition_vars[node]
condition_names = self._var_nodes_to_names(condition_vars)
write_vars = self.write_in_loop[node]
write_names = self._var_nodes_to_names(write_vars)
name_to_type = {}
for var in in_loop_vars:
wrapper = self.node_to_wrapper_map[var]
name_to_type[self._var_node_to_name(var)] = wrapper.node_var_type
for name in in_loop_name_strs:
if name in before_loop_name_strs:
# If a variable is used in loop and created before loop, it
# should be in loop_var as input
# If a variable is used in loop and created before loop
# If this var is a basic variable and read-only and not
# condition var, it may not be loop_var else it should
# be in loop_var as input
if (not name in condition_names) and (
not name in write_names
) and self._node_var_type_is_basic(name_to_type[name]):
continue
loop_var_names.add(name)
elif name in after_loop_name_strs:
# If a variable is created in the while loop and read after
# loop, it should be in loop_var and we should create it
# because name in after_loop_name must be initialized in loop
# So it is write-only, we don't have to filter read-only basic
# vars out
loop_var_names.add(name)
create_var_names.add(name)
return loop_var_names, create_var_names
......@@ -179,8 +208,15 @@ class NameVisitor(gast.NodeVisitor):
return
self.current_seen_vars.add(node)
write_context = {
type(gast.Store()), type(gast.AugStore()), type(gast.Del())
}
for loop_node in self.current_loop:
self.in_loop_vars[loop_node].add(node)
if type(node.ctx) in write_context:
self.write_in_loop[loop_node].add(node)
if self.in_condition:
self.condition_vars[loop_node].add(node)
self.generic_visit(node)
def visit_FunctionDef(self, node):
......@@ -217,21 +253,28 @@ class NameVisitor(gast.NodeVisitor):
if attr_full_name.startswith("self."):
return
self.current_seen_vars.add(node)
for loop_node in self.current_loop:
self.in_loop_vars[loop_node].add(node)
# sub-nodes are visited during get_attribute_full_name and we shouldn't
# visit again
def visit_For(self, node):
self.current_loop.append(node)
self.in_condition = True
self.visit(node.target)
self.visit(node.iter)
self.in_condition = False
self.before_loop_body_vars[node] = copy.copy(self.current_seen_vars)
self.generic_visit(node)
self.current_loop.pop()
def visit_While(self, node):
self.current_loop.append(node)
self.in_condition = True
self.visit(node.test)
self.in_condition = False
self.before_loop_body_vars[node] = copy.copy(self.current_seen_vars)
self.generic_visit(node)
self.current_loop.pop()
......@@ -240,11 +283,24 @@ class NameVisitor(gast.NodeVisitor):
ret = set()
for node in node_set:
if ctx_filter_set is None or type(node.ctx) in ctx_filter_set:
ret.add(self._var_node_to_name(node))
return ret
def _var_node_to_name(self, node):
if isinstance(node, gast.Name):
ret.add(node.id)
return node.id
elif isinstance(node, gast.Attribute):
ret.add(get_attribute_full_name(node))
return ret
return get_attribute_full_name(node)
def _node_var_type_is_basic(self, node_var_type):
basic_types = {
NodeVarType.BOOLEAN, NodeVarType.INT, NodeVarType.FLOAT,
NodeVarType.STRING
}
for t in node_var_type:
if t in basic_types:
return True
return False
def _is_call_func_name_node(self, node):
parent_node = self.node_to_wrapper_map[node].parent.node
......
......@@ -90,8 +90,7 @@ class TestNameVisitor(unittest.TestCase):
while_loop_dyfunc, for_loop_dyfunc, while_loop_dyfunc_with_none
]
self.loop_var_names = [
set(["i", "x"]), set(["i", "ret", "max_len"]),
set(["i", "x", "flag"])
set(["i", "x"]), set(["i", "ret", "max_len"]), set(["i", "x"])
]
self.create_var_names = [set(), set(["ret"]), set()]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册