未验证 提交 8814853d 编写于 作者: H Huihuang Zheng 提交者: GitHub

Remove Read-Only Basic Type Variable in loop_vars (#23299)

* Remove Read-Only Basic Type Variable in loop_vars

test=develop

* Put class support in loop for future thing

test=develop

* Modify based on reviewer's comment

test=develop
上级 0471476a
...@@ -20,6 +20,7 @@ import gast ...@@ -20,6 +20,7 @@ import gast
from collections import defaultdict from collections import defaultdict
from paddle.fluid import unique_name from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node
...@@ -134,6 +135,12 @@ class NameVisitor(gast.NodeVisitor): ...@@ -134,6 +135,12 @@ class NameVisitor(gast.NodeVisitor):
self.before_loop_body_vars = defaultdict(set) self.before_loop_body_vars = defaultdict(set)
self.in_loop_vars = defaultdict(set) self.in_loop_vars = defaultdict(set)
# Mapping from gast.While/gast.For to variable nodes which is condition
# of loop or being modified during the loop
self.write_in_loop = defaultdict(set)
self.condition_vars = defaultdict(set)
self.in_condition = False
self.static_analysis_visitor = StaticAnalysisVisitor(root_node) self.static_analysis_visitor = StaticAnalysisVisitor(root_node)
self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map( self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
) )
...@@ -158,14 +165,36 @@ class NameVisitor(gast.NodeVisitor): ...@@ -158,14 +165,36 @@ class NameVisitor(gast.NodeVisitor):
after_loop_vars = self.current_seen_vars - before_loop_body_vars - in_loop_vars after_loop_vars = self.current_seen_vars - before_loop_body_vars - in_loop_vars
after_loop_name_strs = self._var_nodes_to_names(after_loop_vars, after_loop_name_strs = self._var_nodes_to_names(after_loop_vars,
read_context) read_context)
condition_vars = self.condition_vars[node]
condition_names = self._var_nodes_to_names(condition_vars)
write_vars = self.write_in_loop[node]
write_names = self._var_nodes_to_names(write_vars)
name_to_type = {}
for var in in_loop_vars:
wrapper = self.node_to_wrapper_map[var]
name_to_type[self._var_node_to_name(var)] = wrapper.node_var_type
for name in in_loop_name_strs: for name in in_loop_name_strs:
if name in before_loop_name_strs: if name in before_loop_name_strs:
# If a variable is used in loop and created before loop, it # If a variable is used in loop and created before loop
# should be in loop_var as input
# If this var is a basic variable and read-only and not
# condition var, it may not be loop_var else it should
# be in loop_var as input
if (not name in condition_names) and (
not name in write_names
) and self._node_var_type_is_basic(name_to_type[name]):
continue
loop_var_names.add(name) loop_var_names.add(name)
elif name in after_loop_name_strs: elif name in after_loop_name_strs:
# If a variable is created in the while loop and read after # If a variable is created in the while loop and read after
# loop, it should be in loop_var and we should create it # loop, it should be in loop_var and we should create it
# because name in after_loop_name must be initialized in loop
# So it is write-only, we don't have to filter read-only basic
# vars out
loop_var_names.add(name) loop_var_names.add(name)
create_var_names.add(name) create_var_names.add(name)
return loop_var_names, create_var_names return loop_var_names, create_var_names
...@@ -179,8 +208,15 @@ class NameVisitor(gast.NodeVisitor): ...@@ -179,8 +208,15 @@ class NameVisitor(gast.NodeVisitor):
return return
self.current_seen_vars.add(node) self.current_seen_vars.add(node)
write_context = {
type(gast.Store()), type(gast.AugStore()), type(gast.Del())
}
for loop_node in self.current_loop: for loop_node in self.current_loop:
self.in_loop_vars[loop_node].add(node) self.in_loop_vars[loop_node].add(node)
if type(node.ctx) in write_context:
self.write_in_loop[loop_node].add(node)
if self.in_condition:
self.condition_vars[loop_node].add(node)
self.generic_visit(node) self.generic_visit(node)
def visit_FunctionDef(self, node): def visit_FunctionDef(self, node):
...@@ -217,21 +253,28 @@ class NameVisitor(gast.NodeVisitor): ...@@ -217,21 +253,28 @@ class NameVisitor(gast.NodeVisitor):
if attr_full_name.startswith("self."): if attr_full_name.startswith("self."):
return return
self.current_seen_vars.add(node) self.current_seen_vars.add(node)
for loop_node in self.current_loop: for loop_node in self.current_loop:
self.in_loop_vars[loop_node].add(node) self.in_loop_vars[loop_node].add(node)
# sub-nodes are visited during get_attribute_full_name and we shouldn't # sub-nodes are visited during get_attribute_full_name and we shouldn't
# visit again # visit again
def visit_For(self, node): def visit_For(self, node):
self.current_loop.append(node) self.current_loop.append(node)
self.in_condition = True
self.visit(node.target) self.visit(node.target)
self.visit(node.iter)
self.in_condition = False
self.before_loop_body_vars[node] = copy.copy(self.current_seen_vars) self.before_loop_body_vars[node] = copy.copy(self.current_seen_vars)
self.generic_visit(node) self.generic_visit(node)
self.current_loop.pop() self.current_loop.pop()
def visit_While(self, node): def visit_While(self, node):
self.current_loop.append(node) self.current_loop.append(node)
self.in_condition = True
self.visit(node.test) self.visit(node.test)
self.in_condition = False
self.before_loop_body_vars[node] = copy.copy(self.current_seen_vars) self.before_loop_body_vars[node] = copy.copy(self.current_seen_vars)
self.generic_visit(node) self.generic_visit(node)
self.current_loop.pop() self.current_loop.pop()
...@@ -240,12 +283,25 @@ class NameVisitor(gast.NodeVisitor): ...@@ -240,12 +283,25 @@ class NameVisitor(gast.NodeVisitor):
ret = set() ret = set()
for node in node_set: for node in node_set:
if ctx_filter_set is None or type(node.ctx) in ctx_filter_set: if ctx_filter_set is None or type(node.ctx) in ctx_filter_set:
if isinstance(node, gast.Name): ret.add(self._var_node_to_name(node))
ret.add(node.id)
elif isinstance(node, gast.Attribute):
ret.add(get_attribute_full_name(node))
return ret return ret
def _var_node_to_name(self, node):
if isinstance(node, gast.Name):
return node.id
elif isinstance(node, gast.Attribute):
return get_attribute_full_name(node)
def _node_var_type_is_basic(self, node_var_type):
basic_types = {
NodeVarType.BOOLEAN, NodeVarType.INT, NodeVarType.FLOAT,
NodeVarType.STRING
}
for t in node_var_type:
if t in basic_types:
return True
return False
def _is_call_func_name_node(self, node): def _is_call_func_name_node(self, node):
parent_node = self.node_to_wrapper_map[node].parent.node parent_node = self.node_to_wrapper_map[node].parent.node
if isinstance(parent_node, gast.Call) and parent_node.func == node: if isinstance(parent_node, gast.Call) and parent_node.func == node:
......
...@@ -90,8 +90,7 @@ class TestNameVisitor(unittest.TestCase): ...@@ -90,8 +90,7 @@ class TestNameVisitor(unittest.TestCase):
while_loop_dyfunc, for_loop_dyfunc, while_loop_dyfunc_with_none while_loop_dyfunc, for_loop_dyfunc, while_loop_dyfunc_with_none
] ]
self.loop_var_names = [ self.loop_var_names = [
set(["i", "x"]), set(["i", "ret", "max_len"]), set(["i", "x"]), set(["i", "ret", "max_len"]), set(["i", "x"])
set(["i", "x", "flag"])
] ]
self.create_var_names = [set(), set(["ret"]), set()] self.create_var_names = [set(), set(["ret"]), set()]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册