未验证 提交 26bc953b 编写于 作者: A Aurelius84 提交者: GitHub

Fix returned arguments in IfElse_fn test=develop (#23102)

上级 0d8f40d2
...@@ -388,20 +388,20 @@ class IfConditionVisitor(object): ...@@ -388,20 +388,20 @@ class IfConditionVisitor(object):
class NameVisitor(gast.NodeVisitor): class NameVisitor(gast.NodeVisitor):
def __init__(self, node_black_set=None): def __init__(self, end_node=None):
# Set of nodes that will not be visited. # The terminate node of the visitor.
self.node_black_set = node_black_set or set() self.end_node = end_node
# Dict to store the names and ctxs of vars. # Dict to store the names and ctxs of vars.
self.name_ids = defaultdict(list) self.name_ids = defaultdict(list)
# List of current visited nodes # List of current visited nodes
self.ancestor_nodes = [] self.ancestor_nodes = []
# Available only when node_black_set is set. # Available only when end_node is set.
self._is_finished = False self._is_finished = False
self._candidate_ctxs = (gast.Store, gast.Load, gast.Param) self._candidate_ctxs = (gast.Store, gast.Load, gast.Param)
def visit(self, node): def visit(self, node):
"""Visit a node.""" """Visit a node."""
if node in self.node_black_set or self._is_finished: if node == self.end_node or self._is_finished:
self._is_finished = True self._is_finished = True
return return
...@@ -433,17 +433,24 @@ class NameVisitor(gast.NodeVisitor): ...@@ -433,17 +433,24 @@ class NameVisitor(gast.NodeVisitor):
In above two cases, we should consider to manage the scope of vars to parsing In above two cases, we should consider to manage the scope of vars to parsing
the arguments and returned vars correctly. the arguments and returned vars correctly.
""" """
if not self.end_node:
self.generic_visit(node)
else:
before_if_name_ids = copy.deepcopy(self.name_ids) before_if_name_ids = copy.deepcopy(self.name_ids)
body_name_ids = self._visit_child(node.body) body_name_ids = self._visit_child(node.body)
# If the traversal process stops early, just return the name_ids that have been seen. # If traversal process stops early in `if.body`, return the currently seen name_ids.
if self._is_finished: if self._is_finished:
for name_id, ctxs in before_if_name_ids.items(): self._update_name_ids(before_if_name_ids)
self.name_ids[name_id] = ctxs + self.name_ids[name_id]
# Blocks the vars in `if.body` and only inserts the vars both created in 'if/else' branch
# into name_ids.
else: else:
else_name_ids = self._visit_child(node.orelse) else_name_ids = self._visit_child(node.orelse)
new_name_ids = self._find_new_name_ids(body_name_ids, else_name_ids) # If traversal process stops early in `if.orelse`, return the currently seen name_ids.
if self._is_finished:
self._update_name_ids(before_if_name_ids)
else:
# Blocks the vars in `if.body` and only inserts the vars both created in 'if/else' branch
# into name_ids.
new_name_ids = self._find_new_name_ids(body_name_ids,
else_name_ids)
for new_name_id in new_name_ids: for new_name_id in new_name_ids:
before_if_name_ids[new_name_id].append(gast.Store()) before_if_name_ids[new_name_id].append(gast.Store())
...@@ -463,6 +470,19 @@ class NameVisitor(gast.NodeVisitor): ...@@ -463,6 +470,19 @@ class NameVisitor(gast.NodeVisitor):
node._fields = ('value', 'targets') node._fields = ('value', 'targets')
self.generic_visit(node) self.generic_visit(node)
def visit_FunctionDef(self, node):
if not self.end_node:
self.generic_visit(node)
else:
before_name_ids = copy.deepcopy(self.name_ids)
self.name_ids = defaultdict(list)
self.generic_visit(node)
if self._is_finished:
self._update_name_ids(before_name_ids)
else:
self.name_ids = before_name_ids
def visit_Return(self, node): def visit_Return(self, node):
# Ignore the vars in return # Ignore the vars in return
return return
...@@ -505,12 +525,16 @@ class NameVisitor(gast.NodeVisitor): ...@@ -505,12 +525,16 @@ class NameVisitor(gast.NodeVisitor):
return True return True
return False return False
def _update_name_ids(self, new_name_ids):
for name_id, ctxs in new_name_ids.items():
self.name_ids[name_id] = ctxs + self.name_ids[name_id]
def get_name_ids(nodes, node_black_set=None): def get_name_ids(nodes, end_node=None):
""" """
Return all ast.Name.id of python variable in nodes. Return all ast.Name.id of python variable in nodes.
""" """
name_visitor = NameVisitor(node_black_set) name_visitor = NameVisitor(end_node)
for node in nodes: for node in nodes:
name_visitor.visit(node) name_visitor.visit(node)
return name_visitor.name_ids return name_visitor.name_ids
...@@ -611,7 +635,7 @@ def transform_if_else(node, root): ...@@ -611,7 +635,7 @@ def transform_if_else(node, root):
""" """
Transform ast.If into control flow statement of Paddle static graph. Transform ast.If into control flow statement of Paddle static graph.
""" """
parent_name_ids = get_name_ids([root], node_black_set=[node]) parent_name_ids = get_name_ids([root], end_node=node)
if_name_ids = get_name_ids(node.body) if_name_ids = get_name_ids(node.body)
else_name_ids = get_name_ids(node.orelse) else_name_ids = get_name_ids(node.orelse)
......
...@@ -102,19 +102,18 @@ def test_break_continue_in_for(x): ...@@ -102,19 +102,18 @@ def test_break_continue_in_for(x):
def test_for_in_else(x): def test_for_in_else(x):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
#
# TODO: Huihuang founds that if we put the for range in else body # Case 1:
# the testcase will fail. Enable this test case after fixing it. if False:
# pass
#if False: else:
# pass for i in range(0, 10):
#else: if i > 5:
# for i in range(0, 10): x += 1
# if i > 5: break
# x += 1 x += i
# break
# x += i # Case 2:
#
if False: if False:
pass pass
else: else:
......
...@@ -65,10 +65,24 @@ class TestGetNameIds2(TestGetNameIds): ...@@ -65,10 +65,24 @@ class TestGetNameIds2(TestGetNameIds):
return z return z
""" """
self.all_name_ids = { self.all_name_ids = {
'x': [gast.Param(), gast.Store()], 'x': [
'a': [gast.Store(), gast.Load()], gast.Param(), gast.Store(), gast.Load(), gast.Load(),
'y': [gast.Param(), gast.Load()], gast.Load()
'z': [gast.Store()] ],
'a': [gast.Store(), gast.Load(), gast.Load()],
'y': [
gast.Param(),
gast.Load(),
gast.Load(),
gast.Load(),
gast.Load(),
],
'z': [
gast.Store(),
gast.Load(),
gast.Store(),
gast.Store(),
]
} }
...@@ -83,9 +97,23 @@ class TestGetNameIds3(TestGetNameIds): ...@@ -83,9 +97,23 @@ class TestGetNameIds3(TestGetNameIds):
return z return z
""" """
self.all_name_ids = { self.all_name_ids = {
'x': [gast.Param()], 'x': [
'y': [gast.Param()], gast.Param(),
'z': [gast.Store()] gast.Load(),
gast.Load(),
gast.Load(),
],
'y': [
gast.Param(),
gast.Load(),
gast.Load(),
],
'z': [
gast.Store(),
gast.Store(),
gast.Load(),
gast.Store(),
]
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册