loop_transformer.py 23.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import copy
import gast

from collections import defaultdict
from paddle.fluid import unique_name
22
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
23
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType
24 25
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
26
from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node
27
from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
L
liym27 已提交
28
from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform
29
from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor
30
from paddle.fluid.dygraph.dygraph_to_static.utils import RenameTransformer
31 32 33 34 35 36 37 38
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable_gast_node

__all__ = ['LoopTransformer', 'NameVisitor']

WHILE_CONDITION_PREFIX = 'while_condition'
WHILE_BODY_PREFIX = 'while_body'

39 40
FOR_CONDITION_PREFIX = 'for_loop_condition'
FOR_BODY_PREFIX = 'for_loop_body'
41
GENERATE_VARIABLE_PREFIX = 'generate_variable'
42

43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61

def create_while_node(condition_name, body_name, loop_var_names):
    while_args = []
    while_args.append(
        gast.Name(
            id=condition_name,
            ctx=gast.Param(),
            annotation=None,
            type_comment=None))
    while_args.append(
        gast.Name(
            id=body_name, ctx=gast.Param(), annotation=None, type_comment=None))
    assign_targets = [
        gast.Name(
            id=var_name, ctx=gast.Param(), annotation=None, type_comment=None)
        for var_name in loop_var_names
    ]
    while_args.append(gast.List(elts=assign_targets, ctx=gast.Param()))

62 63 64
    while_func_id = gast.parse(
        'fluid.dygraph.dygraph_to_static.convert_operators.convert_while_loop'
    ).body[0].value
65 66 67 68 69 70 71 72
    while_node = gast.Call(func=while_func_id, args=while_args, keywords=[])
    assign_node = gast.Assign(
        targets=[gast.Tuple(
            elts=assign_targets, ctx=gast.Store())],
        value=while_node)
    return assign_node


73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
class LogicalOpTransformer(gast.NodeTransformer):
    """
    Transform python boolean op into Paddle logical op
    """

    def __init__(self, node):
        self.root = node

    def transform(self):
        return self.visit(self.root)

    def visit_UnaryOp(self, node):
        self.generic_visit(node)
        if isinstance(node.op, gast.Not):
            arg = ast_to_source_code(node.operand)
88 89
            new_node_str = "fluid.dygraph.dygraph_to_static.convert_operators.convert_logical_not({})".format(
                arg)
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
            # gast.parse returns Module(body=[expr(value=...)])
            new_node = gast.parse(new_node_str).body[0].value
            return new_node
        return node

    def visit_BoolOp(self, node):
        self.generic_visit(node)
        if isinstance(node.op, gast.And):
            new_node = self._create_bool_op_node(node.values, 'and')
        elif isinstance(node.op, gast.Or):
            new_node = self._create_bool_op_node(node.values, 'or')
        else:
            raise TypeError(
                "Only supports and/or syntax in control flow if statement.")
        return new_node

    def _create_bool_op_node(self, nodes, api_type):
        assert len(
            nodes
        ) > 1, "The length of BoolOp should be at least 2, but received {}.".format(
            len(nodes))
        if len(nodes) > 2:
            # Creates logic_and/logic_or node recursively.
113
            pre_logic_node = self._create_bool_op_node(nodes[:2], api_type)
114 115 116 117
            if len(nodes[2:]) == 1:
                post_logic_node = nodes[2]
            else:
                post_logic_node = self._create_bool_op_node(nodes[2:], api_type)
118 119
            nodes = [pre_logic_node] + [post_logic_node]

120
        args = [ast_to_source_code(child) for child in nodes]
121
        new_node_str = "fluid.dygraph.dygraph_to_static.convert_operators.convert_logical_{}(x={}, y={})".format(
122 123 124 125 126 127
            api_type, args[0], args[1])
        # gast.parse return Module(body=[expr(...)])
        new_node = gast.parse(new_node_str).body[0].value
        return new_node


128 129 130 131 132 133
class NameVisitor(gast.NodeVisitor):
    '''
    Analysis name liveness for loop transformer
    '''

    def __init__(self, root_node):
134
        # Set of gast.Name or gast.Attribute for variables
135
        self.current_seen_vars = set()
136

137 138 139
        # List of gast.While/gast.For nodes
        self.current_loop = []

140 141 142 143 144
        # List of nodes that have scope of variables.
        self.nodes_with_scope = []

        self.blacklist_names = {"False", "True", "None"}

145 146
        # Mapping from gast.While/gast.For to variable nodes
        self.before_loop_body_vars = defaultdict(set)
147 148
        self.in_loop_vars = defaultdict(set)

149 150 151 152 153 154
        # Mapping from gast.While/gast.For to variable nodes which is condition
        # of loop or being modified during the loop
        self.write_in_loop = defaultdict(set)
        self.condition_vars = defaultdict(set)
        self.in_condition = False

155 156 157 158
        self.static_analysis_visitor = StaticAnalysisVisitor(root_node)
        self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
        )

159 160 161
        self.visit(root_node)

    def is_control_flow_loop(self, node):
L
liym27 已提交
162 163 164
        need_transform = is_control_flow_to_transform(
            node, self.static_analysis_visitor)
        return need_transform
165 166

    def get_loop_var_names(self, node):
167 168
        assert isinstance(
            node, (gast.While, gast.For)), "Input node is not gast loop node"
169 170
        loop_var_names = set()
        create_var_names = set()
H
Huihuang Zheng 已提交
171
        read_context = {type(gast.Load()), type(gast.AugLoad())}
172 173

        in_loop_vars = self.in_loop_vars[node]
174
        in_loop_name_strs = self._var_nodes_to_names(in_loop_vars)
175

176
        before_loop_body_vars = self.before_loop_body_vars[node]
177 178
        before_loop_body_vars = self._remove_target_vars_of_for(
            before_loop_body_vars, node)
179
        before_loop_name_strs = self._var_nodes_to_names(before_loop_body_vars)
180

181
        after_loop_vars = self.current_seen_vars - before_loop_body_vars - in_loop_vars
182
        after_loop_vars = self._remove_target_vars_of_for(after_loop_vars, node)
183 184
        after_loop_name_strs = self._var_nodes_to_names(after_loop_vars,
                                                        read_context)
185 186
        condition_vars = self.condition_vars[node]
        condition_names = self._var_nodes_to_names(condition_vars)
187

188 189 190 191 192 193 194 195
        write_vars = self.write_in_loop[node]
        write_names = self._var_nodes_to_names(write_vars)

        name_to_type = {}
        for var in in_loop_vars:
            wrapper = self.node_to_wrapper_map[var]
            name_to_type[self._var_node_to_name(var)] = wrapper.node_var_type

196 197
        for name in in_loop_name_strs:
            if name in before_loop_name_strs:
198 199 200 201 202 203 204 205 206
                # If a variable is used in loop and created before loop

                # If this var is a basic variable and read-only and not
                # condition var, it may not be loop_var else it should
                # be in loop_var as input
                if (not name in condition_names) and (
                        not name in write_names
                ) and self._node_var_type_is_basic(name_to_type[name]):
                    continue
207
                loop_var_names.add(name)
208

209 210 211
            elif name in after_loop_name_strs:
                # If a variable is created in the while loop and read after
                # loop, it should be in loop_var and we should create it
212 213 214 215

                # because name in after_loop_name must be initialized in loop
                # So it is write-only, we don't have to filter read-only basic
                # vars out
216 217
                loop_var_names.add(name)
                create_var_names.add(name)
218

219 220 221
        return loop_var_names, create_var_names

    def visit_Name(self, node):
222 223 224
        if self._is_call_func_name_node(node):
            self.generic_visit(node)
            return
225
        if node.id in self.blacklist_names:
226 227
            self.generic_visit(node)
            return
228

229
        self.current_seen_vars.add(node)
230 231 232
        write_context = {
            type(gast.Store()), type(gast.AugStore()), type(gast.Del())
        }
233 234
        for loop_node in self.current_loop:
            self.in_loop_vars[loop_node].add(node)
235 236
            if type(node.ctx) in write_context:
                self.write_in_loop[loop_node].add(node)
237 238
        if self.in_condition:
            self.condition_vars[loop_node].add(node)
239 240
        self.generic_visit(node)

241 242 243 244 245 246 247 248 249 250 251 252 253
    def visit_FunctionDef(self, node):
        self.nodes_with_scope.append(node)
        self.blacklist_names.add(node.name)
        # The variables in the function are not visible to the outside scope.
        before_func_seen_vars = copy.copy(self.current_seen_vars)

        self.generic_visit(node)
        self.nodes_with_scope.pop()
        # After exiting the scope of the node, variables in this scope
        # should be removed from self.current_seen_vars.
        if self.nodes_with_scope:
            self.current_seen_vars = before_func_seen_vars

254 255 256 257 258 259 260 261 262 263
    def visit(self, node):
        method = 'visit_' + node.__class__.__name__
        visitor = getattr(self, method, self.generic_visit)
        ret = visitor(node)
        return ret

    def visit_Attribute(self, node):
        if self._is_call_func_name_node(node):
            return
        attr_full_name = get_attribute_full_name(node)
264 265 266 267 268 269 270 271 272 273
        # Class variables are not allowed to appear in the arguments list
        # of defined function under class methods in Python.
        """
        def class_func(self):
            def while_loop_body(self.x, y) # `self.x` is illegal.
        """
        # TODO: If do change the variable with `self.var`, need a better
        # way to deal with this case.
        if attr_full_name.startswith("self."):
            return
274
        self.current_seen_vars.add(node)
275

276 277
        for loop_node in self.current_loop:
            self.in_loop_vars[loop_node].add(node)
278

279 280 281
        # sub-nodes are visited during get_attribute_full_name and we shouldn't
        # visit again

282 283
    def visit_For(self, node):
        self.current_loop.append(node)
284
        self.in_condition = True
285
        self.visit(node.target)
286 287
        self.visit(node.iter)
        self.in_condition = False
288
        self.before_loop_body_vars[node] = copy.copy(self.current_seen_vars)
289 290 291 292 293
        self.generic_visit(node)
        self.current_loop.pop()

    def visit_While(self, node):
        self.current_loop.append(node)
294
        self.in_condition = True
295
        self.visit(node.test)
296
        self.in_condition = False
297
        self.before_loop_body_vars[node] = copy.copy(self.current_seen_vars)
298 299 300
        self.generic_visit(node)
        self.current_loop.pop()

301 302 303 304
    def _var_nodes_to_names(self, node_set, ctx_filter_set=None):
        ret = set()
        for node in node_set:
            if ctx_filter_set is None or type(node.ctx) in ctx_filter_set:
305
                ret.add(self._var_node_to_name(node))
306 307
        return ret

308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323
    def _var_node_to_name(self, node):
        if isinstance(node, gast.Name):
            return node.id
        elif isinstance(node, gast.Attribute):
            return get_attribute_full_name(node)

    def _node_var_type_is_basic(self, node_var_type):
        basic_types = {
            NodeVarType.BOOLEAN, NodeVarType.INT, NodeVarType.FLOAT,
            NodeVarType.STRING
        }
        for t in node_var_type:
            if t in basic_types:
                return True
        return False

324
    def _is_call_func_name_node(self, node):
325
        parent_node = self._get_parent_node(node)
326 327
        if isinstance(parent_node, gast.Call) and parent_node.func == node:
            return True
328 329
        return False

330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378
    def _get_parent_node(self, node):
        wrapper_node = self.node_to_wrapper_map.get(node)
        if wrapper_node:
            parent_node = wrapper_node.parent.node
            return parent_node
        return None

    def _remove_target_vars_of_for(self, before_or_after_loop_vars, loop_node):
        """
        Remove target vars of gast.For from before_loop_vars or after_loop_vars.
        :param before_or_after_loop_vars: before_loop_vars or after_loop_vars of loop_node.
        :param loop_node: Current loop node.
        """

        removed_vars = set()
        for name_node in before_or_after_loop_vars:
            if not isinstance(name_node, gast.Name):
                continue

            parent_node = self._get_parent_node(name_node)

            # NOTE: gast.For.target can be gast.Tuple.
            #  For example: `for i, j in enumerate(x)` has two target vars: i and j
            if isinstance(parent_node, gast.Tuple):
                parent_node = self._get_parent_node(parent_node)

            if isinstance(parent_node,
                          gast.For) and parent_node is not loop_node:
                target_node = parent_node.target

                if isinstance(target_node, gast.Tuple):
                    target_vars = target_node.elts
                else:
                    target_vars = [target_node]

                if name_node in target_vars:
                    removed_vars.add(name_node)

        removed_vars_name_strs = {var.id for var in removed_vars}

        for var in before_or_after_loop_vars:
            if not isinstance(var, gast.Name):
                continue
            if var.id in removed_vars_name_strs and var not in self.condition_vars[
                    loop_node]:
                removed_vars.add(var)

        return before_or_after_loop_vars - removed_vars

379 380 381 382 383 384 385 386 387

class LoopTransformer(gast.NodeTransformer):
    """
    This class transforms python while/for statement into Static Graph Ast
    """

    def __init__(self, wrapper_root):
        assert isinstance(
            wrapper_root, AstNodeWrapper
388
        ), "Input non-AstNodeWrapper node for the initialization of LoopTransformer."
389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415
        self.wrapper_root = wrapper_root
        self.root = wrapper_root.node
        self.name_visitor = NameVisitor(self.root)

    def transform(self):
        self.visit(self.root)

    def visit(self, node):
        self.generic_visit(node)
        # All parent nodes that may contain gast.While/gast.For
        if hasattr(node, 'body'):
            self.replace_stmt_list(node.body)
        if hasattr(node, 'orelse'):
            self.replace_stmt_list(node.orelse)
        return node

    def replace_stmt_list(self, body_list):
        if not isinstance(body_list, list):
            return

        i = 0
        while i < len(body_list):
            if isinstance(body_list[i], gast.While):
                new_stmts = self.get_while_stmt_nodes(body_list[i])
                body_list[i:i + 1] = new_stmts
                i += len(new_stmts)
            elif isinstance(body_list[i], gast.For):
416 417 418
                new_stmts = self.get_for_stmt_nodes(body_list[i])
                body_list[i:i + 1] = new_stmts
                i += len(new_stmts)
419 420 421
            else:
                i += 1

422 423 424
    def get_for_stmt_nodes(self, node):
        # TODO: consider for - else in python

425 426
        # 1. check whether need to transform
        # NOTE: Current need transform cases:
427 428 429
        #   1). for x in range(VarBase[0]|VarBase.numpy()[0])
        #   2). for x in VarBase|VarBase.numpy()
        #   3). for i, x in enumerate(VarBase|VarBase.numpy())
430
        if not self.name_visitor.is_control_flow_loop(node):
431 432
            return [node]

433 434 435 436 437 438
        # 2. get key statements for different cases
        # NOTE: three key statements:
        #   1). init_stmts: list[node], prepare nodes of for loop, may not only one
        #   2). cond_stmt: node, condition node to judge whether continue loop
        #   3). body_stmts: list[node], updated loop body, sometimes we should change
        #       the original statement in body, not just append new statement
439
        current_for_node_parser = ForNodeVisitor(node)
440 441
        stmts_tuple = current_for_node_parser.parse()
        if stmts_tuple is None:
442
            return [node]
443
        init_stmts, cond_stmt, body_stmts = stmts_tuple
444

445
        # 3. get original loop vars
446 447
        loop_var_names, create_var_names = self.name_visitor.get_loop_var_names(
            node)
448 449 450 451 452 453 454 455 456 457 458 459 460
        # NOTE: in 'for x in var' or 'for i, x in enumerate(var)' cases,
        # we need append new loop var & remove useless loop var
        #   1. for x in var -> x is no need
        #   2. for i, x in enumerate(var) -> x is no need
        if current_for_node_parser.is_for_iter(
        ) or current_for_node_parser.is_for_enumerate_iter():
            iter_var_name = current_for_node_parser.iter_var_name
            iter_idx_name = current_for_node_parser.iter_idx_name
            loop_var_names.add(iter_idx_name)
            if iter_var_name not in create_var_names:
                loop_var_names.remove(iter_var_name)

        # 4. prepare result statement list
461 462 463 464 465 466 467 468 469
        new_stmts = []
        # Python can create variable in loop and use it out of loop, E.g.
        #
        # for x in range(10):
        #     y += x
        # print(x) # x = 10
        #
        # We need to create static variable for those variables
        for name in create_var_names:
470 471
            if "." not in name:
                new_stmts.append(create_static_variable_gast_node(name))
472

473 474
        # 5. append init statements
        new_stmts.extend(init_stmts)
475 476 477 478
        # for x in range(10) in dygraph should be convert into static tensor + 1 <= 10
        for name in loop_var_names:
            new_stmts.append(to_static_variable_gast_node(name))

479
        # 6. create & append condition function node
480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499
        condition_func_node = gast.FunctionDef(
            name=unique_name.generate(FOR_CONDITION_PREFIX),
            args=gast.arguments(
                args=[
                    gast.Name(
                        id=name,
                        ctx=gast.Param(),
                        annotation=None,
                        type_comment=None) for name in loop_var_names
                ],
                posonlyargs=[],
                vararg=None,
                kwonlyargs=[],
                kw_defaults=None,
                kwarg=None,
                defaults=[]),
            body=[gast.Return(value=cond_stmt)],
            decorator_list=[],
            returns=None,
            type_comment=None)
500 501 502 503 504
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(condition_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
505 506
        new_stmts.append(condition_func_node)

507 508 509
        # 7. create & append loop body function node
        # append return values for loop body
        body_stmts.append(
510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527
            gast.Return(value=generate_name_node(
                loop_var_names, ctx=gast.Load())))
        body_func_node = gast.FunctionDef(
            name=unique_name.generate(FOR_BODY_PREFIX),
            args=gast.arguments(
                args=[
                    gast.Name(
                        id=name,
                        ctx=gast.Param(),
                        annotation=None,
                        type_comment=None) for name in loop_var_names
                ],
                posonlyargs=[],
                vararg=None,
                kwonlyargs=[],
                kw_defaults=None,
                kwarg=None,
                defaults=[]),
528
            body=body_stmts,
529 530 531
            decorator_list=[],
            returns=None,
            type_comment=None)
532 533 534 535 536
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(body_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
537 538
        new_stmts.append(body_func_node)

539
        # 8. create & append while loop node 
540 541 542 543 544 545
        while_loop_node = create_while_node(condition_func_node.name,
                                            body_func_node.name, loop_var_names)
        new_stmts.append(while_loop_node)

        return new_stmts

546 547 548 549 550 551 552 553 554 555 556 557 558 559
    def get_while_stmt_nodes(self, node):
        loop_var_names, create_var_names = self.name_visitor.get_loop_var_names(
            node)
        new_stmts = []

        # Python can create variable in loop and use it out of loop, E.g.
        #
        # while x < 10:
        #     x += 1
        #     y = x
        # z = y
        #
        # We need to create static variable for those variables
        for name in create_var_names:
560 561
            if "." not in name:
                new_stmts.append(create_static_variable_gast_node(name))
562

563 564 565
        logical_op_transformer = LogicalOpTransformer(node.test)
        cond_value_node = logical_op_transformer.transform()

566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581
        condition_func_node = gast.FunctionDef(
            name=unique_name.generate(WHILE_CONDITION_PREFIX),
            args=gast.arguments(
                args=[
                    gast.Name(
                        id=name,
                        ctx=gast.Param(),
                        annotation=None,
                        type_comment=None) for name in loop_var_names
                ],
                posonlyargs=[],
                vararg=None,
                kwonlyargs=[],
                kw_defaults=None,
                kwarg=None,
                defaults=[]),
582
            body=[gast.Return(value=cond_value_node)],
583 584 585
            decorator_list=[],
            returns=None,
            type_comment=None)
586 587 588 589 590
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(condition_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616
        new_stmts.append(condition_func_node)

        new_body = node.body
        new_body.append(
            gast.Return(value=generate_name_node(
                loop_var_names, ctx=gast.Load())))
        body_func_node = gast.FunctionDef(
            name=unique_name.generate(WHILE_BODY_PREFIX),
            args=gast.arguments(
                args=[
                    gast.Name(
                        id=name,
                        ctx=gast.Param(),
                        annotation=None,
                        type_comment=None) for name in loop_var_names
                ],
                posonlyargs=[],
                vararg=None,
                kwonlyargs=[],
                kw_defaults=None,
                kwarg=None,
                defaults=[]),
            body=new_body,
            decorator_list=[],
            returns=None,
            type_comment=None)
617 618 619 620 621
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(body_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
622 623 624 625 626 627
        new_stmts.append(body_func_node)

        while_loop_node = create_while_node(condition_func_node.name,
                                            body_func_node.name, loop_var_names)
        new_stmts.append(while_loop_node)
        return new_stmts