break_continue_transformer.py 14.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
from paddle.utils import gast
16 17

from paddle.fluid import unique_name
18 19 20
from paddle.jit.dy2static.utils import index_in_list
from paddle.jit.dy2static.utils import BaseNodeVisitor
from paddle.jit.dy2static.variable_trans_func import (
21 22
    create_bool_node,
)
23
from .base_transformer import (
24 25
    BaseTransformer,
)
26
from .base_transformer import (
27 28
    ForNodeVisitor,
)
29 30 31 32 33 34 35

__all__ = ['BreakContinueTransformer']

BREAK_NAME_PREFIX = '__break'
CONTINUE_NAME_PREFIX = '__continue'


36
class ForToWhileTransformer(BaseTransformer):
37 38 39 40 41 42 43
    """
    Transform python for loop into while loop and add condition node in the
    loop test
    """

    def __init__(self, parent_node, loop_node, condition_node):
        assert isinstance(
44 45
            loop_node, gast.For
        ), "loop_node is not gast.For in ForToWhileTransformer"
46 47 48 49 50 51 52 53 54 55
        self.parent_node = parent_node
        self.loop_node = loop_node
        self.condition_node = condition_node

    def transform(self):
        if hasattr(self.parent_node, 'body'):
            body_list = self.parent_node.body
            i = index_in_list(body_list, self.loop_node)
            if i != -1:
                new_stmts = self.get_for_stmt_nodes(body_list[i])
56
                body_list[i : i + 1] = new_stmts
57
                i += len(new_stmts)
58
                return new_stmts
59 60 61 62 63
        if hasattr(self.parent_node, 'orelse'):
            body_list = self.parent_node.orelse
            i = index_in_list(body_list, self.loop_node)
            if i != -1:
                new_stmts = self.get_for_stmt_nodes(body_list[i])
64
                body_list[i : i + 1] = new_stmts
65
                i += len(new_stmts)
66
                return new_stmts
67
        raise ValueError(
68 69
            "parent_node doesn't contain the loop_node in ForToWhileTransformer"
        )
70 71 72

    def get_for_stmt_nodes(self, node):
        assert isinstance(
73 74
            node, gast.For
        ), "Input node is NOT gast.For in get_for_stmt_nodes"
75

76
        # 1. parse current gast.For node
77
        current_for_node_parser = ForNodeVisitor(node)
78 79
        stmts_tuple = current_for_node_parser.parse()
        if stmts_tuple is None:
80
            return [node]
81
        init_stmts, cond_stmt, body_stmts = stmts_tuple
82

83
        # 2. append break statement
84 85 86
        new_cond_stmt = gast.BoolOp(
            op=gast.And(), values=[cond_stmt, self.condition_node]
        )
87

88
        # 3. construct gast.While node
89 90 91
        while_node = gast.While(
            test=new_cond_stmt, body=body_stmts, orelse=node.orelse
        )
92 93
        init_stmts.append(while_node)
        return init_stmts
94 95


96
class BreakContinueTransformer(BaseNodeVisitor):
97 98 99
    """
    Rewrite 'break' and 'continue' key words in a if-else python way to make
    it equivalent to original control flow
100

101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
    The main idea of this class is:

        1. Map the 'break/continue' stmt with an unique boolean variable V.

        2. Find the first ancestor block containing this 'break/continue', a
        block can be a node containing stmt list. We should remove all stmts
        after the 'break/continue' and set the V to True here.

        3. Add 'if V' for stmts in ancestor blocks between the first one
        (exclusive) and the ancestor loop (inclusive)

        4. For 'break' add break into condition of the loop. For 'continue',
        set continue to False at the beginning of each loop

        TODO: more details should be summarized as design document
116 117 118 119

    Note: The class is inherited from BaseNodeVisitor instead of NodeTransformer,
          because ancestor nodes will be modified inplace for `Break/Continue` here.
          In general, we recommend to inheriting NodeTransformer to modify node!
120 121 122
    """

    def __init__(self, wrapper_root):
123
        super().__init__()
124

125 126 127 128 129 130 131
        self.wrapper_root = wrapper_root
        self.root = wrapper_root.node

    def transform(self):
        self.visit(self.root)

    def visit_Break(self, node):
132
        loop_node_index = _find_ancestor_loop_index(node, self.ancestor_nodes)
133 134 135 136 137 138 139 140 141 142
        assert loop_node_index != -1, "SyntaxError: 'break' outside loop"
        loop_node = self.ancestor_nodes[loop_node_index]

        # 1. Map the 'break/continue' stmt with an unique boolean variable V.
        variable_name = unique_name.generate(BREAK_NAME_PREFIX)

        # 2. Find the first ancestor block containing this 'break/continue', a
        # block can be a node containing stmt list. We should remove all stmts
        # after the 'break/continue' and set the V to True here.
        first_block_index = self._remove_stmts_after_break_continue(
143 144
            node, variable_name, loop_node_index
        )
145

146
        # 3. Add 'if not V' for stmts in ancestor blocks between the first one
147 148 149 150
        # (exclusive) and the ancestor loop (inclusive)
        self._replace_if_stmt(loop_node_index, first_block_index, variable_name)

        # 4. For 'break' add break into condition of the loop.
151
        assign_false_node = create_bool_node(variable_name, False)
152 153
        self._add_stmt_before_cur_node(loop_node_index, assign_false_node)

154 155 156 157 158 159 160 161 162
        cond_var_node = gast.UnaryOp(
            op=gast.Not(),
            operand=gast.Name(
                id=variable_name,
                ctx=gast.Load(),
                annotation=None,
                type_comment=None,
            ),
        )
163

164
        if isinstance(loop_node, gast.While):
165 166 167
            loop_node.test = gast.BoolOp(
                op=gast.And(), values=[loop_node.test, cond_var_node]
            )
168 169
        elif isinstance(loop_node, gast.For):
            parent_node = self.ancestor_nodes[loop_node_index - 1]
170 171 172
            for_to_while = ForToWhileTransformer(
                parent_node, loop_node, cond_var_node
            )
173 174 175
            for_to_while.transform()

    def visit_Continue(self, node):
176
        loop_node_index = _find_ancestor_loop_index(node, self.ancestor_nodes)
177 178 179 180 181 182 183 184 185 186
        assert loop_node_index != -1, "SyntaxError: 'continue' outside loop"
        loop_node = self.ancestor_nodes[loop_node_index]

        # 1. Map the 'break/continue' stmt with an unique boolean variable V.
        variable_name = unique_name.generate(CONTINUE_NAME_PREFIX)

        # 2. Find the first ancestor block containing this 'break/continue', a
        # block can be a node containing stmt list. We should remove all stmts
        # after the 'break/continue' and set the V to True here.
        first_block_index = self._remove_stmts_after_break_continue(
187 188
            node, variable_name, loop_node_index
        )
189

190
        # 3. Add 'if not V' for stmts in ancestor blocks between the first one
191 192 193 194
        # (exclusive) and the ancestor loop (inclusive)
        self._replace_if_stmt(loop_node_index, first_block_index, variable_name)

        # 4. For 'continue', set continue to False at the beginning of each loop
195
        assign_false_node = create_bool_node(variable_name, False)
196 197
        loop_node.body.insert(0, assign_false_node)

198 199 200
    def _remove_stmts_after_break_continue(
        self, break_continue_node, break_continue_name, loop_node_index
    ):
201
        for first_block_index in range(
202 203
            len(self.ancestor_nodes) - 1, loop_node_index - 1, -1
        ):
204
            first_block = self.ancestor_nodes[first_block_index]
205 206 207 208 209
            if hasattr(
                first_block, "body"
            ) and self._replace_break_continue_in_stmt_list(
                first_block.body, break_continue_node, break_continue_name
            ):
210 211
                return first_block_index

212 213 214 215 216
            if hasattr(
                first_block, "orelse"
            ) and self._replace_break_continue_in_stmt_list(
                first_block.orelse, break_continue_node, break_continue_name
            ):
217 218 219 220
                return first_block_index

        return first_block_index

221 222 223
    def _replace_if_stmt(
        self, loop_node_index, first_block_index, break_continue_name
    ):
224 225 226
        for i in range(first_block_index - 1, loop_node_index - 1, -1):
            cur_node = self.ancestor_nodes[i]
            son_node = self.ancestor_nodes[i + 1]
227 228 229 230 231
            if hasattr(
                cur_node, 'body'
            ) and self._replace_after_node_to_if_in_stmt_list(
                cur_node.body, son_node, break_continue_name
            ):
232 233
                continue
            if hasattr(
234 235 236 237
                cur_node, 'orelse'
            ) and self._replace_after_node_to_if_in_stmt_list(
                cur_node.orelse, son_node, break_continue_name
            ):
238 239
                continue

240 241 242
    def _replace_break_continue_in_stmt_list(
        self, stmt_list, break_continue_node, break_continue_name
    ):
243 244 245
        i = index_in_list(stmt_list, break_continue_node)
        if i == -1:
            return False
246
        assign_true_node = create_bool_node(break_continue_name, True)
247 248 249
        stmt_list[i:] = [assign_true_node]
        return True

250 251 252
    def _replace_after_node_to_if_in_stmt_list(
        self, stmt_list, node, break_continue_name
    ):
253 254 255 256 257 258 259 260
        i = index_in_list(stmt_list, node)
        if i == -1:
            return False

        if i == len(stmt_list) - 1:
            # No need to add, we consider this as added successfully
            return True

261 262 263 264 265 266 267 268 269 270 271 272 273 274
        if_stmt = gast.If(
            test=gast.UnaryOp(
                op=gast.Not(),
                operand=gast.Name(
                    id=break_continue_name,
                    ctx=gast.Store(),
                    annotation=None,
                    type_comment=None,
                ),
            ),
            body=stmt_list[i + 1 :],
            orelse=[],
        )
        stmt_list[i + 1 :] = []
275 276 277 278 279 280
        stmt_list.append(if_stmt)
        return True

    def _add_stmt_before_cur_node(self, cur_node_index, stmt_node):
        cur_node = self.ancestor_nodes[cur_node_index]
        parent_node = self.ancestor_nodes[cur_node_index - 1]
281 282 283 284 285
        if hasattr(
            parent_node, "body"
        ) and self._add_stmt_into_list_before_node(
            parent_node.body, cur_node, stmt_node
        ):
286
            return True
287 288 289 290 291
        if hasattr(
            parent_node, "orelse"
        ) and self._add_stmt_into_list_before_node(
            parent_node.orelse, cur_node, stmt_node
        ):
292 293 294 295 296 297 298 299 300 301
            return True
        return False

    def _add_stmt_into_list_before_node(self, stmt_list, node, stmt_node):
        i = index_in_list(stmt_list, node)
        if i == -1:
            return False
        stmt_list.insert(i, stmt_node)
        return True

302 303 304 305 306 307 308 309 310 311

def _find_ancestor_loop_index(node, ancestor_nodes):
    for i in range(len(ancestor_nodes) - 1, -1, -1):
        if isinstance(ancestor_nodes[i], (gast.For, gast.While)):
            return i
    return -1


class BreakTransformOptimizer(BaseNodeVisitor):
    """
312 313 314
    In specific pattern, the transformed code could be optimized by joining the
    If.test with while.test.

315 316 317 318 319 320 321
    Currently supported pattern is:
    ```
        while cond1:            while cond1 and not cond2:
            if cond2:    --->       do_something()
                break
            do_something()
    ```
322

323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342
    See following example:

    >>> def foo(x):
    ...     i = paddle.to_tensor(1, dtype='int32')
    ...     while i < 10:
    ...         if x.mean() > 5:
    ...             break
    ...         x += i
    ...         i += 1
    ...     return x

    The generated code after applying optimization will be:
    ```
        def foo(x):
            i = paddle.to_tensor(1, dtype='int32')
            while i < 10 and not x.mean() > 5:
                x += i
                i += 1
            return x
    ```
343
    It can avoid wrapping all ops after `break` statement into `cond_op` that
344 345 346 347
    usually brings very heavy overhead.
    """

    def __init__(self, wrapper_root):
348
        super().__init__()
349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365

        self.wrapper_root = wrapper_root
        self.root = wrapper_root.node

    def transform(self):
        self.visit(self.root)

    def visit_Break(self, node):
        loop_node_index = _find_ancestor_loop_index(node, self.ancestor_nodes)
        assert loop_node_index != -1, "SyntaxError: 'break' outside loop"
        loop_node = self.ancestor_nodes[loop_node_index]

        if self._is_break_cond_pattern(node, loop_node):
            cond_var_node = self._join_with_while_cond(node, loop_node)

            if isinstance(loop_node, gast.While):
                loop_node.test = gast.BoolOp(
366 367
                    op=gast.And(), values=[loop_node.test, cond_var_node]
                )
368 369
            elif isinstance(loop_node, gast.For):
                parent_node = self.ancestor_nodes[loop_node_index - 1]
370 371 372
                for_to_while = ForToWhileTransformer(
                    parent_node, loop_node, cond_var_node
                )
373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388
                for_to_while.transform()

    def _is_break_cond_pattern(self, break_node, loop_node):
        """
        Judge whether if match the pattern to join `If.test` with `while.test`
        """
        # while/for -> if -> break
        if len(self.ancestor_nodes) < 3 or self.ancestor_nodes[-3] != loop_node:
            return False

        assert self.ancestor_nodes[-1] == break_node
        parent_if_node = self.ancestor_nodes[-2]

        is_matched = False
        if isinstance(parent_if_node, gast.If):
            # gast.If only contains `break`
389 390 391 392
            break_first_in_if = (
                parent_if_node.body[0] == break_node
                and len(parent_if_node.orelse) == 0
            )
393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412
            # gast.If is first node of loop_node
            if_first_in_loop = loop_node.body[0] == parent_if_node

            is_matched = if_first_in_loop and break_first_in_if

        return is_matched

    def _join_with_while_cond(self, break_node, loop_node):
        """
        Join the `If.test` with `While.test` together.
        """
        parent_if_node = self.ancestor_nodes[-2]

        cond_var_node = gast.UnaryOp(op=gast.Not(), operand=parent_if_node.test)

        # remove the gast.If node that contains the gast.Break.
        assert loop_node.body[0] == parent_if_node
        loop_node.body.pop(0)

        return cond_var_node