未验证 提交 26bc953b 编写于 作者: A Aurelius84 提交者: GitHub

Fix returned arguments in IfElse_fn test=develop (#23102)

上级 0d8f40d2
......@@ -388,20 +388,20 @@ class IfConditionVisitor(object):
class NameVisitor(gast.NodeVisitor):
def __init__(self, node_black_set=None):
# Set of nodes that will not be visited.
self.node_black_set = node_black_set or set()
def __init__(self, end_node=None):
# The terminate node of the visitor.
self.end_node = end_node
# Dict to store the names and ctxs of vars.
self.name_ids = defaultdict(list)
# List of current visited nodes
self.ancestor_nodes = []
# Available only when node_black_set is set.
# Available only when end_node is set.
self._is_finished = False
self._candidate_ctxs = (gast.Store, gast.Load, gast.Param)
def visit(self, node):
"""Visit a node."""
if node in self.node_black_set or self._is_finished:
if node == self.end_node or self._is_finished:
self._is_finished = True
return
......@@ -433,21 +433,28 @@ class NameVisitor(gast.NodeVisitor):
In above two cases, we should consider to manage the scope of vars to parsing
the arguments and returned vars correctly.
"""
before_if_name_ids = copy.deepcopy(self.name_ids)
body_name_ids = self._visit_child(node.body)
# If the traversal process stops early, just return the name_ids that have been seen.
if self._is_finished:
for name_id, ctxs in before_if_name_ids.items():
self.name_ids[name_id] = ctxs + self.name_ids[name_id]
# Blocks the vars in `if.body` and only inserts the vars both created in 'if/else' branch
# into name_ids.
if not self.end_node:
self.generic_visit(node)
else:
else_name_ids = self._visit_child(node.orelse)
new_name_ids = self._find_new_name_ids(body_name_ids, else_name_ids)
for new_name_id in new_name_ids:
before_if_name_ids[new_name_id].append(gast.Store())
self.name_ids = before_if_name_ids
before_if_name_ids = copy.deepcopy(self.name_ids)
body_name_ids = self._visit_child(node.body)
# If traversal process stops early in `if.body`, return the currently seen name_ids.
if self._is_finished:
self._update_name_ids(before_if_name_ids)
else:
else_name_ids = self._visit_child(node.orelse)
# If traversal process stops early in `if.orelse`, return the currently seen name_ids.
if self._is_finished:
self._update_name_ids(before_if_name_ids)
else:
# Blocks the vars in `if.body` and only inserts the vars both created in 'if/else' branch
# into name_ids.
new_name_ids = self._find_new_name_ids(body_name_ids,
else_name_ids)
for new_name_id in new_name_ids:
before_if_name_ids[new_name_id].append(gast.Store())
self.name_ids = before_if_name_ids
def visit_Attribute(self, node):
if not self._is_call_func_name_node(node):
......@@ -463,6 +470,19 @@ class NameVisitor(gast.NodeVisitor):
node._fields = ('value', 'targets')
self.generic_visit(node)
def visit_FunctionDef(self, node):
if not self.end_node:
self.generic_visit(node)
else:
before_name_ids = copy.deepcopy(self.name_ids)
self.name_ids = defaultdict(list)
self.generic_visit(node)
if self._is_finished:
self._update_name_ids(before_name_ids)
else:
self.name_ids = before_name_ids
def visit_Return(self, node):
# Ignore the vars in return
return
......@@ -505,12 +525,16 @@ class NameVisitor(gast.NodeVisitor):
return True
return False
def _update_name_ids(self, new_name_ids):
for name_id, ctxs in new_name_ids.items():
self.name_ids[name_id] = ctxs + self.name_ids[name_id]
def get_name_ids(nodes, node_black_set=None):
def get_name_ids(nodes, end_node=None):
"""
Return all ast.Name.id of python variable in nodes.
"""
name_visitor = NameVisitor(node_black_set)
name_visitor = NameVisitor(end_node)
for node in nodes:
name_visitor.visit(node)
return name_visitor.name_ids
......@@ -611,7 +635,7 @@ def transform_if_else(node, root):
"""
Transform ast.If into control flow statement of Paddle static graph.
"""
parent_name_ids = get_name_ids([root], node_black_set=[node])
parent_name_ids = get_name_ids([root], end_node=node)
if_name_ids = get_name_ids(node.body)
else_name_ids = get_name_ids(node.orelse)
......
......@@ -102,19 +102,18 @@ def test_break_continue_in_for(x):
def test_for_in_else(x):
x = fluid.dygraph.to_variable(x)
#
# TODO: Huihuang founds that if we put the for range in else body
# the testcase will fail. Enable this test case after fixing it.
#
#if False:
# pass
#else:
# for i in range(0, 10):
# if i > 5:
# x += 1
# break
# x += i
#
# Case 1:
if False:
pass
else:
for i in range(0, 10):
if i > 5:
x += 1
break
x += i
# Case 2:
if False:
pass
else:
......
......@@ -65,10 +65,24 @@ class TestGetNameIds2(TestGetNameIds):
return z
"""
self.all_name_ids = {
'x': [gast.Param(), gast.Store()],
'a': [gast.Store(), gast.Load()],
'y': [gast.Param(), gast.Load()],
'z': [gast.Store()]
'x': [
gast.Param(), gast.Store(), gast.Load(), gast.Load(),
gast.Load()
],
'a': [gast.Store(), gast.Load(), gast.Load()],
'y': [
gast.Param(),
gast.Load(),
gast.Load(),
gast.Load(),
gast.Load(),
],
'z': [
gast.Store(),
gast.Load(),
gast.Store(),
gast.Store(),
]
}
......@@ -83,9 +97,23 @@ class TestGetNameIds3(TestGetNameIds):
return z
"""
self.all_name_ids = {
'x': [gast.Param()],
'y': [gast.Param()],
'z': [gast.Store()]
'x': [
gast.Param(),
gast.Load(),
gast.Load(),
gast.Load(),
],
'y': [
gast.Param(),
gast.Load(),
gast.Load(),
],
'z': [
gast.Store(),
gast.Store(),
gast.Load(),
gast.Store(),
]
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册