From 2403362d06c4696c5611b0074f3b2a3aafcccaa0 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Wed, 18 Mar 2020 10:07:32 +0800 Subject: [PATCH] BugFix for parsing Arguments and inserting funcs in IfElseTransormer (#23035) * Support and/or in controlFlow if test=develop --- .../dygraph_to_static/ifelse_transformer.py | 181 +++++++++++++----- .../dygraph_to_static/ifelse_simple_func.py | 52 +++++ .../dygraph_to_static/test_ast_util.py | 2 +- .../{test_basic.py => test_ifelse.py} | 12 ++ .../dygraph_to_static/test_ifelse_basic.py | 18 +- 5 files changed, 206 insertions(+), 59 deletions(-) rename python/paddle/fluid/tests/unittests/dygraph_to_static/{test_basic.py => test_ifelse.py} (92%) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py index 345fcca52d2..82ee4ac123c 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py @@ -95,7 +95,7 @@ class IfElseTransformer(gast.NodeTransformer): """ self._insert_func_nodes(node) - def _insert_func_nodes(self, parent_node): + def _insert_func_nodes(self, node): """ Defined `true_func` and `false_func` will be inserted in front of corresponding `layers.cond` statement instead of inserting them all into body of parent node. @@ -103,13 +103,18 @@ class IfElseTransformer(gast.NodeTransformer): For example, `self.var_dict["key"]`. In this case, nested structure of newly defined functions is easier to understand. """ - if not (self.new_func_nodes and hasattr(parent_node, 'body')): + if not self.new_func_nodes: return - idx = len(parent_node.body) - 1 + idx = -1 + if isinstance(node, list): + idx = len(node) - 1 + elif isinstance(node, gast.AST): + for _, child in gast.iter_fields(node): + self._insert_func_nodes(child) while idx >= 0: - child_node = parent_node.body[idx] + child_node = node[idx] if child_node in self.new_func_nodes: - parent_node.body[idx:idx] = self.new_func_nodes[child_node] + node[idx:idx] = self.new_func_nodes[child_node] idx = idx + len(self.new_func_nodes[child_node]) - 1 del self.new_func_nodes[child_node] else: @@ -366,51 +371,133 @@ class IfConditionVisitor(object): return new_node, new_assign_nodes -def get_name_ids(nodes, not_name_set=None, node_black_list=None): +class NameVisitor(gast.NodeVisitor): + def __init__(self, node_black_set=None): + # Set of nodes that will not be visited. + self.node_black_set = node_black_set or set() + # Dict to store the names and ctxs of vars. + self.name_ids = defaultdict(list) + # List of current visited nodes + self.ancestor_nodes = [] + # Available only when node_black_set is set. + self._is_finished = False + self._candidate_ctxs = (gast.Store, gast.Load, gast.Param) + + def visit(self, node): + """Visit a node.""" + if node in self.node_black_set or self._is_finished: + self._is_finished = True + return + + self.ancestor_nodes.append(node) + method = 'visit_' + node.__class__.__name__ + visitor = getattr(self, method, self.generic_visit) + ret = visitor(node) + self.ancestor_nodes.pop() + + return ret + + def visit_If(self, node): + """ + For nested `if/else`, the created vars are not always visible for parent node. + In addition, the vars created in `if.body` are not visible for `if.orelse`. + + Case 1: + x = 1 + if m > 1: + res = new_tensor + res = res + 1 # Error, `res` is not visible here. + + Case 2: + if x_tensor > 0: + res = new_tensor + else: + res = res + 1 # Error, `res` is not visible here. + + In above two cases, we should consider to manage the scope of vars to parsing + the arguments and returned vars correctly. + """ + before_if_name_ids = copy.deepcopy(self.name_ids) + body_name_ids = self._visit_child(node.body) + # If the traversal process stops early, just return the name_ids that have been seen. + if self._is_finished: + for name_id, ctxs in before_if_name_ids.items(): + self.name_ids[name_id] = ctxs + self.name_ids[name_id] + # Blocks the vars in `if.body` and only inserts the vars both created in 'if/else' branch + # into name_ids. + else: + else_name_ids = self._visit_child(node.orelse) + new_name_ids = self._find_new_name_ids(body_name_ids, else_name_ids) + for new_name_id in new_name_ids: + before_if_name_ids[new_name_id].append(gast.Store()) + + self.name_ids = before_if_name_ids + + def visit_Attribute(self, node): + if not self._is_call_func_name_node(node): + self.generic_visit(node) + + def visit_Name(self, node): + if not self._is_call_func_name_node(node): + if isinstance(node.ctx, self._candidate_ctxs): + self.name_ids[node.id].append(node.ctx) + + def visit_Assign(self, node): + # Visit `value` firstly. + node._fields = ('value', 'targets') + self.generic_visit(node) + + def visit_Return(self, node): + # Ignore the vars in return + return + + def _visit_child(self, node): + self.name_ids = defaultdict(list) + if isinstance(node, list): + for item in node: + if isinstance(item, gast.AST): + self.visit(item) + elif isinstance(node, gast.AST): + self.visit(node) + + return copy.deepcopy(self.name_ids) + + def _find_new_name_ids(self, body_name_ids, else_name_ids): + def is_required_ctx(ctxs, required_ctx): + for ctx in ctxs: + if isinstance(ctx, required_ctx): + return True + return False + + candidate_name_ids = set(body_name_ids.keys()) & set(else_name_ids.keys( + )) + store_ctx = gast.Store + new_name_ids = set() + for name_id in candidate_name_ids: + if is_required_ctx(body_name_ids[name_id], + store_ctx) and is_required_ctx( + else_name_ids[name_id], store_ctx): + new_name_ids.add(name_id) + + return new_name_ids + + def _is_call_func_name_node(self, node): + if len(self.ancestor_nodes) > 1: + assert self.ancestor_nodes[-1] == node + parent_node = self.ancestor_nodes[-2] + if isinstance(parent_node, gast.Call) and parent_node.func == node: + return True + return False + + +def get_name_ids(nodes, node_black_set=None): """ Return all ast.Name.id of python variable in nodes. """ - if not isinstance(nodes, (list, tuple, set)): - raise ValueError( - "nodes must be one of list, tuple, set, but received %s" % - type(nodes)) - if not_name_set is None: - not_name_set = set() - - def update(old_dict, new_dict): - for k, v in new_dict.items(): - old_dict[k].extend(v) - - name_ids = defaultdict(list) + name_visitor = NameVisitor(node_black_set) for node in nodes: - if node_black_list and node in node_black_list: - break - if isinstance(node, gast.AST): - # In two case, the ast.Name should be filtered. - # 1. Function name like `my_func` of my_func(x) - # 2. api prefix like `fluid` of `fluid.layers.mean` - if isinstance(node, gast.Return): - continue - elif isinstance(node, gast.Call) and isinstance(node.func, - gast.Name): - not_name_set.add(node.func.id) - elif isinstance(node, gast.Attribute) and isinstance(node.value, - gast.Name): - not_name_set.add(node.value.id) - if isinstance( - node, gast.Name - ) and node.id not in name_ids and node.id not in not_name_set: - if isinstance(node.ctx, (gast.Store, gast.Load, gast.Param)): - name_ids[node.id].append(node.ctx) - else: - if isinstance(node, gast.Assign): - node = copy.copy(node) - node._fields = ('value', 'targets') - for field, value in gast.iter_fields(node): - value = value if isinstance(value, list) else [value] - update(name_ids, - get_name_ids(value, not_name_set, node_black_list)) - return name_ids + name_visitor.visit(node) + return name_visitor.name_ids def parse_cond_args(var_ids_dict, return_ids=None, ctx=gast.Load): @@ -508,7 +595,7 @@ def transform_if_else(node, root): """ Transform ast.If into control flow statement of Paddle static graph. """ - parent_name_ids = get_name_ids([root], node_black_list=[node]) + parent_name_ids = get_name_ids([root], node_black_set=[node]) if_name_ids = get_name_ids(node.body) else_name_ids = get_name_ids(node.orelse) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py index 277675a5876..8655bcab447 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py @@ -65,6 +65,58 @@ def nested_if_else(x_v): return y +def nested_if_else_2(x): + y = fluid.layers.reshape(x, [-1, 1]) + b = 2 + if b < 1: + # var `z` is not visible for outer scope + z = y + x_shape_0 = x.shape[0] + if x_shape_0 < 1: + if fluid.layers.shape(y).numpy()[0] < 1: + res = fluid.layers.fill_constant( + value=2, shape=x.shape, dtype="int32") + # `z` is a new var here. + z = y + 1 + else: + res = fluid.layers.fill_constant( + value=3, shape=x.shape, dtype="int32") + else: + res = x + return res + + +def nested_if_else_3(x): + y = fluid.layers.reshape(x, [-1, 1]) + b = 2 + # var `z` is visible for func.body + if b < 1: + z = y + else: + z = x + + if b < 1: + res = x + # var `out` is only visible for current `if` + if b > 1: + out = x + 1 + else: + out = x - 1 + else: + y_shape = fluid.layers.shape(y) + if y_shape.numpy()[0] < 1: + res = fluid.layers.fill_constant( + value=2, shape=x.shape, dtype="int32") + # `z` is created in above code block. + z = y + 1 + else: + res = fluid.layers.fill_constant( + value=3, shape=x.shape, dtype="int32") + # `out` is a new var. + out = x + 1 + return res + + class NetWithControlFlowIf(fluid.dygraph.Layer): def __init__(self, hidden_dim=16): super(NetWithControlFlowIf, self).__init__() diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ast_util.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ast_util.py index 6689f2e681d..a048e20799d 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ast_util.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ast_util.py @@ -22,7 +22,7 @@ import numpy as np import paddle.fluid as fluid from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func -from test_basic import dyfunc_with_if_else, dyfunc_with_if_else2, nested_if_else +from ifelse_simple_func import dyfunc_with_if_else, dyfunc_with_if_else2, nested_if_else class TestAST2Func(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_basic.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py similarity index 92% rename from python/paddle/fluid/tests/unittests/dygraph_to_static/test_basic.py rename to python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py index b3bc9100223..2502064b268 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_basic.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py @@ -72,6 +72,18 @@ class TestDygraphIfElse3(TestDygraphIfElse): self.dyfunc = nested_if_else +class TestDygraphIfElse4(TestDygraphIfElse): + def setUp(self): + self.x = np.random.random([10, 16]).astype('float32') + self.dyfunc = nested_if_else_2 + + +class TestDygraphIfElse5(TestDygraphIfElse): + def setUp(self): + self.x = np.random.random([10, 16]).astype('float32') + self.dyfunc = nested_if_else_3 + + class TestDygraphIfElseWithAndOr(TestDygraphIfElse): def setUp(self): self.x = np.random.random([10, 16]).astype('float32') diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse_basic.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse_basic.py index 2ad95675c52..06e8d94a03b 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse_basic.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse_basic.py @@ -65,14 +65,10 @@ class TestGetNameIds2(TestGetNameIds): return z """ self.all_name_ids = { - 'x': [ - gast.Param(), gast.Store(), gast.Load(), gast.Load(), - gast.Load() - ], - 'a': [gast.Store(), gast.Load(), gast.Load()], - 'y': - [gast.Param(), gast.Load(), gast.Load(), gast.Load(), gast.Load()], - 'z': [gast.Store(), gast.Load(), gast.Store(), gast.Store()] + 'x': [gast.Param(), gast.Store()], + 'a': [gast.Store(), gast.Load()], + 'y': [gast.Param(), gast.Load()], + 'z': [gast.Store()] } @@ -87,9 +83,9 @@ class TestGetNameIds3(TestGetNameIds): return z """ self.all_name_ids = { - 'x': [gast.Param(), gast.Load(), gast.Load(), gast.Load()], - 'y': [gast.Param(), gast.Load(), gast.Load()], - 'z': [gast.Store(), gast.Store(), gast.Load(), gast.Store()] + 'x': [gast.Param()], + 'y': [gast.Param()], + 'z': [gast.Store()] } -- GitLab