未验证 提交 a5c18204 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2stat]Join break cond with while cond in some pattern (#28171)

* Join break cond with while cond

* remove usless code

* refine the if code

* Split into BreakTransfromOptimizer

* add BreakTransformOptimizer in ast_transformer

* add more comment
上级 7a3a05cc
...@@ -22,6 +22,7 @@ import gast ...@@ -22,6 +22,7 @@ import gast
from paddle.fluid.dygraph.dygraph_to_static.assert_transformer import AssertTransformer from paddle.fluid.dygraph.dygraph_to_static.assert_transformer import AssertTransformer
from paddle.fluid.dygraph.dygraph_to_static.basic_api_transformer import BasicApiTransformer from paddle.fluid.dygraph.dygraph_to_static.basic_api_transformer import BasicApiTransformer
from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import BreakContinueTransformer from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import BreakContinueTransformer
from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import BreakTransformOptimizer
from paddle.fluid.dygraph.dygraph_to_static.call_transformer import CallTransformer from paddle.fluid.dygraph.dygraph_to_static.call_transformer import CallTransformer
from paddle.fluid.dygraph.dygraph_to_static.cast_transformer import CastTransformer from paddle.fluid.dygraph.dygraph_to_static.cast_transformer import CastTransformer
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfElseTransformer from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfElseTransformer
...@@ -75,6 +76,7 @@ class DygraphToStaticAst(gast.NodeTransformer): ...@@ -75,6 +76,7 @@ class DygraphToStaticAst(gast.NodeTransformer):
BasicApiTransformer, # Basic Api BasicApiTransformer, # Basic Api
TensorShapeTransformer, # Tensor.shape -> layers.shape(Tensor) TensorShapeTransformer, # Tensor.shape -> layers.shape(Tensor)
ListTransformer, # List used in control flow ListTransformer, # List used in control flow
BreakTransformOptimizer, # optimize transfromation of break in loops
BreakContinueTransformer, # break/continue in loops BreakContinueTransformer, # break/continue in loops
ReturnTransformer, # return in functions ReturnTransformer, # return in functions
LogicalTransformer, # logical and/or/not LogicalTransformer, # logical and/or/not
......
...@@ -19,6 +19,7 @@ import gast ...@@ -19,6 +19,7 @@ import gast
from paddle.fluid import unique_name from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list
from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import BaseNodeVisitor
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node
__all__ = ['BreakContinueTransformer'] __all__ = ['BreakContinueTransformer']
...@@ -83,7 +84,7 @@ class ForToWhileTransformer(gast.NodeTransformer): ...@@ -83,7 +84,7 @@ class ForToWhileTransformer(gast.NodeTransformer):
return init_stmts return init_stmts
class BreakContinueTransformer(gast.NodeTransformer): class BreakContinueTransformer(BaseNodeVisitor):
""" """
Rewrite 'break' and 'continue' key words in a if-else python way to make Rewrite 'break' and 'continue' key words in a if-else python way to make
it equivalent to original control flow it equivalent to original control flow
...@@ -103,41 +104,23 @@ class BreakContinueTransformer(gast.NodeTransformer): ...@@ -103,41 +104,23 @@ class BreakContinueTransformer(gast.NodeTransformer):
set continue to False at the beginning of each loop set continue to False at the beginning of each loop
TODO: more details should be summarized as design document TODO: more details should be summarized as design document
Note: The class is inherited from BaseNodeVisitor instead of NodeTransformer,
because ancestor nodes will be modified inplace for `Break/Continue` here.
In general, we recommend to inheriting NodeTransformer to modify node!
""" """
def __init__(self, wrapper_root): def __init__(self, wrapper_root):
super(BreakContinueTransformer, self).__init__()
self.wrapper_root = wrapper_root self.wrapper_root = wrapper_root
self.root = wrapper_root.node self.root = wrapper_root.node
self.ancestor_nodes = []
def transform(self): def transform(self):
self.visit(self.root) self.visit(self.root)
def generic_visit(self, node):
# TODO: because we change ancestor nodes during visit_Break/Continue,
# not current node, so generic_visit of NodeTransformer will visit node
# which may be deleted. To prevent that node being added into
# transformed AST, I have to self-write a generic_visit, but this is
# NOT a good thing. Considering refactorying this whole class.
for field, value in gast.iter_fields(node):
if isinstance(value, list):
for item in value:
if isinstance(item, gast.AST):
self.visit(item)
elif isinstance(value, gast.AST):
self.visit(value)
def visit(self, node):
self.ancestor_nodes.append(node)
method = 'visit_' + node.__class__.__name__
visitor = getattr(self, method, self.generic_visit)
ret = visitor(node)
self.ancestor_nodes.pop()
return ret
def visit_Break(self, node): def visit_Break(self, node):
loop_node_index = self._find_ancestor_loop_index(node) loop_node_index = _find_ancestor_loop_index(node, self.ancestor_nodes)
assert loop_node_index != -1, "SyntaxError: 'break' outside loop" assert loop_node_index != -1, "SyntaxError: 'break' outside loop"
loop_node = self.ancestor_nodes[loop_node_index] loop_node = self.ancestor_nodes[loop_node_index]
...@@ -150,7 +133,7 @@ class BreakContinueTransformer(gast.NodeTransformer): ...@@ -150,7 +133,7 @@ class BreakContinueTransformer(gast.NodeTransformer):
first_block_index = self._remove_stmts_after_break_continue( first_block_index = self._remove_stmts_after_break_continue(
node, variable_name, loop_node_index) node, variable_name, loop_node_index)
# 3. Add 'if V' for stmts in ancestor blocks between the first one # 3. Add 'if not V' for stmts in ancestor blocks between the first one
# (exclusive) and the ancestor loop (inclusive) # (exclusive) and the ancestor loop (inclusive)
self._replace_if_stmt(loop_node_index, first_block_index, variable_name) self._replace_if_stmt(loop_node_index, first_block_index, variable_name)
...@@ -165,6 +148,7 @@ class BreakContinueTransformer(gast.NodeTransformer): ...@@ -165,6 +148,7 @@ class BreakContinueTransformer(gast.NodeTransformer):
ctx=gast.Load(), ctx=gast.Load(),
annotation=None, annotation=None,
type_comment=None)) type_comment=None))
if isinstance(loop_node, gast.While): if isinstance(loop_node, gast.While):
loop_node.test = gast.BoolOp( loop_node.test = gast.BoolOp(
op=gast.And(), values=[loop_node.test, cond_var_node]) op=gast.And(), values=[loop_node.test, cond_var_node])
...@@ -175,7 +159,7 @@ class BreakContinueTransformer(gast.NodeTransformer): ...@@ -175,7 +159,7 @@ class BreakContinueTransformer(gast.NodeTransformer):
for_to_while.transform() for_to_while.transform()
def visit_Continue(self, node): def visit_Continue(self, node):
loop_node_index = self._find_ancestor_loop_index(node) loop_node_index = _find_ancestor_loop_index(node, self.ancestor_nodes)
assert loop_node_index != -1, "SyntaxError: 'continue' outside loop" assert loop_node_index != -1, "SyntaxError: 'continue' outside loop"
loop_node = self.ancestor_nodes[loop_node_index] loop_node = self.ancestor_nodes[loop_node_index]
...@@ -188,7 +172,7 @@ class BreakContinueTransformer(gast.NodeTransformer): ...@@ -188,7 +172,7 @@ class BreakContinueTransformer(gast.NodeTransformer):
first_block_index = self._remove_stmts_after_break_continue( first_block_index = self._remove_stmts_after_break_continue(
node, variable_name, loop_node_index) node, variable_name, loop_node_index)
# 3. Add 'if V' for stmts in ancestor blocks between the first one # 3. Add 'if not V' for stmts in ancestor blocks between the first one
# (exclusive) and the ancestor loop (inclusive) # (exclusive) and the ancestor loop (inclusive)
self._replace_if_stmt(loop_node_index, first_block_index, variable_name) self._replace_if_stmt(loop_node_index, first_block_index, variable_name)
...@@ -215,15 +199,6 @@ class BreakContinueTransformer(gast.NodeTransformer): ...@@ -215,15 +199,6 @@ class BreakContinueTransformer(gast.NodeTransformer):
return first_block_index return first_block_index
def _replace_break_continue_in_stmt_list(
self, stmt_list, break_continue_node, break_continue_name):
i = index_in_list(stmt_list, break_continue_node)
if i == -1:
return False
assign_true_node = create_fill_constant_node(break_continue_name, True)
stmt_list[i:] = [assign_true_node]
return True
def _replace_if_stmt(self, loop_node_index, first_block_index, def _replace_if_stmt(self, loop_node_index, first_block_index,
break_continue_name): break_continue_name):
for i in range(first_block_index - 1, loop_node_index - 1, -1): for i in range(first_block_index - 1, loop_node_index - 1, -1):
...@@ -239,6 +214,15 @@ class BreakContinueTransformer(gast.NodeTransformer): ...@@ -239,6 +214,15 @@ class BreakContinueTransformer(gast.NodeTransformer):
cur_node.orelse, son_node, break_continue_name): cur_node.orelse, son_node, break_continue_name):
continue continue
def _replace_break_continue_in_stmt_list(
self, stmt_list, break_continue_node, break_continue_name):
i = index_in_list(stmt_list, break_continue_node)
if i == -1:
return False
assign_true_node = create_fill_constant_node(break_continue_name, True)
stmt_list[i:] = [assign_true_node]
return True
def _replace_after_node_to_if_in_stmt_list(self, stmt_list, node, def _replace_after_node_to_if_in_stmt_list(self, stmt_list, node,
break_continue_name): break_continue_name):
i = index_in_list(stmt_list, node) i = index_in_list(stmt_list, node)
...@@ -282,8 +266,110 @@ class BreakContinueTransformer(gast.NodeTransformer): ...@@ -282,8 +266,110 @@ class BreakContinueTransformer(gast.NodeTransformer):
stmt_list.insert(i, stmt_node) stmt_list.insert(i, stmt_node)
return True return True
def _find_ancestor_loop_index(self, node):
for i in range(len(self.ancestor_nodes) - 1, -1, -1): def _find_ancestor_loop_index(node, ancestor_nodes):
if isinstance(self.ancestor_nodes[i], (gast.For, gast.While)): for i in range(len(ancestor_nodes) - 1, -1, -1):
return i if isinstance(ancestor_nodes[i], (gast.For, gast.While)):
return -1 return i
return -1
class BreakTransformOptimizer(BaseNodeVisitor):
"""
In specific pattern, the transformed code could be optimized by joining the
If.test with while.test.
Currently supported pattern is:
```
while cond1: while cond1 and not cond2:
if cond2: ---> do_something()
break
do_something()
```
See following example:
>>> def foo(x):
... i = paddle.to_tensor(1, dtype='int32')
... while i < 10:
... if x.mean() > 5:
... break
... x += i
... i += 1
... return x
The generated code after applying optimization will be:
```
def foo(x):
i = paddle.to_tensor(1, dtype='int32')
while i < 10 and not x.mean() > 5:
x += i
i += 1
return x
```
It can avoid wrapping all ops after `break` statement into `cond_op` that
usually brings very heavy overhead.
"""
def __init__(self, wrapper_root):
super(BreakTransformOptimizer, self).__init__()
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
def transform(self):
self.visit(self.root)
def visit_Break(self, node):
loop_node_index = _find_ancestor_loop_index(node, self.ancestor_nodes)
assert loop_node_index != -1, "SyntaxError: 'break' outside loop"
loop_node = self.ancestor_nodes[loop_node_index]
if self._is_break_cond_pattern(node, loop_node):
cond_var_node = self._join_with_while_cond(node, loop_node)
if isinstance(loop_node, gast.While):
loop_node.test = gast.BoolOp(
op=gast.And(), values=[loop_node.test, cond_var_node])
elif isinstance(loop_node, gast.For):
parent_node = self.ancestor_nodes[loop_node_index - 1]
for_to_while = ForToWhileTransformer(parent_node, loop_node,
cond_var_node)
for_to_while.transform()
def _is_break_cond_pattern(self, break_node, loop_node):
"""
Judge whether if match the pattern to join `If.test` with `while.test`
"""
# while/for -> if -> break
if len(self.ancestor_nodes) < 3 or self.ancestor_nodes[-3] != loop_node:
return False
assert self.ancestor_nodes[-1] == break_node
parent_if_node = self.ancestor_nodes[-2]
is_matched = False
if isinstance(parent_if_node, gast.If):
# gast.If only contains `break`
break_first_in_if = parent_if_node.body[0] == break_node and len(
parent_if_node.orelse) == 0
# gast.If is first node of loop_node
if_first_in_loop = loop_node.body[0] == parent_if_node
is_matched = if_first_in_loop and break_first_in_if
return is_matched
def _join_with_while_cond(self, break_node, loop_node):
"""
Join the `If.test` with `While.test` together.
"""
parent_if_node = self.ancestor_nodes[-2]
cond_var_node = gast.UnaryOp(op=gast.Not(), operand=parent_if_node.test)
# remove the gast.If node that contains the gast.Break.
assert loop_node.body[0] == parent_if_node
loop_node.body.pop(0)
return cond_var_node
...@@ -29,6 +29,28 @@ import numpy as np ...@@ -29,6 +29,28 @@ import numpy as np
from paddle.fluid import unique_name from paddle.fluid import unique_name
class BaseNodeVisitor(gast.NodeVisitor):
"""
Implement customized NodeVisitor inherited from gast.NodeVisitor.
Ancestor nodes are traced to easily support more operations of currently
visited node.
"""
def __init__(self):
self.ancestor_nodes = []
def visit(self, node):
"""Visit a node."""
self.ancestor_nodes.append(node)
method = 'visit_' + node.__class__.__name__
visitor = getattr(self, method, self.generic_visit)
ret = visitor(node)
self.ancestor_nodes.pop()
return ret
# imp is deprecated in python3 # imp is deprecated in python3
if six.PY2: if six.PY2:
import imp import imp
......
...@@ -16,6 +16,7 @@ from __future__ import print_function ...@@ -16,6 +16,7 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.jit import declarative from paddle.fluid.dygraph.jit import declarative
...@@ -157,6 +158,30 @@ def while_loop_class_var(x): ...@@ -157,6 +158,30 @@ def while_loop_class_var(x):
return foo.c return foo.c
def test_optim_break_in_for(x):
x = paddle.to_tensor(x)
for i in range(10):
if x.sum() > 5:
break
x += 10086
x += i
if i < 3:
x = x * 2
return x
def test_optim_break_in_while(x):
x = paddle.to_tensor(x)
i = fluid.layers.fill_constant(shape=[1], dtype='int32', value=0)
while i < 10:
if i > 5:
break
x += 10086
x += i
i += 1
return x
class TestContinueInFor(unittest.TestCase): class TestContinueInFor(unittest.TestCase):
def setUp(self): def setUp(self):
self.input = np.zeros((1)).astype('int32') self.input = np.zeros((1)).astype('int32')
...@@ -226,5 +251,15 @@ class TestWhileLoopClassVar(TestContinueInWhile): ...@@ -226,5 +251,15 @@ class TestWhileLoopClassVar(TestContinueInWhile):
self.dygraph_func = while_loop_class_var self.dygraph_func = while_loop_class_var
class TestOptimBreakInFor(TestContinueInWhile):
def init_dygraph_func(self):
self.dygraph_func = test_optim_break_in_for
class TestOptimBreakInWhile(TestContinueInWhile):
def init_dygraph_func(self):
self.dygraph_func = test_optim_break_in_while
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册