diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py index 2c59a66f22be2aa646a60da14e85173b9c8d81a7..fa168a62de11a9bebb2199924576e32685ed6513 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py @@ -22,6 +22,7 @@ import gast from paddle.fluid.dygraph.dygraph_to_static.assert_transformer import AssertTransformer from paddle.fluid.dygraph.dygraph_to_static.basic_api_transformer import BasicApiTransformer from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import BreakContinueTransformer +from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import BreakTransformOptimizer from paddle.fluid.dygraph.dygraph_to_static.call_transformer import CallTransformer from paddle.fluid.dygraph.dygraph_to_static.cast_transformer import CastTransformer from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfElseTransformer @@ -75,6 +76,7 @@ class DygraphToStaticAst(gast.NodeTransformer): BasicApiTransformer, # Basic Api TensorShapeTransformer, # Tensor.shape -> layers.shape(Tensor) ListTransformer, # List used in control flow + BreakTransformOptimizer, # optimize transfromation of break in loops BreakContinueTransformer, # break/continue in loops ReturnTransformer, # return in functions LogicalTransformer, # logical and/or/not diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/break_continue_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/break_continue_transformer.py index c78f6e8f403196fc098914c4cc58c8a16a4d885c..cb0383b9f736235b85735e38635f3db43ab23784 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/break_continue_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/break_continue_transformer.py @@ -19,6 +19,7 @@ import gast from paddle.fluid import unique_name from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor +from paddle.fluid.dygraph.dygraph_to_static.utils import BaseNodeVisitor from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node __all__ = ['BreakContinueTransformer'] @@ -83,7 +84,7 @@ class ForToWhileTransformer(gast.NodeTransformer): return init_stmts -class BreakContinueTransformer(gast.NodeTransformer): +class BreakContinueTransformer(BaseNodeVisitor): """ Rewrite 'break' and 'continue' key words in a if-else python way to make it equivalent to original control flow @@ -103,41 +104,23 @@ class BreakContinueTransformer(gast.NodeTransformer): set continue to False at the beginning of each loop TODO: more details should be summarized as design document + + Note: The class is inherited from BaseNodeVisitor instead of NodeTransformer, + because ancestor nodes will be modified inplace for `Break/Continue` here. + In general, we recommend to inheriting NodeTransformer to modify node! """ def __init__(self, wrapper_root): + super(BreakContinueTransformer, self).__init__() + self.wrapper_root = wrapper_root self.root = wrapper_root.node - self.ancestor_nodes = [] - def transform(self): self.visit(self.root) - def generic_visit(self, node): - # TODO: because we change ancestor nodes during visit_Break/Continue, - # not current node, so generic_visit of NodeTransformer will visit node - # which may be deleted. To prevent that node being added into - # transformed AST, I have to self-write a generic_visit, but this is - # NOT a good thing. Considering refactorying this whole class. - for field, value in gast.iter_fields(node): - if isinstance(value, list): - for item in value: - if isinstance(item, gast.AST): - self.visit(item) - elif isinstance(value, gast.AST): - self.visit(value) - - def visit(self, node): - self.ancestor_nodes.append(node) - method = 'visit_' + node.__class__.__name__ - visitor = getattr(self, method, self.generic_visit) - ret = visitor(node) - self.ancestor_nodes.pop() - return ret - def visit_Break(self, node): - loop_node_index = self._find_ancestor_loop_index(node) + loop_node_index = _find_ancestor_loop_index(node, self.ancestor_nodes) assert loop_node_index != -1, "SyntaxError: 'break' outside loop" loop_node = self.ancestor_nodes[loop_node_index] @@ -150,7 +133,7 @@ class BreakContinueTransformer(gast.NodeTransformer): first_block_index = self._remove_stmts_after_break_continue( node, variable_name, loop_node_index) - # 3. Add 'if V' for stmts in ancestor blocks between the first one + # 3. Add 'if not V' for stmts in ancestor blocks between the first one # (exclusive) and the ancestor loop (inclusive) self._replace_if_stmt(loop_node_index, first_block_index, variable_name) @@ -165,6 +148,7 @@ class BreakContinueTransformer(gast.NodeTransformer): ctx=gast.Load(), annotation=None, type_comment=None)) + if isinstance(loop_node, gast.While): loop_node.test = gast.BoolOp( op=gast.And(), values=[loop_node.test, cond_var_node]) @@ -175,7 +159,7 @@ class BreakContinueTransformer(gast.NodeTransformer): for_to_while.transform() def visit_Continue(self, node): - loop_node_index = self._find_ancestor_loop_index(node) + loop_node_index = _find_ancestor_loop_index(node, self.ancestor_nodes) assert loop_node_index != -1, "SyntaxError: 'continue' outside loop" loop_node = self.ancestor_nodes[loop_node_index] @@ -188,7 +172,7 @@ class BreakContinueTransformer(gast.NodeTransformer): first_block_index = self._remove_stmts_after_break_continue( node, variable_name, loop_node_index) - # 3. Add 'if V' for stmts in ancestor blocks between the first one + # 3. Add 'if not V' for stmts in ancestor blocks between the first one # (exclusive) and the ancestor loop (inclusive) self._replace_if_stmt(loop_node_index, first_block_index, variable_name) @@ -215,15 +199,6 @@ class BreakContinueTransformer(gast.NodeTransformer): return first_block_index - def _replace_break_continue_in_stmt_list( - self, stmt_list, break_continue_node, break_continue_name): - i = index_in_list(stmt_list, break_continue_node) - if i == -1: - return False - assign_true_node = create_fill_constant_node(break_continue_name, True) - stmt_list[i:] = [assign_true_node] - return True - def _replace_if_stmt(self, loop_node_index, first_block_index, break_continue_name): for i in range(first_block_index - 1, loop_node_index - 1, -1): @@ -239,6 +214,15 @@ class BreakContinueTransformer(gast.NodeTransformer): cur_node.orelse, son_node, break_continue_name): continue + def _replace_break_continue_in_stmt_list( + self, stmt_list, break_continue_node, break_continue_name): + i = index_in_list(stmt_list, break_continue_node) + if i == -1: + return False + assign_true_node = create_fill_constant_node(break_continue_name, True) + stmt_list[i:] = [assign_true_node] + return True + def _replace_after_node_to_if_in_stmt_list(self, stmt_list, node, break_continue_name): i = index_in_list(stmt_list, node) @@ -282,8 +266,110 @@ class BreakContinueTransformer(gast.NodeTransformer): stmt_list.insert(i, stmt_node) return True - def _find_ancestor_loop_index(self, node): - for i in range(len(self.ancestor_nodes) - 1, -1, -1): - if isinstance(self.ancestor_nodes[i], (gast.For, gast.While)): - return i - return -1 + +def _find_ancestor_loop_index(node, ancestor_nodes): + for i in range(len(ancestor_nodes) - 1, -1, -1): + if isinstance(ancestor_nodes[i], (gast.For, gast.While)): + return i + return -1 + + +class BreakTransformOptimizer(BaseNodeVisitor): + """ + In specific pattern, the transformed code could be optimized by joining the + If.test with while.test. + + Currently supported pattern is: + ``` + while cond1: while cond1 and not cond2: + if cond2: ---> do_something() + break + do_something() + ``` + + See following example: + + >>> def foo(x): + ... i = paddle.to_tensor(1, dtype='int32') + ... while i < 10: + ... if x.mean() > 5: + ... break + ... x += i + ... i += 1 + ... return x + + The generated code after applying optimization will be: + ``` + def foo(x): + i = paddle.to_tensor(1, dtype='int32') + while i < 10 and not x.mean() > 5: + x += i + i += 1 + return x + ``` + It can avoid wrapping all ops after `break` statement into `cond_op` that + usually brings very heavy overhead. + """ + + def __init__(self, wrapper_root): + super(BreakTransformOptimizer, self).__init__() + + self.wrapper_root = wrapper_root + self.root = wrapper_root.node + + def transform(self): + self.visit(self.root) + + def visit_Break(self, node): + loop_node_index = _find_ancestor_loop_index(node, self.ancestor_nodes) + assert loop_node_index != -1, "SyntaxError: 'break' outside loop" + loop_node = self.ancestor_nodes[loop_node_index] + + if self._is_break_cond_pattern(node, loop_node): + cond_var_node = self._join_with_while_cond(node, loop_node) + + if isinstance(loop_node, gast.While): + loop_node.test = gast.BoolOp( + op=gast.And(), values=[loop_node.test, cond_var_node]) + elif isinstance(loop_node, gast.For): + parent_node = self.ancestor_nodes[loop_node_index - 1] + for_to_while = ForToWhileTransformer(parent_node, loop_node, + cond_var_node) + for_to_while.transform() + + def _is_break_cond_pattern(self, break_node, loop_node): + """ + Judge whether if match the pattern to join `If.test` with `while.test` + """ + # while/for -> if -> break + if len(self.ancestor_nodes) < 3 or self.ancestor_nodes[-3] != loop_node: + return False + + assert self.ancestor_nodes[-1] == break_node + parent_if_node = self.ancestor_nodes[-2] + + is_matched = False + if isinstance(parent_if_node, gast.If): + # gast.If only contains `break` + break_first_in_if = parent_if_node.body[0] == break_node and len( + parent_if_node.orelse) == 0 + # gast.If is first node of loop_node + if_first_in_loop = loop_node.body[0] == parent_if_node + + is_matched = if_first_in_loop and break_first_in_if + + return is_matched + + def _join_with_while_cond(self, break_node, loop_node): + """ + Join the `If.test` with `While.test` together. + """ + parent_if_node = self.ancestor_nodes[-2] + + cond_var_node = gast.UnaryOp(op=gast.Not(), operand=parent_if_node.test) + + # remove the gast.If node that contains the gast.Break. + assert loop_node.body[0] == parent_if_node + loop_node.body.pop(0) + + return cond_var_node diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index 7a234580712ac15c946100ed60fc8e2d8849bd16..b44739ca8484ba4a73aa09ba82eaab4ea5681ed4 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -29,6 +29,28 @@ import numpy as np from paddle.fluid import unique_name + +class BaseNodeVisitor(gast.NodeVisitor): + """ + Implement customized NodeVisitor inherited from gast.NodeVisitor. + Ancestor nodes are traced to easily support more operations of currently + visited node. + """ + + def __init__(self): + self.ancestor_nodes = [] + + def visit(self, node): + """Visit a node.""" + self.ancestor_nodes.append(node) + + method = 'visit_' + node.__class__.__name__ + visitor = getattr(self, method, self.generic_visit) + ret = visitor(node) + self.ancestor_nodes.pop() + return ret + + # imp is deprecated in python3 if six.PY2: import imp diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_break_continue.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_break_continue.py index 6bcbc2b4a0babe97f4d299581f184072b8a286d0..8423c056b2d83038aba7c2671af8d30fa874a36b 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_break_continue.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_break_continue.py @@ -16,6 +16,7 @@ from __future__ import print_function import unittest import numpy as np +import paddle import paddle.fluid as fluid from paddle.fluid.dygraph.jit import declarative @@ -157,6 +158,30 @@ def while_loop_class_var(x): return foo.c +def test_optim_break_in_for(x): + x = paddle.to_tensor(x) + for i in range(10): + if x.sum() > 5: + break + x += 10086 + x += i + if i < 3: + x = x * 2 + return x + + +def test_optim_break_in_while(x): + x = paddle.to_tensor(x) + i = fluid.layers.fill_constant(shape=[1], dtype='int32', value=0) + while i < 10: + if i > 5: + break + x += 10086 + x += i + i += 1 + return x + + class TestContinueInFor(unittest.TestCase): def setUp(self): self.input = np.zeros((1)).astype('int32') @@ -226,5 +251,15 @@ class TestWhileLoopClassVar(TestContinueInWhile): self.dygraph_func = while_loop_class_var +class TestOptimBreakInFor(TestContinueInWhile): + def init_dygraph_func(self): + self.dygraph_func = test_optim_break_in_for + + +class TestOptimBreakInWhile(TestContinueInWhile): + def init_dygraph_func(self): + self.dygraph_func = test_optim_break_in_while + + if __name__ == '__main__': unittest.main()