diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py index ef5caa2dd87e9296cb49f86ce574e30d639aa7ec..a638c66af039214d1650247867d4e445ab116edc 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py @@ -20,6 +20,7 @@ import gast from collections import defaultdict from paddle.fluid import unique_name from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper +from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node @@ -134,6 +135,12 @@ class NameVisitor(gast.NodeVisitor): self.before_loop_body_vars = defaultdict(set) self.in_loop_vars = defaultdict(set) + # Mapping from gast.While/gast.For to variable nodes which is condition + # of loop or being modified during the loop + self.write_in_loop = defaultdict(set) + self.condition_vars = defaultdict(set) + self.in_condition = False + self.static_analysis_visitor = StaticAnalysisVisitor(root_node) self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map( ) @@ -158,14 +165,36 @@ class NameVisitor(gast.NodeVisitor): after_loop_vars = self.current_seen_vars - before_loop_body_vars - in_loop_vars after_loop_name_strs = self._var_nodes_to_names(after_loop_vars, read_context) + condition_vars = self.condition_vars[node] + condition_names = self._var_nodes_to_names(condition_vars) + write_vars = self.write_in_loop[node] + write_names = self._var_nodes_to_names(write_vars) + + name_to_type = {} + for var in in_loop_vars: + wrapper = self.node_to_wrapper_map[var] + name_to_type[self._var_node_to_name(var)] = wrapper.node_var_type + for name in in_loop_name_strs: if name in before_loop_name_strs: - # If a variable is used in loop and created before loop, it - # should be in loop_var as input + # If a variable is used in loop and created before loop + + # If this var is a basic variable and read-only and not + # condition var, it may not be loop_var else it should + # be in loop_var as input + if (not name in condition_names) and ( + not name in write_names + ) and self._node_var_type_is_basic(name_to_type[name]): + continue loop_var_names.add(name) + elif name in after_loop_name_strs: # If a variable is created in the while loop and read after # loop, it should be in loop_var and we should create it + + # because name in after_loop_name must be initialized in loop + # So it is write-only, we don't have to filter read-only basic + # vars out loop_var_names.add(name) create_var_names.add(name) return loop_var_names, create_var_names @@ -179,8 +208,15 @@ class NameVisitor(gast.NodeVisitor): return self.current_seen_vars.add(node) + write_context = { + type(gast.Store()), type(gast.AugStore()), type(gast.Del()) + } for loop_node in self.current_loop: self.in_loop_vars[loop_node].add(node) + if type(node.ctx) in write_context: + self.write_in_loop[loop_node].add(node) + if self.in_condition: + self.condition_vars[loop_node].add(node) self.generic_visit(node) def visit_FunctionDef(self, node): @@ -217,21 +253,28 @@ class NameVisitor(gast.NodeVisitor): if attr_full_name.startswith("self."): return self.current_seen_vars.add(node) + for loop_node in self.current_loop: self.in_loop_vars[loop_node].add(node) + # sub-nodes are visited during get_attribute_full_name and we shouldn't # visit again def visit_For(self, node): self.current_loop.append(node) + self.in_condition = True self.visit(node.target) + self.visit(node.iter) + self.in_condition = False self.before_loop_body_vars[node] = copy.copy(self.current_seen_vars) self.generic_visit(node) self.current_loop.pop() def visit_While(self, node): self.current_loop.append(node) + self.in_condition = True self.visit(node.test) + self.in_condition = False self.before_loop_body_vars[node] = copy.copy(self.current_seen_vars) self.generic_visit(node) self.current_loop.pop() @@ -240,12 +283,25 @@ class NameVisitor(gast.NodeVisitor): ret = set() for node in node_set: if ctx_filter_set is None or type(node.ctx) in ctx_filter_set: - if isinstance(node, gast.Name): - ret.add(node.id) - elif isinstance(node, gast.Attribute): - ret.add(get_attribute_full_name(node)) + ret.add(self._var_node_to_name(node)) return ret + def _var_node_to_name(self, node): + if isinstance(node, gast.Name): + return node.id + elif isinstance(node, gast.Attribute): + return get_attribute_full_name(node) + + def _node_var_type_is_basic(self, node_var_type): + basic_types = { + NodeVarType.BOOLEAN, NodeVarType.INT, NodeVarType.FLOAT, + NodeVarType.STRING + } + for t in node_var_type: + if t in basic_types: + return True + return False + def _is_call_func_name_node(self, node): parent_node = self.node_to_wrapper_map[node].parent.node if isinstance(parent_node, gast.Call) and parent_node.func == node: diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py index baf396a41b7373f40baa6e2aa745500d3f8d93fc..9b673bdcd1b958f75508c0457d1874cca6c4c52a 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py @@ -90,8 +90,7 @@ class TestNameVisitor(unittest.TestCase): while_loop_dyfunc, for_loop_dyfunc, while_loop_dyfunc_with_none ] self.loop_var_names = [ - set(["i", "x"]), set(["i", "ret", "max_len"]), - set(["i", "x", "flag"]) + set(["i", "x"]), set(["i", "ret", "max_len"]), set(["i", "x"]) ] self.create_var_names = [set(), set(["ret"]), set()]