From 26bc953b3f123427c9959de2bb6e369184a0dcde Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Thu, 19 Mar 2020 18:40:28 +0800 Subject: [PATCH] Fix returned arguments in IfElse_fn test=develop (#23102) --- .../dygraph_to_static/ifelse_transformer.py | 68 +++++++++++++------ .../dygraph_to_static/test_break_continue.py | 25 ++++--- .../dygraph_to_static/test_ifelse_basic.py | 42 ++++++++++-- 3 files changed, 93 insertions(+), 42 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py index aa3edb0a789..cce8067805b 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py @@ -388,20 +388,20 @@ class IfConditionVisitor(object): class NameVisitor(gast.NodeVisitor): - def __init__(self, node_black_set=None): - # Set of nodes that will not be visited. - self.node_black_set = node_black_set or set() + def __init__(self, end_node=None): + # The terminate node of the visitor. + self.end_node = end_node # Dict to store the names and ctxs of vars. self.name_ids = defaultdict(list) # List of current visited nodes self.ancestor_nodes = [] - # Available only when node_black_set is set. + # Available only when end_node is set. self._is_finished = False self._candidate_ctxs = (gast.Store, gast.Load, gast.Param) def visit(self, node): """Visit a node.""" - if node in self.node_black_set or self._is_finished: + if node == self.end_node or self._is_finished: self._is_finished = True return @@ -433,21 +433,28 @@ class NameVisitor(gast.NodeVisitor): In above two cases, we should consider to manage the scope of vars to parsing the arguments and returned vars correctly. """ - before_if_name_ids = copy.deepcopy(self.name_ids) - body_name_ids = self._visit_child(node.body) - # If the traversal process stops early, just return the name_ids that have been seen. - if self._is_finished: - for name_id, ctxs in before_if_name_ids.items(): - self.name_ids[name_id] = ctxs + self.name_ids[name_id] - # Blocks the vars in `if.body` and only inserts the vars both created in 'if/else' branch - # into name_ids. + if not self.end_node: + self.generic_visit(node) else: - else_name_ids = self._visit_child(node.orelse) - new_name_ids = self._find_new_name_ids(body_name_ids, else_name_ids) - for new_name_id in new_name_ids: - before_if_name_ids[new_name_id].append(gast.Store()) - - self.name_ids = before_if_name_ids + before_if_name_ids = copy.deepcopy(self.name_ids) + body_name_ids = self._visit_child(node.body) + # If traversal process stops early in `if.body`, return the currently seen name_ids. + if self._is_finished: + self._update_name_ids(before_if_name_ids) + else: + else_name_ids = self._visit_child(node.orelse) + # If traversal process stops early in `if.orelse`, return the currently seen name_ids. + if self._is_finished: + self._update_name_ids(before_if_name_ids) + else: + # Blocks the vars in `if.body` and only inserts the vars both created in 'if/else' branch + # into name_ids. + new_name_ids = self._find_new_name_ids(body_name_ids, + else_name_ids) + for new_name_id in new_name_ids: + before_if_name_ids[new_name_id].append(gast.Store()) + + self.name_ids = before_if_name_ids def visit_Attribute(self, node): if not self._is_call_func_name_node(node): @@ -463,6 +470,19 @@ class NameVisitor(gast.NodeVisitor): node._fields = ('value', 'targets') self.generic_visit(node) + def visit_FunctionDef(self, node): + if not self.end_node: + self.generic_visit(node) + else: + before_name_ids = copy.deepcopy(self.name_ids) + self.name_ids = defaultdict(list) + self.generic_visit(node) + + if self._is_finished: + self._update_name_ids(before_name_ids) + else: + self.name_ids = before_name_ids + def visit_Return(self, node): # Ignore the vars in return return @@ -505,12 +525,16 @@ class NameVisitor(gast.NodeVisitor): return True return False + def _update_name_ids(self, new_name_ids): + for name_id, ctxs in new_name_ids.items(): + self.name_ids[name_id] = ctxs + self.name_ids[name_id] + -def get_name_ids(nodes, node_black_set=None): +def get_name_ids(nodes, end_node=None): """ Return all ast.Name.id of python variable in nodes. """ - name_visitor = NameVisitor(node_black_set) + name_visitor = NameVisitor(end_node) for node in nodes: name_visitor.visit(node) return name_visitor.name_ids @@ -611,7 +635,7 @@ def transform_if_else(node, root): """ Transform ast.If into control flow statement of Paddle static graph. """ - parent_name_ids = get_name_ids([root], node_black_set=[node]) + parent_name_ids = get_name_ids([root], end_node=node) if_name_ids = get_name_ids(node.body) else_name_ids = get_name_ids(node.orelse) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_break_continue.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_break_continue.py index 31870146149..bc13a2f121d 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_break_continue.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_break_continue.py @@ -102,19 +102,18 @@ def test_break_continue_in_for(x): def test_for_in_else(x): x = fluid.dygraph.to_variable(x) - # - # TODO: Huihuang founds that if we put the for range in else body - # the testcase will fail. Enable this test case after fixing it. - # - #if False: - # pass - #else: - # for i in range(0, 10): - # if i > 5: - # x += 1 - # break - # x += i - # + + # Case 1: + if False: + pass + else: + for i in range(0, 10): + if i > 5: + x += 1 + break + x += i + + # Case 2: if False: pass else: diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse_basic.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse_basic.py index 06e8d94a03b..6cc0634996a 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse_basic.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse_basic.py @@ -65,10 +65,24 @@ class TestGetNameIds2(TestGetNameIds): return z """ self.all_name_ids = { - 'x': [gast.Param(), gast.Store()], - 'a': [gast.Store(), gast.Load()], - 'y': [gast.Param(), gast.Load()], - 'z': [gast.Store()] + 'x': [ + gast.Param(), gast.Store(), gast.Load(), gast.Load(), + gast.Load() + ], + 'a': [gast.Store(), gast.Load(), gast.Load()], + 'y': [ + gast.Param(), + gast.Load(), + gast.Load(), + gast.Load(), + gast.Load(), + ], + 'z': [ + gast.Store(), + gast.Load(), + gast.Store(), + gast.Store(), + ] } @@ -83,9 +97,23 @@ class TestGetNameIds3(TestGetNameIds): return z """ self.all_name_ids = { - 'x': [gast.Param()], - 'y': [gast.Param()], - 'z': [gast.Store()] + 'x': [ + gast.Param(), + gast.Load(), + gast.Load(), + gast.Load(), + ], + 'y': [ + gast.Param(), + gast.Load(), + gast.Load(), + ], + 'z': [ + gast.Store(), + gast.Store(), + gast.Load(), + gast.Store(), + ] } -- GitLab