break_continue_transformer.py 14.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16 17
from paddle.fluid import unique_name
from paddle.jit.dy2static.utils import BaseNodeVisitor, index_in_list
from paddle.jit.dy2static.variable_trans_func import create_bool_node
18
from paddle.utils import gast
19

20
from .base_transformer import BaseTransformer, ForNodeVisitor
21

22
__all__ = []
23 24 25 26 27

BREAK_NAME_PREFIX = '__break'
CONTINUE_NAME_PREFIX = '__continue'


28
class ForToWhileTransformer(BaseTransformer):
29 30 31 32 33 34 35
    """
    Transform python for loop into while loop and add condition node in the
    loop test
    """

    def __init__(self, parent_node, loop_node, condition_node):
        assert isinstance(
36 37
            loop_node, gast.For
        ), "loop_node is not gast.For in ForToWhileTransformer"
38 39 40 41 42 43 44 45 46 47
        self.parent_node = parent_node
        self.loop_node = loop_node
        self.condition_node = condition_node

    def transform(self):
        if hasattr(self.parent_node, 'body'):
            body_list = self.parent_node.body
            i = index_in_list(body_list, self.loop_node)
            if i != -1:
                new_stmts = self.get_for_stmt_nodes(body_list[i])
48
                body_list[i : i + 1] = new_stmts
49
                i += len(new_stmts)
50
                return new_stmts
51 52 53 54 55
        if hasattr(self.parent_node, 'orelse'):
            body_list = self.parent_node.orelse
            i = index_in_list(body_list, self.loop_node)
            if i != -1:
                new_stmts = self.get_for_stmt_nodes(body_list[i])
56
                body_list[i : i + 1] = new_stmts
57
                i += len(new_stmts)
58
                return new_stmts
59
        raise ValueError(
60 61
            "parent_node doesn't contain the loop_node in ForToWhileTransformer"
        )
62 63 64

    def get_for_stmt_nodes(self, node):
        assert isinstance(
65 66
            node, gast.For
        ), "Input node is NOT gast.For in get_for_stmt_nodes"
67

68
        # 1. parse current gast.For node
69
        current_for_node_parser = ForNodeVisitor(node)
70 71
        stmts_tuple = current_for_node_parser.parse()
        if stmts_tuple is None:
72
            return [node]
73
        init_stmts, cond_stmt, body_stmts = stmts_tuple
74

75
        # 2. append break statement
76 77 78
        new_cond_stmt = gast.BoolOp(
            op=gast.And(), values=[cond_stmt, self.condition_node]
        )
79

80
        # 3. construct gast.While node
81 82 83
        while_node = gast.While(
            test=new_cond_stmt, body=body_stmts, orelse=node.orelse
        )
84 85
        init_stmts.append(while_node)
        return init_stmts
86 87


88
class BreakContinueTransformer(BaseNodeVisitor):
89 90 91
    """
    Rewrite 'break' and 'continue' key words in a if-else python way to make
    it equivalent to original control flow
92

93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
    The main idea of this class is:

        1. Map the 'break/continue' stmt with an unique boolean variable V.

        2. Find the first ancestor block containing this 'break/continue', a
        block can be a node containing stmt list. We should remove all stmts
        after the 'break/continue' and set the V to True here.

        3. Add 'if V' for stmts in ancestor blocks between the first one
        (exclusive) and the ancestor loop (inclusive)

        4. For 'break' add break into condition of the loop. For 'continue',
        set continue to False at the beginning of each loop

        TODO: more details should be summarized as design document
108 109 110 111

    Note: The class is inherited from BaseNodeVisitor instead of NodeTransformer,
          because ancestor nodes will be modified inplace for `Break/Continue` here.
          In general, we recommend to inheriting NodeTransformer to modify node!
112 113
    """

114
    def __init__(self, root):
115
        super().__init__()
116

117
        self.root = root
118 119 120 121 122

    def transform(self):
        self.visit(self.root)

    def visit_Break(self, node):
123
        loop_node_index = _find_ancestor_loop_index(node, self.ancestor_nodes)
124 125 126 127 128 129 130 131 132 133
        assert loop_node_index != -1, "SyntaxError: 'break' outside loop"
        loop_node = self.ancestor_nodes[loop_node_index]

        # 1. Map the 'break/continue' stmt with an unique boolean variable V.
        variable_name = unique_name.generate(BREAK_NAME_PREFIX)

        # 2. Find the first ancestor block containing this 'break/continue', a
        # block can be a node containing stmt list. We should remove all stmts
        # after the 'break/continue' and set the V to True here.
        first_block_index = self._remove_stmts_after_break_continue(
134 135
            node, variable_name, loop_node_index
        )
136

137
        # 3. Add 'if not V' for stmts in ancestor blocks between the first one
138 139 140 141
        # (exclusive) and the ancestor loop (inclusive)
        self._replace_if_stmt(loop_node_index, first_block_index, variable_name)

        # 4. For 'break' add break into condition of the loop.
142
        assign_false_node = create_bool_node(variable_name, False)
143 144
        self._add_stmt_before_cur_node(loop_node_index, assign_false_node)

145 146 147 148 149 150 151 152 153
        cond_var_node = gast.UnaryOp(
            op=gast.Not(),
            operand=gast.Name(
                id=variable_name,
                ctx=gast.Load(),
                annotation=None,
                type_comment=None,
            ),
        )
154

155
        if isinstance(loop_node, gast.While):
156 157 158
            loop_node.test = gast.BoolOp(
                op=gast.And(), values=[loop_node.test, cond_var_node]
            )
159 160
        elif isinstance(loop_node, gast.For):
            parent_node = self.ancestor_nodes[loop_node_index - 1]
161 162 163
            for_to_while = ForToWhileTransformer(
                parent_node, loop_node, cond_var_node
            )
164 165 166
            for_to_while.transform()

    def visit_Continue(self, node):
167
        loop_node_index = _find_ancestor_loop_index(node, self.ancestor_nodes)
168 169 170 171 172 173 174 175 176 177
        assert loop_node_index != -1, "SyntaxError: 'continue' outside loop"
        loop_node = self.ancestor_nodes[loop_node_index]

        # 1. Map the 'break/continue' stmt with an unique boolean variable V.
        variable_name = unique_name.generate(CONTINUE_NAME_PREFIX)

        # 2. Find the first ancestor block containing this 'break/continue', a
        # block can be a node containing stmt list. We should remove all stmts
        # after the 'break/continue' and set the V to True here.
        first_block_index = self._remove_stmts_after_break_continue(
178 179
            node, variable_name, loop_node_index
        )
180

181
        # 3. Add 'if not V' for stmts in ancestor blocks between the first one
182 183 184 185
        # (exclusive) and the ancestor loop (inclusive)
        self._replace_if_stmt(loop_node_index, first_block_index, variable_name)

        # 4. For 'continue', set continue to False at the beginning of each loop
186
        assign_false_node = create_bool_node(variable_name, False)
187 188
        loop_node.body.insert(0, assign_false_node)

189 190 191
    def _remove_stmts_after_break_continue(
        self, break_continue_node, break_continue_name, loop_node_index
    ):
192
        for first_block_index in range(
193 194
            len(self.ancestor_nodes) - 1, loop_node_index - 1, -1
        ):
195
            first_block = self.ancestor_nodes[first_block_index]
196 197 198 199 200
            if hasattr(
                first_block, "body"
            ) and self._replace_break_continue_in_stmt_list(
                first_block.body, break_continue_node, break_continue_name
            ):
201 202
                return first_block_index

203 204 205 206 207
            if hasattr(
                first_block, "orelse"
            ) and self._replace_break_continue_in_stmt_list(
                first_block.orelse, break_continue_node, break_continue_name
            ):
208 209 210 211
                return first_block_index

        return first_block_index

212 213 214
    def _replace_if_stmt(
        self, loop_node_index, first_block_index, break_continue_name
    ):
215 216 217
        for i in range(first_block_index - 1, loop_node_index - 1, -1):
            cur_node = self.ancestor_nodes[i]
            son_node = self.ancestor_nodes[i + 1]
218 219 220 221 222
            if hasattr(
                cur_node, 'body'
            ) and self._replace_after_node_to_if_in_stmt_list(
                cur_node.body, son_node, break_continue_name
            ):
223 224
                continue
            if hasattr(
225 226 227 228
                cur_node, 'orelse'
            ) and self._replace_after_node_to_if_in_stmt_list(
                cur_node.orelse, son_node, break_continue_name
            ):
229 230
                continue

231 232 233
    def _replace_break_continue_in_stmt_list(
        self, stmt_list, break_continue_node, break_continue_name
    ):
234 235 236
        i = index_in_list(stmt_list, break_continue_node)
        if i == -1:
            return False
237
        assign_true_node = create_bool_node(break_continue_name, True)
238 239 240
        stmt_list[i:] = [assign_true_node]
        return True

241 242 243
    def _replace_after_node_to_if_in_stmt_list(
        self, stmt_list, node, break_continue_name
    ):
244 245 246 247 248 249 250 251
        i = index_in_list(stmt_list, node)
        if i == -1:
            return False

        if i == len(stmt_list) - 1:
            # No need to add, we consider this as added successfully
            return True

252 253 254 255 256 257 258 259 260 261 262 263 264 265
        if_stmt = gast.If(
            test=gast.UnaryOp(
                op=gast.Not(),
                operand=gast.Name(
                    id=break_continue_name,
                    ctx=gast.Store(),
                    annotation=None,
                    type_comment=None,
                ),
            ),
            body=stmt_list[i + 1 :],
            orelse=[],
        )
        stmt_list[i + 1 :] = []
266 267 268 269 270 271
        stmt_list.append(if_stmt)
        return True

    def _add_stmt_before_cur_node(self, cur_node_index, stmt_node):
        cur_node = self.ancestor_nodes[cur_node_index]
        parent_node = self.ancestor_nodes[cur_node_index - 1]
272 273 274 275 276
        if hasattr(
            parent_node, "body"
        ) and self._add_stmt_into_list_before_node(
            parent_node.body, cur_node, stmt_node
        ):
277
            return True
278 279 280 281 282
        if hasattr(
            parent_node, "orelse"
        ) and self._add_stmt_into_list_before_node(
            parent_node.orelse, cur_node, stmt_node
        ):
283 284 285 286 287 288 289 290 291 292
            return True
        return False

    def _add_stmt_into_list_before_node(self, stmt_list, node, stmt_node):
        i = index_in_list(stmt_list, node)
        if i == -1:
            return False
        stmt_list.insert(i, stmt_node)
        return True

293 294 295 296 297 298 299 300 301 302

def _find_ancestor_loop_index(node, ancestor_nodes):
    for i in range(len(ancestor_nodes) - 1, -1, -1):
        if isinstance(ancestor_nodes[i], (gast.For, gast.While)):
            return i
    return -1


class BreakTransformOptimizer(BaseNodeVisitor):
    """
303 304 305
    In specific pattern, the transformed code could be optimized by joining the
    If.test with while.test.

306 307 308 309 310 311 312
    Currently supported pattern is:
    ```
        while cond1:            while cond1 and not cond2:
            if cond2:    --->       do_something()
                break
            do_something()
    ```
313

314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333
    See following example:

    >>> def foo(x):
    ...     i = paddle.to_tensor(1, dtype='int32')
    ...     while i < 10:
    ...         if x.mean() > 5:
    ...             break
    ...         x += i
    ...         i += 1
    ...     return x

    The generated code after applying optimization will be:
    ```
        def foo(x):
            i = paddle.to_tensor(1, dtype='int32')
            while i < 10 and not x.mean() > 5:
                x += i
                i += 1
            return x
    ```
334
    It can avoid wrapping all ops after `break` statement into `cond_op` that
335 336 337
    usually brings very heavy overhead.
    """

338
    def __init__(self, root):
339
        super().__init__()
340

341
        self.root = root
342 343 344 345 346 347 348 349 350 351 352 353 354 355

    def transform(self):
        self.visit(self.root)

    def visit_Break(self, node):
        loop_node_index = _find_ancestor_loop_index(node, self.ancestor_nodes)
        assert loop_node_index != -1, "SyntaxError: 'break' outside loop"
        loop_node = self.ancestor_nodes[loop_node_index]

        if self._is_break_cond_pattern(node, loop_node):
            cond_var_node = self._join_with_while_cond(node, loop_node)

            if isinstance(loop_node, gast.While):
                loop_node.test = gast.BoolOp(
356 357
                    op=gast.And(), values=[loop_node.test, cond_var_node]
                )
358 359
            elif isinstance(loop_node, gast.For):
                parent_node = self.ancestor_nodes[loop_node_index - 1]
360 361 362
                for_to_while = ForToWhileTransformer(
                    parent_node, loop_node, cond_var_node
                )
363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378
                for_to_while.transform()

    def _is_break_cond_pattern(self, break_node, loop_node):
        """
        Judge whether if match the pattern to join `If.test` with `while.test`
        """
        # while/for -> if -> break
        if len(self.ancestor_nodes) < 3 or self.ancestor_nodes[-3] != loop_node:
            return False

        assert self.ancestor_nodes[-1] == break_node
        parent_if_node = self.ancestor_nodes[-2]

        is_matched = False
        if isinstance(parent_if_node, gast.If):
            # gast.If only contains `break`
379 380 381 382
            break_first_in_if = (
                parent_if_node.body[0] == break_node
                and len(parent_if_node.orelse) == 0
            )
383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402
            # gast.If is first node of loop_node
            if_first_in_loop = loop_node.body[0] == parent_if_node

            is_matched = if_first_in_loop and break_first_in_if

        return is_matched

    def _join_with_while_cond(self, break_node, loop_node):
        """
        Join the `If.test` with `While.test` together.
        """
        parent_if_node = self.ancestor_nodes[-2]

        cond_var_node = gast.UnaryOp(op=gast.Not(), operand=parent_if_node.test)

        # remove the gast.If node that contains the gast.Break.
        assert loop_node.body[0] == parent_if_node
        loop_node.body.pop(0)

        return cond_var_node