diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py index 09ac73057af20d31c012379be39f5755fba22603..ef5caa2dd87e9296cb49f86ce574e30d639aa7ec 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py @@ -125,6 +125,11 @@ class NameVisitor(gast.NodeVisitor): # List of gast.While/gast.For nodes self.current_loop = [] + # List of nodes that have scope of variables. + self.nodes_with_scope = [] + + self.blacklist_names = {"False", "True", "None"} + # Mapping from gast.While/gast.For to variable nodes self.before_loop_body_vars = defaultdict(set) self.in_loop_vars = defaultdict(set) @@ -169,7 +174,7 @@ class NameVisitor(gast.NodeVisitor): if self._is_call_func_name_node(node): self.generic_visit(node) return - if node.id == "False" or node.id == "True" or node.id == "None": + if node.id in self.blacklist_names: self.generic_visit(node) return @@ -178,6 +183,19 @@ class NameVisitor(gast.NodeVisitor): self.in_loop_vars[loop_node].add(node) self.generic_visit(node) + def visit_FunctionDef(self, node): + self.nodes_with_scope.append(node) + self.blacklist_names.add(node.name) + # The variables in the function are not visible to the outside scope. + before_func_seen_vars = copy.copy(self.current_seen_vars) + + self.generic_visit(node) + self.nodes_with_scope.pop() + # After exiting the scope of the node, variables in this scope + # should be removed from self.current_seen_vars. + if self.nodes_with_scope: + self.current_seen_vars = before_func_seen_vars + def visit(self, node): method = 'visit_' + node.__class__.__name__ visitor = getattr(self, method, self.generic_visit) @@ -188,6 +206,16 @@ class NameVisitor(gast.NodeVisitor): if self._is_call_func_name_node(node): return attr_full_name = get_attribute_full_name(node) + # Class variables are not allowed to appear in the arguments list + # of defined function under class methods in Python. + """ + def class_func(self): + def while_loop_body(self.x, y) # `self.x` is illegal. + """ + # TODO: If do change the variable with `self.var`, need a better + # way to deal with this case. + if attr_full_name.startswith("self."): + return self.current_seen_vars.add(node) for loop_node in self.current_loop: self.in_loop_vars[loop_node].add(node) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py index 7c500a05c2512f8cbeb38cbd9a87d0271e59222a..baf396a41b7373f40baa6e2aa745500d3f8d93fc 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py @@ -35,6 +35,23 @@ def while_loop_dyfunc(x): return i +def while_loop_dyfun_with_conflict_var(x): + i = fluid.dygraph.to_variable(x) + + def relu(y): + # 'y' is not visible outside the scope. + return fluid.layers.relu(y) + + while x < 10: + # If a tmp variable is created which has same name + # with a argument in function, it should not be + # included in the loop_vars. + add_fn = lambda x, y: x + y + i = add_fn(i, x) + x = x + 1 + return i + + def while_loop_dyfunc_with_none(x): i = fluid.dygraph.to_variable(x)\ if x is not None \ @@ -131,6 +148,11 @@ class TestTransformWhileLoop(unittest.TestCase): # self.assertTrue(np.allclose(self._run_dygraph(), self._run_static())) +class TestTransformWhileLoopWithConflicVar(TestTransformWhileLoop): + def _init_dyfunc(self): + self.dyfunc = while_loop_dyfun_with_conflict_var + + class TestTransformWhileLoopWithNone(TestTransformWhileLoop): def _init_dyfunc(self): self.dyfunc = while_loop_dyfunc_with_none