未验证 提交 a5c18204 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2stat]Join break cond with while cond in some pattern (#28171)

* Join break cond with while cond

* remove usless code

* refine the if code

* Split into BreakTransfromOptimizer

* add BreakTransformOptimizer in ast_transformer

* add more comment
上级 7a3a05cc
......@@ -22,6 +22,7 @@ import gast
from paddle.fluid.dygraph.dygraph_to_static.assert_transformer import AssertTransformer
from paddle.fluid.dygraph.dygraph_to_static.basic_api_transformer import BasicApiTransformer
from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import BreakContinueTransformer
from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import BreakTransformOptimizer
from paddle.fluid.dygraph.dygraph_to_static.call_transformer import CallTransformer
from paddle.fluid.dygraph.dygraph_to_static.cast_transformer import CastTransformer
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfElseTransformer
......@@ -75,6 +76,7 @@ class DygraphToStaticAst(gast.NodeTransformer):
BasicApiTransformer, # Basic Api
TensorShapeTransformer, # Tensor.shape -> layers.shape(Tensor)
ListTransformer, # List used in control flow
BreakTransformOptimizer, # optimize transfromation of break in loops
BreakContinueTransformer, # break/continue in loops
ReturnTransformer, # return in functions
LogicalTransformer, # logical and/or/not
......
......@@ -19,6 +19,7 @@ import gast
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list
from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import BaseNodeVisitor
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node
__all__ = ['BreakContinueTransformer']
......@@ -83,7 +84,7 @@ class ForToWhileTransformer(gast.NodeTransformer):
return init_stmts
class BreakContinueTransformer(gast.NodeTransformer):
class BreakContinueTransformer(BaseNodeVisitor):
"""
Rewrite 'break' and 'continue' key words in a if-else python way to make
it equivalent to original control flow
......@@ -103,41 +104,23 @@ class BreakContinueTransformer(gast.NodeTransformer):
set continue to False at the beginning of each loop
TODO: more details should be summarized as design document
Note: The class is inherited from BaseNodeVisitor instead of NodeTransformer,
because ancestor nodes will be modified inplace for `Break/Continue` here.
In general, we recommend to inheriting NodeTransformer to modify node!
"""
def __init__(self, wrapper_root):
super(BreakContinueTransformer, self).__init__()
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
self.ancestor_nodes = []
def transform(self):
self.visit(self.root)
def generic_visit(self, node):
# TODO: because we change ancestor nodes during visit_Break/Continue,
# not current node, so generic_visit of NodeTransformer will visit node
# which may be deleted. To prevent that node being added into
# transformed AST, I have to self-write a generic_visit, but this is
# NOT a good thing. Considering refactorying this whole class.
for field, value in gast.iter_fields(node):
if isinstance(value, list):
for item in value:
if isinstance(item, gast.AST):
self.visit(item)
elif isinstance(value, gast.AST):
self.visit(value)
def visit(self, node):
self.ancestor_nodes.append(node)
method = 'visit_' + node.__class__.__name__
visitor = getattr(self, method, self.generic_visit)
ret = visitor(node)
self.ancestor_nodes.pop()
return ret
def visit_Break(self, node):
loop_node_index = self._find_ancestor_loop_index(node)
loop_node_index = _find_ancestor_loop_index(node, self.ancestor_nodes)
assert loop_node_index != -1, "SyntaxError: 'break' outside loop"
loop_node = self.ancestor_nodes[loop_node_index]
......@@ -150,7 +133,7 @@ class BreakContinueTransformer(gast.NodeTransformer):
first_block_index = self._remove_stmts_after_break_continue(
node, variable_name, loop_node_index)
# 3. Add 'if V' for stmts in ancestor blocks between the first one
# 3. Add 'if not V' for stmts in ancestor blocks between the first one
# (exclusive) and the ancestor loop (inclusive)
self._replace_if_stmt(loop_node_index, first_block_index, variable_name)
......@@ -165,6 +148,7 @@ class BreakContinueTransformer(gast.NodeTransformer):
ctx=gast.Load(),
annotation=None,
type_comment=None))
if isinstance(loop_node, gast.While):
loop_node.test = gast.BoolOp(
op=gast.And(), values=[loop_node.test, cond_var_node])
......@@ -175,7 +159,7 @@ class BreakContinueTransformer(gast.NodeTransformer):
for_to_while.transform()
def visit_Continue(self, node):
loop_node_index = self._find_ancestor_loop_index(node)
loop_node_index = _find_ancestor_loop_index(node, self.ancestor_nodes)
assert loop_node_index != -1, "SyntaxError: 'continue' outside loop"
loop_node = self.ancestor_nodes[loop_node_index]
......@@ -188,7 +172,7 @@ class BreakContinueTransformer(gast.NodeTransformer):
first_block_index = self._remove_stmts_after_break_continue(
node, variable_name, loop_node_index)
# 3. Add 'if V' for stmts in ancestor blocks between the first one
# 3. Add 'if not V' for stmts in ancestor blocks between the first one
# (exclusive) and the ancestor loop (inclusive)
self._replace_if_stmt(loop_node_index, first_block_index, variable_name)
......@@ -215,15 +199,6 @@ class BreakContinueTransformer(gast.NodeTransformer):
return first_block_index
def _replace_break_continue_in_stmt_list(
self, stmt_list, break_continue_node, break_continue_name):
i = index_in_list(stmt_list, break_continue_node)
if i == -1:
return False
assign_true_node = create_fill_constant_node(break_continue_name, True)
stmt_list[i:] = [assign_true_node]
return True
def _replace_if_stmt(self, loop_node_index, first_block_index,
break_continue_name):
for i in range(first_block_index - 1, loop_node_index - 1, -1):
......@@ -239,6 +214,15 @@ class BreakContinueTransformer(gast.NodeTransformer):
cur_node.orelse, son_node, break_continue_name):
continue
def _replace_break_continue_in_stmt_list(
self, stmt_list, break_continue_node, break_continue_name):
i = index_in_list(stmt_list, break_continue_node)
if i == -1:
return False
assign_true_node = create_fill_constant_node(break_continue_name, True)
stmt_list[i:] = [assign_true_node]
return True
def _replace_after_node_to_if_in_stmt_list(self, stmt_list, node,
break_continue_name):
i = index_in_list(stmt_list, node)
......@@ -282,8 +266,110 @@ class BreakContinueTransformer(gast.NodeTransformer):
stmt_list.insert(i, stmt_node)
return True
def _find_ancestor_loop_index(self, node):
for i in range(len(self.ancestor_nodes) - 1, -1, -1):
if isinstance(self.ancestor_nodes[i], (gast.For, gast.While)):
def _find_ancestor_loop_index(node, ancestor_nodes):
for i in range(len(ancestor_nodes) - 1, -1, -1):
if isinstance(ancestor_nodes[i], (gast.For, gast.While)):
return i
return -1
class BreakTransformOptimizer(BaseNodeVisitor):
"""
In specific pattern, the transformed code could be optimized by joining the
If.test with while.test.
Currently supported pattern is:
```
while cond1: while cond1 and not cond2:
if cond2: ---> do_something()
break
do_something()
```
See following example:
>>> def foo(x):
... i = paddle.to_tensor(1, dtype='int32')
... while i < 10:
... if x.mean() > 5:
... break
... x += i
... i += 1
... return x
The generated code after applying optimization will be:
```
def foo(x):
i = paddle.to_tensor(1, dtype='int32')
while i < 10 and not x.mean() > 5:
x += i
i += 1
return x
```
It can avoid wrapping all ops after `break` statement into `cond_op` that
usually brings very heavy overhead.
"""
def __init__(self, wrapper_root):
super(BreakTransformOptimizer, self).__init__()
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
def transform(self):
self.visit(self.root)
def visit_Break(self, node):
loop_node_index = _find_ancestor_loop_index(node, self.ancestor_nodes)
assert loop_node_index != -1, "SyntaxError: 'break' outside loop"
loop_node = self.ancestor_nodes[loop_node_index]
if self._is_break_cond_pattern(node, loop_node):
cond_var_node = self._join_with_while_cond(node, loop_node)
if isinstance(loop_node, gast.While):
loop_node.test = gast.BoolOp(
op=gast.And(), values=[loop_node.test, cond_var_node])
elif isinstance(loop_node, gast.For):
parent_node = self.ancestor_nodes[loop_node_index - 1]
for_to_while = ForToWhileTransformer(parent_node, loop_node,
cond_var_node)
for_to_while.transform()
def _is_break_cond_pattern(self, break_node, loop_node):
"""
Judge whether if match the pattern to join `If.test` with `while.test`
"""
# while/for -> if -> break
if len(self.ancestor_nodes) < 3 or self.ancestor_nodes[-3] != loop_node:
return False
assert self.ancestor_nodes[-1] == break_node
parent_if_node = self.ancestor_nodes[-2]
is_matched = False
if isinstance(parent_if_node, gast.If):
# gast.If only contains `break`
break_first_in_if = parent_if_node.body[0] == break_node and len(
parent_if_node.orelse) == 0
# gast.If is first node of loop_node
if_first_in_loop = loop_node.body[0] == parent_if_node
is_matched = if_first_in_loop and break_first_in_if
return is_matched
def _join_with_while_cond(self, break_node, loop_node):
"""
Join the `If.test` with `While.test` together.
"""
parent_if_node = self.ancestor_nodes[-2]
cond_var_node = gast.UnaryOp(op=gast.Not(), operand=parent_if_node.test)
# remove the gast.If node that contains the gast.Break.
assert loop_node.body[0] == parent_if_node
loop_node.body.pop(0)
return cond_var_node
......@@ -29,6 +29,28 @@ import numpy as np
from paddle.fluid import unique_name
class BaseNodeVisitor(gast.NodeVisitor):
"""
Implement customized NodeVisitor inherited from gast.NodeVisitor.
Ancestor nodes are traced to easily support more operations of currently
visited node.
"""
def __init__(self):
self.ancestor_nodes = []
def visit(self, node):
"""Visit a node."""
self.ancestor_nodes.append(node)
method = 'visit_' + node.__class__.__name__
visitor = getattr(self, method, self.generic_visit)
ret = visitor(node)
self.ancestor_nodes.pop()
return ret
# imp is deprecated in python3
if six.PY2:
import imp
......
......@@ -16,6 +16,7 @@ from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.jit import declarative
......@@ -157,6 +158,30 @@ def while_loop_class_var(x):
return foo.c
def test_optim_break_in_for(x):
x = paddle.to_tensor(x)
for i in range(10):
if x.sum() > 5:
break
x += 10086
x += i
if i < 3:
x = x * 2
return x
def test_optim_break_in_while(x):
x = paddle.to_tensor(x)
i = fluid.layers.fill_constant(shape=[1], dtype='int32', value=0)
while i < 10:
if i > 5:
break
x += 10086
x += i
i += 1
return x
class TestContinueInFor(unittest.TestCase):
def setUp(self):
self.input = np.zeros((1)).astype('int32')
......@@ -226,5 +251,15 @@ class TestWhileLoopClassVar(TestContinueInWhile):
self.dygraph_func = while_loop_class_var
class TestOptimBreakInFor(TestContinueInWhile):
def init_dygraph_func(self):
self.dygraph_func = test_optim_break_in_for
class TestOptimBreakInWhile(TestContinueInWhile):
def init_dygraph_func(self):
self.dygraph_func = test_optim_break_in_while
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册