未验证 提交 f8205ffa 编写于 作者: A Aurelius84 提交者: GitHub

fix conflict var bug in loop_transformer test=develop (#23287)

上级 16e74f11
......@@ -125,6 +125,11 @@ class NameVisitor(gast.NodeVisitor):
# List of gast.While/gast.For nodes
self.current_loop = []
# List of nodes that have scope of variables.
self.nodes_with_scope = []
self.blacklist_names = {"False", "True", "None"}
# Mapping from gast.While/gast.For to variable nodes
self.before_loop_body_vars = defaultdict(set)
self.in_loop_vars = defaultdict(set)
......@@ -169,7 +174,7 @@ class NameVisitor(gast.NodeVisitor):
if self._is_call_func_name_node(node):
self.generic_visit(node)
return
if node.id == "False" or node.id == "True" or node.id == "None":
if node.id in self.blacklist_names:
self.generic_visit(node)
return
......@@ -178,6 +183,19 @@ class NameVisitor(gast.NodeVisitor):
self.in_loop_vars[loop_node].add(node)
self.generic_visit(node)
def visit_FunctionDef(self, node):
self.nodes_with_scope.append(node)
self.blacklist_names.add(node.name)
# The variables in the function are not visible to the outside scope.
before_func_seen_vars = copy.copy(self.current_seen_vars)
self.generic_visit(node)
self.nodes_with_scope.pop()
# After exiting the scope of the node, variables in this scope
# should be removed from self.current_seen_vars.
if self.nodes_with_scope:
self.current_seen_vars = before_func_seen_vars
def visit(self, node):
method = 'visit_' + node.__class__.__name__
visitor = getattr(self, method, self.generic_visit)
......@@ -188,6 +206,16 @@ class NameVisitor(gast.NodeVisitor):
if self._is_call_func_name_node(node):
return
attr_full_name = get_attribute_full_name(node)
# Class variables are not allowed to appear in the arguments list
# of defined function under class methods in Python.
"""
def class_func(self):
def while_loop_body(self.x, y) # `self.x` is illegal.
"""
# TODO: If do change the variable with `self.var`, need a better
# way to deal with this case.
if attr_full_name.startswith("self."):
return
self.current_seen_vars.add(node)
for loop_node in self.current_loop:
self.in_loop_vars[loop_node].add(node)
......
......@@ -35,6 +35,23 @@ def while_loop_dyfunc(x):
return i
def while_loop_dyfun_with_conflict_var(x):
i = fluid.dygraph.to_variable(x)
def relu(y):
# 'y' is not visible outside the scope.
return fluid.layers.relu(y)
while x < 10:
# If a tmp variable is created which has same name
# with a argument in function, it should not be
# included in the loop_vars.
add_fn = lambda x, y: x + y
i = add_fn(i, x)
x = x + 1
return i
def while_loop_dyfunc_with_none(x):
i = fluid.dygraph.to_variable(x)\
if x is not None \
......@@ -131,6 +148,11 @@ class TestTransformWhileLoop(unittest.TestCase):
# self.assertTrue(np.allclose(self._run_dygraph(), self._run_static()))
class TestTransformWhileLoopWithConflicVar(TestTransformWhileLoop):
def _init_dyfunc(self):
self.dyfunc = while_loop_dyfun_with_conflict_var
class TestTransformWhileLoopWithNone(TestTransformWhileLoop):
def _init_dyfunc(self):
self.dyfunc = while_loop_dyfunc_with_none
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册